在TensorF中展平数据集

2024-09-30 14:29:28 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图将TensorFlow中的一个数据集转换成几个单值张量。数据集当前如下所示:

[12 43 64 34 45 2 13 54] [34 65 34 67 87 12 23 43] [23 53 23 1 5] ...

改造后,它应该是这样的:

^{pr2}$

我最初的想法是在数据集中使用flat_map,然后使用reshapeunstack将每个张量转换为一个张量列表:

output_labels = self.dataset.flat_map(convert_labels)

...

def convert_labels(tensor):
    id_list = tf.unstack(tf.reshape(tensor, [-1, 1]))
    return tf.data.Dataset.from_tensors(id_list)

然而,每个张量的形状只有部分已知(即(?, 1)),这就是取消堆叠操作失败的原因。有没有办法在不显式地迭代不同张量的情况下仍然“合并”不同的张量?在


Tags: 数据idmapconvert列表labelstftensorflow
1条回答
网友
1楼 · 发布于 2024-09-30 14:29:28

您的解决方案非常接近,但是^{}使用一个返回tf.data.Dataset对象的函数,而不是张量列表。幸运的是,^{}方法正好适用于您的用例,因为它可以将张量拆分为可变数量的元素:

output_labels = self.dataset.flat_map(tf.data.Dataset.from_tensor_slices)

注意,^{}转换实现了相同的功能,并且在TensorFlow的当前主分支中有一个稍微更有效的实现(将包含在1.9版本中):

^{pr2}$

相关问题 更多 >