我试图读取用tf.io.TFRecordWriter
编写的数据,如下所示:
import tensorflow as tf
import numpy as np
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
tfrecord_filename = "test.tfrecord"
with tf.io.TFRecordWriter(tfrecord_filename) as writer:
for i in range(4):
a = np.random.uniform(-1, 1, 5)
a = tf.convert_to_tensor(a, dtype=tf.float32)
a = tf.io.serialize_tensor(a)
feature = {
'a' : _bytes_feature(a),
}
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example_proto.SerializeToString())
然后我使用的模式如下所示: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/feature.proto和https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto 和protoc一起解码。阅读程序是:
import test_pb2
parser = test_pb2.Example()
with open("test.tfrecord", "rb") as f:
parser.ParseFromString(f.read())
print(parser.feature)
使用ParseFromString
方法,我希望能够恢复执行上述程序后写入的数据,但我始终得到:
RuntimeWarning: Unexpected end-group tag: Not all data was converted
我做错什么了?你知道吗
您可以使用
tf.data.TFRecordDataset
读取/恢复数据代码:
输出:
相关问题 更多 >
编程相关推荐