我想更改数据集中一个元素的数据类型。(元素形状=(32,28,28)-->;这是mnist数据集中的一批28×28图像)
因此,我运行了以下命令:tf.cast(dataset.take(1),tf.float32)
我的数据集的类型是tensorflow.python.data.ops.dataset_ops.PrefetchDataset
它抛出了一个错误:: Attempt to convert a value (<TakeDataset shapes: (32, 28, 28), types: tf.uint8>) with an unsupported type (<class 'tensorflow.python.data.ops.dataset_ops.TakeDataset'>) to a Tensor.
因此,我使用以下代码从数据集中提取了一个元素:
for batch_data in dataset:
one_element = dataset
break
然后我运行了tf.cast(one_element,tf.float32)
,它就工作了
我能知道为什么会这样吗
take()返回数据集而不是张量(即使调用take(1)):https://www.tensorflow.org/api_docs/python/tf/data/Dataset
相关问题 更多 >
编程相关推荐