Django会话中的Keras序列模型?

2024-10-02 08:25:08 发布

您现在位置:Python中文网/ 问答频道 /正文

我不知道如何在作为服务器运行的Django应用程序中“保存”Keras模型的图形/会话。你知道吗

我已经构建了Keras序列模型DeepModel,它作为一个独立的Python模块运行良好。但现在我想将其嵌入Django应用程序中,在该应用程序中我为模型定义了以下处理程序:

# instantiated by Django app
class DeepModelManager:

def __init__(self, params):

    self.graph = tf.Graph()
    self.sess = tf.Session()
    K.set_session(self.sess)
    with self.sess.as_default():
        self.instance = DeepModel(params)
        self.model = self.instance.build()

        optimizer = Adam()
        loss = "categorical_crossentropy"
        metrics = ["accuracy"]
        self.model.compile(loss=loss, optimizer=optimizer, metrics=metrics)

def train(self):
    X = ...
    y = ...

    K.set_session(self.sess)
    with self.sess.as_default():
        H = self.model.fit(X, y, epochs=20)

该过程的工作方式如下:

  • 首先Django应用程序等待来自外部客户机的init请求,然后它实例化DeepModelManager(作为views模块变量),该模块实例化/构建/编译Keras DeepModel
  • 然后Django应用程序等待列车请求,该请求应触发上述train功能,即模型拟合

但是每当运行train方法时,我总是会遇到错误

ValueError: Tensor("training/Adam/Const:0", shape=(), dtype=float32) must be from the same graph as Tensor("sub:0", shape=(), dtype=float32).

我怀疑这是因为TensorFlow会话(或图形)在模型初始化和训练之间被清除了。这就是为什么我试着玩tf.Session()tf.Graph()(在独立的DeepModel版本中,我不碰它),但是没有用。你知道吗


Tags: 模块django模型self应用程序modeltfas

热门问题