目前正在尝试使此repo生效
我试图将经过训练的模型保存在本地机器中,以便以后应用。我在tensorflow的doc中读到,通过调用tf.save_model.save(object)
保存模型似乎非常直观。但我不知道如何申请
原始代码如下:model.py 以下是我的变化:
import tensorflow as tf
class ICON(tf.Module): # make it a tensorflow modul
def __init__(self, config, embeddingMatrix, session=None):
def _build_inputs(self):
def _build_vars(self):
def _convolution(self, input_to_conv):
def _inference(self):
def batch_fit(self, queries, ownHistory, otherHistory, labels):
feed_dict = {self._input_queries: queries, self._own_histories: ownHistory, self._other_histories: otherHistory,
self._labels: labels}
loss, _ = self._sess.run([self.loss_op, self.train_op], feed_dict=feed_dict)
return loss
def predict(self, queries, ownHistory, otherHistory, ):
feed_dict = {self._input_queries: queries, self._own_histories: ownHistory, self._other_histories: otherHistory}
return self._sess.run(self.predict_op, feed_dict=feed_dict)
def save(self): # attempt to save the model
tf.saved_model.save(
self, './output/model')
上面的代码生成ValueError,如下所示:
ValueError: Tensor("ICON/CNN/embedding_matrix:0", shape=(16832, 300), dtype=float32_ref) must be from the same graph as Tensor("saver_filename:0", shape=(), dtype=string).
我相信你可以用tf.train.Saver来做这个
然后可以通过这种方式恢复模型
您可能还发现这个tutorial有助于更好地理解这一点
编辑:如果要使用SavedModel
然后,您可以使用tf.contrib.predictor.from_saved_模型使用SavedModel加载和服务
相关问题 更多 >
编程相关推荐