如何将固定长度功能写入tfrecord

2024-05-15 19:54:04 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在努力学习编写tensorflow tfrecord文件的基础知识。我正在用python编写一个带有ndarray的简单示例,但是由于某种原因,当我读取它时,它的长度必须是可变的,并且作为SparseTensor读取

举个例子

def serialize_tf_record(features, targets):
    record = {
        'shape': tf.train.Int64List(value=features.shape),
        'features': tf.train.FloatList(value=features.flatten()),
        'targets': tf.train.Int64List(value=targets),
    }

    return build_tf_example(record)

def deserialize_tf_record(record):
    tfrecord_format = {
        'shape': tf.io.VarLenFeature(tf.int64),
        'features': tf.io.VarLenFeature(tf.float32),
        'targets': tf.io.VarLenFeature(tf.int64),
    }

    features_tensor = tf.io.parse_single_example(record, tfrecord_format)
    return features_tensor

有人能解释为什么这会写可变长度的记录吗?它在代码中是固定的,但我似乎无法用tensorflow知道其固定的方式来编写它。tensorflow文档在这里非常可怕。有人能帮我澄清一下API吗


Tags: ioreturnvalueexampletftensorflowdeftfrecord
1条回答
网友
1楼 · 发布于 2024-05-15 19:54:04

您应该提供更多的上下文代码,比如build_tf_example函数以及特性和目标的示例

下面是一个返回密集张量的示例:


import numpy as np
import tensorflow as tf

def build_tf_example(record):
    return tf.train.Example(features=tf.train.Features(feature=record)).SerializeToString()

def serialize_tf_record(features, targets):
    record = {
        'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=features.shape)),
        'features': tf.train.Feature(float_list=tf.train.FloatList(value=features.flatten())),
        'targets': tf.train.Feature(int64_list=tf.train.Int64List(value=targets)),
    }

    return build_tf_example(record)

def deserialize_tf_record(record):
    tfrecord_format = {
        'shape': tf.io.FixedLenSequenceFeature((), dtype=tf.int64, allow_missing=True),
        'features': tf.io.FixedLenSequenceFeature((), dtype=tf.float32, allow_missing=True),
        'targets': tf.io.FixedLenSequenceFeature((), dtype=tf.int64, allow_missing=True),
    }

    features_tensor = tf.io.parse_single_example(record, tfrecord_format)
    return features_tensor

def main():
    features = np.zeros((3, 5, 7))
    targets = np.ones((4,), dtype=int)
    tf.print(deserialize_tf_record(serialize_tf_record(features, targets)))


if __name__ == '__main__':
    main()
  • 我将record转换为一个功能字典(以便轻松地序列化)
  • 据我所知,您的每个特征都可以是一个数组(与标量值相对),因此您可以使用FixedLenSequenceFeature输入特征来解析它,从而构建一个密集张量而不是稀疏张量

相关问题 更多 >