tensorflow:保存模型和加载模型

2024-10-01 11:29:19 发布

您现在位置:Python中文网/ 问答频道 /正文

目前正在尝试使此repo生效

我试图将经过训练的模型保存在本地机器中,以便以后应用。我在tensorflow的doc中读到,通过调用tf.save_model.save(object)保存模型似乎非常直观。但我不知道如何申请

原始代码如下:model.py 以下是我的变化:

import tensorflow as tf

class ICON(tf.Module): # make it a tensorflow modul

    def __init__(self, config, embeddingMatrix, session=None):

    def _build_inputs(self):

    def _build_vars(self):

    def _convolution(self, input_to_conv):

    def _inference(self):

    def batch_fit(self, queries, ownHistory, otherHistory, labels):

        feed_dict = {self._input_queries: queries, self._own_histories: ownHistory, self._other_histories: otherHistory,
                     self._labels: labels}
        loss, _ = self._sess.run([self.loss_op, self.train_op], feed_dict=feed_dict)
        return loss

    def predict(self, queries, ownHistory, otherHistory, ):

        feed_dict = {self._input_queries: queries, self._own_histories: ownHistory, self._other_histories: otherHistory}
        return self._sess.run(self.predict_op, feed_dict=feed_dict)

    def save(self): # attempt to save the model
        tf.saved_model.save(
            self, './output/model')

上面的代码生成ValueError,如下所示: ValueError: Tensor("ICON/CNN/embedding_matrix:0", shape=(16832, 300), dtype=float32_ref) must be from the same graph as Tensor("saver_filename:0", shape=(), dtype=string).


Tags: selfinputlabelsmodelsavetftensorflowdef
1条回答
网友
1楼 · 发布于 2024-10-01 11:29:19

我相信你可以用tf.train.Saver来做这个

def save(self): # attempt to save the model
    saver = tf.train.Saver()
    saver.save(self._sess, './output/model')

然后可以通过这种方式恢复模型

saver = tf.train.import_meta_graph('./output/model.meta')
with tf.Session() as sess:
    saver.restore(sess, tf.train.latest_checkpoint('./output'))

您可能还发现这个tutorial有助于更好地理解这一点

编辑:如果要使用SavedModel

def save(self):
    inputs = {'input_queries': self._input_queries, 'own_histories': self._own_histories, 'other_histories': self._other_histories}
    outputs = {'output': self.predict_op}
    tf.saved_model.simple_save(self._sess, './output/model', inputs, outputs)

然后,您可以使用tf.contrib.predictor.from_saved_模型使用SavedModel加载和服务

from tensorflow.contrib.predictor import from_saved_model
predictor = from_saved_model('./output/model')
predictions = predictor({'input_queries': input_queries, 'own_histories': own_histories, 'other_histories': other_histories})

相关问题 更多 >