我有一个关于新数据集API(tensorflow 1.4)的问题。我有两个数据集,我需要创建一个组合的不平衡数据集,即。 每个批处理应该包含第一个数据集中的特定数量的元素和第二个数据集中的特定数量的元素。例如
dataset1 = tf.data.Dataset.from_tensor_slices(tf.constant([1,1,1,1,1,1]
dataset1 = tf.data.Dataset.from_tensor_slices(tf.constant([2,2,2,2,2,2]))
假设批大小为4,我希望组合数据集中的一个批看起来像[1,1,1,2]。我知道如何使用zip和flat-map生成一个平衡的数据集 但我对这件事不知所措。在
提前谢谢!在
为了解决这个问题,我的解决方案是单独批处理数据集,压缩它们,然后在生成的数据集上映射一个
tf.concat
运算符。在在您的示例中,它将给出如下内容(我将第二个数据集重命名为
dataset2
):如果数据集是张量的嵌套结构,则可以使用以下版本的concat:
^{pr2}$如果所有的数据集元素(要组合的数据集的一部分)都是张量,并且只有最外层的维度(相对批大小)不同,则有效。它为数据集元素的每个组件构建一个列表,并将这些组件相互独立地连接起来。在
它处理一个层次的嵌套。如果需要更多,可以使用递归来打开嵌套嵌套,但它可能会给出一个不太干净的计算图。。。在
相关问题 更多 >
编程相关推荐