TensorFlow 2:重新保存保存的模型?

2024-07-02 04:46:26 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试加载以SavedModel格式保存的模型,然后在其上应用一些计算并重新保存整个管道。最小代码如下(从this Kaggle kernel):

# Load the model. After that, the `delg_model` will be of the type
# `tensorflow.python.training.tracking.tracking.AutoTrackable`
delg_model = tf.saved_model.load('path/to/saved/model/dir')

# we don't need the whole model, so we prune it. After that, the
# `global_feature_extraction_fn` will be of the type
# `tensorflow.python.eager.wrap_function.WrappedFunction`
delg_input_tensor_names = ['input_image:0', 'input_scales:0']
global_feature_extraction_fn = delg_model.prune(
    delg_input_tensor_names, ['global_descriptors:0'])

Question: Now, I want to save the global_feature_extraction_fn, with some other TF ops to post-process the output, in the same SavedModel format. What is the correct way to do that?


我试过的

我试图遵循saving custom models to SavedModel format的TensorFlow文档并定义tf.Module

class DelgModule(tf.Module):
    def __init__(self):
        super().__init__()

    @tf.function(input_signature=[
        tf.TensorSpec(shape=[None, None, 3], name='input_image')
    ])
    def call(self, input_tensor):
        # custom function on top of the model's output
        embedding = tf.nn.l2_normalize(
            global_feature_extraction_fn(
                input_tensor,                         # input_image
                tf.convert_to_tensor([0.7, 1.0, 1.4]) # input_scales
            )[0],
            axis=1, name='l2_normalized_output')
        return output_tensors = {
            'global_descriptor': embedding
        }

delg_module = DelgModule()

然后我在测试映像上运行它来构建tf.function,并确保它正常工作(生成正确的输出)。但当我试图将其保存为以下内容时:

tf.saved_model.save(
    delg_module, export_dir='./delg_resaved',
    signatures={
        'serving_default': delg_module.call
    })

由此产生的模型是不正确的。原来的一个重达90MB,而delg_resaved中的一个重达800KB。我还尝试在tf.saved_model.load函数内部执行DelgModule.call,以便图形创建和变量加载完全在tf.function内部完成,但结果保持不变


Tags: thetoinputoutputmodeltffunctionglobal