Tensorflow:如何保存“DNNRegrestorv2”模型?python

2024-10-01 17:25:59 发布

您现在位置:Python中文网/ 问答频道 /正文

我在尝试保存经过培训的模型时遇到问题,我尝试了:

model.save('~/Desktop/models/')

但它给了我一个错误AttributeError: 'DNNRegressorV2' object has no attribute 'save'

我也尝试过:

tf.saved_model.save(model, mobilenet_save_path)

但它给了我一个错误ValueError: Expected a Trackable object for export, got <tensorflow_estimator.python.estimator.canned.dnn.DNNRegressorV2 object at 0x111cc4b70>.

有什么想法吗

>type(model)
<class 'tensorflow_estimator.python.estimator.canned.dnn.DNNRegressorV2'>

Tags: no模型modelobjectmodelssavetensorflow错误
1条回答
网友
1楼 · 发布于 2024-10-01 17:25:59

要保存估计器,您需要创建一个服务输入接收器。此函数构建tf.Graph的一部分,用于解析SavedModel接收的原始数据

tf.estimator.export模块包含帮助构建这些接收器的函数

下面的代码基于feature_列构建一个接收器,它接受序列化的tf.Example协议缓冲区,这通常与tf服务一起使用

tmpdir = tempfile.mkdtemp()

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
  tf.feature_column.make_parse_example_spec([input_column]))

estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)

您还可以从python加载并运行该模型:

imported = tf.saved_model.load(estimator_path)

def predict(x):
  example = tf.train.Example()
  example.features.feature["x"].float_list.value.extend([x])
  return imported.signatures["predict"](
    examples=tf.constant([example.SerializeToString()]))

print(predict(1.5))
print(predict(3.5))

点击here

相关问题 更多 >

    热门问题