TensorFlow/Keras:如何使用model.checkpoint()恢复训练?

2024-10-02 02:29:28 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用model.checkpoint()保存我的最佳模型:

checkpoint = '/gdrive/MyDrive/mpmodel.ckpt'
cdir = os.path.dirname(checkpoint)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=False,
    save_best_only=True)

history = model.fit([timt,at],[wt,wbt],epochs=100,callbacks=[cp_callback])

这是我加载模型的方式:

latest = tf.train.latest_checkpoint(cdir)
model.load(latest)

现在,我想从上次结束的地方恢复训练。在this ipynb中提到:

Since the optimizer-state is recovered, you can resume training from exactly where you left off.

然而,它并没有确切地告诉我怎么做。请引导我


Tags: path模型youonlymodelsavetfcallback
1条回答
网友
1楼 · 发布于 2024-10-02 02:29:28

在model.load(最新版本)之后,您可以继续使用model.fit()

无论如何,我认为使用检查点回调并不常见。更常见的做法是使用model.save(model\u NAME),然后使用model=tf.keras.models.load\u model(model\u NAME)重新加载模型

MODEL_NAME是保存模型的文件夹

相关问题 更多 >

    热门问题