我有一个tfrecord文件,其中存储了一个数据列表,每个元素都有二维坐标和三维坐标。坐标是dtype float64的2d numpy数组。在
这些是我用来存储它们的功能。在
feature = {'train/coord2d': _floats_feature(projC),
'train/coord3d': _floats_feature(sChair)}
我把它们放在一个浮动列表中。在
^{pr2}$现在我正在尝试接收它们,以便将它们输入到我的网络中进行训练。我希望2d坐标作为输入,3d作为输出,用于训练我的网络。在
def read_and_decode(filename):
filename_queue = tf.train.string_input_producer(filename, name='queue')
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features= {'train/coord2d': tf.FixedLenFeature([], tf.float32),
'train/coord3d': tf.FixedLenFeature([], tf.float32)})
coord2d = tf.cast(features['train/coord2d'], tf.float32)
coord3d = tf.cast(features['train/coord3d'], tf.float32)
return coord2d, coord3d
with tf.Session() as sess:
filename = ["train.tfrecords"]
dataset = tf.data.TFRecordDataset(filename)
c2d, c3d = read_and_decode(filename)
print(sess.run(c2d))
print(sess.run(c3d))
这是我的代码,但我并不真正理解它,因为我从教程等,所以我试图打印出c2d和c3d,看看它们是什么格式的,但我的程序一直在运行,根本没有打印任何东西,而且从未终止。c2d和c3d是否包含数据集中每个元素的2d和3d坐标?当训练网络作为输入和输出时,它们能直接使用吗?在
我也不知道他们应该是什么格式之前,他们可以作为输入到网络。我应该把它们转换回2d numpy数组还是2d张量?万一我怎么办?总的来说,我只是非常失落,所以任何一个圭达奇将是非常有帮助的!谢谢
您使用
tf.data.TFRecordDataset(filename)
是对的,但问题是dataset
与传递给sess.run()
的张量没有连接。在下面是一个简单的示例程序,可以生成一些输出:
相关问题 更多 >
编程相关推荐