如何阅读由TFRecordWriter编写的Protobuf

2024-09-27 09:34:38 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图读取用tf.io.TFRecordWriter编写的数据,如下所示:

import tensorflow as tf
import numpy as np

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

tfrecord_filename = "test.tfrecord"
with tf.io.TFRecordWriter(tfrecord_filename) as writer:
    for i in range(4):
        a  = np.random.uniform(-1, 1, 5)
        a = tf.convert_to_tensor(a, dtype=tf.float32)
        a = tf.io.serialize_tensor(a)
        feature = {
          'a'   : _bytes_feature(a),
        }
        example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
        writer.write(example_proto.SerializeToString())  

然后我使用的模式如下所示: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/feature.protohttps://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto 和protoc一起解码。阅读程序是:

import test_pb2

parser = test_pb2.Example()
with open("test.tfrecord", "rb") as f:
    parser.ParseFromString(f.read())
    print(parser.feature)

使用ParseFromString方法,我希望能够恢复执行上述程序后写入的数据,但我始终得到:

RuntimeWarning: Unexpected end-group tag: Not all data was converted

我做错什么了?你知道吗


Tags: iotestimportparserbytesvalueexampletf
1条回答
网友
1楼 · 发布于 2024-09-27 09:34:38

您可以使用tf.data.TFRecordDataset读取/恢复数据

代码:

import tensorflow as tf
import numpy as np

print('TensorFlow:',tf.__version__)

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

tfrecord_filename = "test.tfrecord"
with tf.io.TFRecordWriter(tfrecord_filename) as writer:
    for i in range(4):
        a  = np.random.uniform(-1, 1, 5)
        a = tf.convert_to_tensor(a, dtype=tf.float32)
        a = tf.io.serialize_tensor(a)
        feature = {
          'a'   : _bytes_feature(a),
        }
        example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
        writer.write(example_proto.SerializeToString()) 

feature_description = {
    'a': tf.io.FixedLenFeature([], tf.string),
}

def parse_example(example_proto):
    parsed_example = tf.io.parse_single_example(
        example_proto, feature_description)
    a =  tf.io.parse_tensor(parsed_example['a'], tf.float32)
    ### perform operations on a
    return a 

dataset = tf.data.TFRecordDataset('test.tfrecord')
dataset = dataset.map(parse_example)

for sample in dataset:
    print(sample)

输出:

TensorFlow: 2.0.0
tf.Tensor([ 0.95179933  0.39751884 -0.3064195   0.8001448   0.3104681 ], shape=(5,), dtype=float32)
tf.Tensor([ 0.4199625  -0.03338468  0.61874187 -0.31352085  0.63478047], shape=(5,), dtype=float32)
tf.Tensor([-0.23349774  0.37200847 -0.00269533 -0.56773156  0.8720373 ], shape=(5,), dtype=float32)
tf.Tensor([-0.142148   -0.6130724   0.5867819  -0.01797233  0.36230987], shape=(5,), dtype=float32)

相关问题 更多 >

    热门问题