keras.estimator.model_到_estimator无法warmstart或加载上一个checkpoint

2024-09-30 18:13:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我用keras模型估计函数训练张量流模型,然后用训练数据进行训练。这很好,我可以继续使用测试数据成功地进行预测。在

在一个单独的例行程序中,我希望能够用最新的训练检查点加载预先训练的估计器,并进行更多的预测(即,无需重新训练)。我已经看过warm_start_from,但在加载keras模型时似乎不可用。我对https://www.tensorflow.org/get_started/checkpoints的理解是,我可以从相同的keras模型创建一个新的估计器,并且第一次预测它将从我指定的目录加载检查点。在

下面的代码片段是我的尝试(最终estimator_model2将加载到一个单独的例程中,这只是为了演示)。在

modelConfig = tf.estimator.RunConfig('/myCheckpointpath', keep_checkpoint_max=1)

estimator_model = keras.estimator.model_to_estimator(keras_model=myKerasModel(inputShape, nOutputs), config=modelConfig)            
estimator_model.train(input_fn=lambda: input_fn(_trainData_2d, _trainLabels, batch_size=self.batchSize, shuffle=True, num_epochs=2))

estimator_model2 = keras.estimator.model_to_estimator(keras_model=myKerasModel(inputShape, nOutputs), config=modelConfig)                       
predictions = list(estimator_model2.predict(input_fn=lambda: input_fn(_testData_2d)))

从诊断中我可以看到它在执行最后一行时尝试加载检查点。然而,我得到一个错误,表明在训练期间保存的检查点并不包含新估计器所需的所有信息。这是错误:

^{pr2}$

如果有帮助的话,我可以展示keras模型,但我不认为这是问题所在。在

有谁能给我一个解决方案或建议一个更好的方法来加载一个以前训练过的值来做预测的估计器吗?在


Tags: tolambda模型configinputmodel检查点keras
1条回答
网友
1楼 · 发布于 2024-09-30 18:13:36

我对上述问题的解决方案是使用一种混合方法,即使用keras符号指定模型,然后将其放入tensorflow模型函数中,然后将其加载到和估计器中。通过采用这种方法,我可以像使用任何其他tensorflow模型一样保存和重新加载检查点。我认为这是使用直观的keras表示法的最佳组合,同时能够利用tensorflow估计器和数据工具。以下是我描述各种tensorflow调用设置方法的概述:

  1. 创建估计员:

    | estimator: tf.estimator.Estimator
    | config: tf.estimator.RunConfig                         #checkpointPath and saving spec for training
            | model_fn: tf.estimator.EstimatorSpec
                | myKerasModel                               #specify model. Doesn't have to be keras.
                        | keras.models.Model
                | loss: myLossFunction                       #train_and_eval only
                | optimizer: myOptimizerFunction             #train_and_eval only
                | training_hooks:tf.train.SummarySaverHook   #train_and_eval only - for saved diagnostics
                | evaluation_hooks:tf.train.SummarySaverHook #train_and_eval only - for saved diagnostics
                | predictions: model(data, training=False)   #predict only
    
  2. 培训和评估:

    | tf.estimator.train_and_evaluate
            | estimator: tf.estimator.Estimator                  #the estimator we created
            | train_spec: tf.estimator.TrainSpec
                                | input_fn: tf.data.Dataset      #specify the input function 
                                    | features, labels           #data to use
                                    | batch, shuffle, num_epochs #controls for data
            | eval_spec: tf.estimator.EvalSpec
                                | input_fn: tf.data.Dataset      #specify the input function
                                    | features, labels           #data to use
                                    | batch                      #controls for data
                                | throttle and start_delay       #specify when to start evaluation
    

注意,我们可以使用tf.estimator.Estimator.train和{},但是 不允许在培训期间进行评估,因此我们使用tf.estimator.train_and_evaluate。在

  1. 预测:

    | estimator: tf.estimator.Estimator.predict       #the estimator we created
            | input_fn: tf.data.Dataset               #specify the input function
                | features                            #data to use
    

相关问题 更多 >