在每N个阶段结束时保存模型权重

2024-05-17 03:42:49 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在训练一个神经网络,并希望在预测阶段每N个阶段保存一个模型权重。我建议这段代码草案,它的灵感来自@grovina的响应here。请你提些建议好吗? 提前谢谢。

from keras.callbacks import Callback

class WeightsSaver(Callback):
    def __init__(self, model, N):
        self.model = model
        self.N = N
        self.epoch = 0

    def on_batch_end(self, epoch, logs={}):
        if self.epoch % self.N == 0:
            name = 'weights%08d.h5' % self.epoch
            self.model.save_weights(name)
        self.epoch += 1

然后将其添加到fit调用中:要每5个阶段节省重量:

model.fit(X_train, Y_train, callbacks=[WeightsSaver(model, 5)])

Tags: name模型selfmodeldefcallbacktrain神经网络
2条回答

你应该在epoch端实现,而不是在batch端实现。同时,将模型作为__init__的参数传递也是多余的。

from keras.callbacks import Callback
class WeightsSaver(Callback):
  def __init__(self, N):
    self.N = N
    self.epoch = 0

  def on_epoch_end(self, epoch, logs={}):
    if self.epoch % self.N == 0:
      name = 'weights%08d.h5' % self.epoch
      self.model.save_weights(name)
    self.epoch += 1

您不需要传递回调的模型。它已经可以通过它的超级访问模型。所以删除__init__(..., model, ...)参数和self.model = model。无论如何,您应该能够通过self.model访问当前模型。你也在每个批处理结束时保存它,这不是你想要的,你可能希望它是on_epoch_end

但无论如何,你所做的都可以通过天真的modelcheckpoint callback来完成。你不需要写一个自定义的。你可以这样使用它

mc = keras.callbacks.ModelCheckpoint('weights{epoch:08d}.h5', 
                                     save_weights_only=True, period=5)
model.fit(X_train, Y_train, callbacks=[mc])

相关问题 更多 >