基于张量流估计的转移学习/再训练

2024-09-30 04:29:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我一直不知道如何在新的TFEstimator API中使用转移学习/最后一层再培训。在

Estimator需要一个model_fn,其中包含网络架构、培训和评估操作,如{a2}中所定义。使用CNN架构的model_fn的一个例子是here。在

如果我想重新培训inception架构的最后一层,例如,我不确定是否需要在此model_fn中指定整个模型,然后加载预先训练的权重,或者是否有一种方法像“传统”方法那样使用保存的图(例如here)。在

这是一个issue,但仍然是开放的,我不清楚答案。在


Tags: 方法模型网络apia2modelhere定义
1条回答
网友
1楼 · 发布于 2024-09-30 04:29:00

可以在模型定义期间加载元图,并使用SessionRunHook从ckpt文件加载权重。在

def model(features, labels, mode, params):
    # Create the graph here

    return tf.estimator.EstimatorSpec(mode, 
            predictions,
            loss,
            train_op,
            training_hooks=[RestoreHook()])

SessionRunHook可以是:

^{pr2}$

这样,在第一步加载权重,并在模型检查点进行训练时保存。在

相关问题 更多 >

    热门问题