如何将tfrecord数据读入tensors/numpy数组?

2024-10-03 21:36:53 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个tfrecord文件,其中存储了一个数据列表,每个元素都有二维坐标和三维坐标。坐标是dtype float64的2d numpy数组。在

这些是我用来存储它们的功能。在

feature = {'train/coord2d': _floats_feature(projC),
                   'train/coord3d': _floats_feature(sChair)}

我把它们放在一个浮动列表中。在

^{pr2}$

现在我正在尝试接收它们,以便将它们输入到我的网络中进行训练。我希望2d坐标作为输入,3d作为输出,用于训练我的网络。在

def read_and_decode(filename):
    filename_queue = tf.train.string_input_producer(filename, name='queue')
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
    serialized_example,

    features= {'train/coord2d': tf.FixedLenFeature([], tf.float32),
            'train/coord3d': tf.FixedLenFeature([], tf.float32)})

    coord2d = tf.cast(features['train/coord2d'], tf.float32)
    coord3d = tf.cast(features['train/coord3d'], tf.float32)

    return coord2d, coord3d



with tf.Session() as sess:
    filename = ["train.tfrecords"]
    dataset = tf.data.TFRecordDataset(filename)
    c2d, c3d = read_and_decode(filename)
   print(sess.run(c2d))
   print(sess.run(c3d))

这是我的代码,但我并不真正理解它,因为我从教程等,所以我试图打印出c2d和c3d,看看它们是什么格式的,但我的程序一直在运行,根本没有打印任何东西,而且从未终止。c2d和c3d是否包含数据集中每个元素的2d和3d坐标?当训练网络作为输入和输出时,它们能直接使用吗?在

我也不知道他们应该是什么格式之前,他们可以作为输入到网络。我应该把它们转换回2d numpy数组还是2d张量?万一我怎么办?总的来说,我只是非常失落,所以任何一个圭达奇将是非常有帮助的!谢谢


Tags: 网络readqueueexampletftrainfilenamefeature
1条回答
网友
1楼 · 发布于 2024-10-03 21:36:53

您使用tf.data.TFRecordDataset(filename)是对的,但问题是dataset与传递给sess.run()的张量没有连接。在

下面是一个简单的示例程序,可以生成一些输出:

def decode(serialized_example):
  # NOTE: You might get an error here, because it seems unlikely that the features
  # called 'coord2d' and 'coord3d', and produced using `ndarray.flatten()`, will
  # have a scalar shape. You might need to change the shape passed to
  # `tf.FixedLenFeature()`.
  features = tf.parse_single_example(
      serialized_example,
      features={'train/coord2d': tf.FixedLenFeature([], tf.float32),
                'train/coord3d': tf.FixedLenFeature([], tf.float32)})

  # NOTE: No need to cast these features, as they are already `tf.float32` values.
  return features['train/coord2d'], features['train/coord3d']

filename = ["train.tfrecords"]
dataset = tf.data.TFRecordDataset(filename).map(decode)
iterator = dataset.make_one_shot_iterator()
c2d, c3d = iterator.get_next()

with tf.Session() as sess:

  try:

    while True:
      print(sess.run((c2d, c3d)))

  except tf.errors.OutOfRangeError:
    # Raised when we reach the end of the file.
    pass

相关问题 更多 >