我不知道如何在作为服务器运行的Django应用程序中“保存”Keras模型的图形/会话。你知道吗
我已经构建了Keras序列模型DeepModel
,它作为一个独立的Python模块运行良好。但现在我想将其嵌入Django应用程序中,在该应用程序中我为模型定义了以下处理程序:
# instantiated by Django app
class DeepModelManager:
def __init__(self, params):
self.graph = tf.Graph()
self.sess = tf.Session()
K.set_session(self.sess)
with self.sess.as_default():
self.instance = DeepModel(params)
self.model = self.instance.build()
optimizer = Adam()
loss = "categorical_crossentropy"
metrics = ["accuracy"]
self.model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
def train(self):
X = ...
y = ...
K.set_session(self.sess)
with self.sess.as_default():
H = self.model.fit(X, y, epochs=20)
该过程的工作方式如下:
DeepModelManager
(作为views
模块变量),该模块实例化/构建/编译Keras DeepModel
train
功能,即模型拟合但是每当运行train
方法时,我总是会遇到错误
ValueError: Tensor("training/Adam/Const:0", shape=(), dtype=float32) must be from the same graph as Tensor("sub:0", shape=(), dtype=float32).
我怀疑这是因为TensorFlow会话(或图形)在模型初始化和训练之间被清除了。这就是为什么我试着玩tf.Session()
和tf.Graph()
(在独立的DeepModel
版本中,我不碰它),但是没有用。你知道吗
目前没有回答
相关问题 更多 >
编程相关推荐