是否可以在会话外部或在创建会话之间使用钩子/回调?

2024-05-18 08:34:19 发布

您现在位置:Python中文网/ 问答频道 /正文

使用train_and_evaluate()可以执行一个时间表,根据我传递的规范对模型进行训练和评估。我可以用EvalSpecTrainSpec注册一些钩子,但也有一些限制。在

问题是我只能有一个SessionRunHook,它可以作为回调函数使用,但总是在会话中。在

我的问题是我的日程安排比较复杂。在评估过程中,我还想量化模型,并进一步评估量化模型。这里的问题是,如前所述,如果我使用SessionRunHook类对象,那么我总是在会话中。在

所以问题是是否有一种方法可以使用train_and_evaluate()并在其中注册一些回调:

train_and_evaluate(..):

  # .. deep down ..

  while <condition>:
    with tf.Session() as train_sess:
      # Do training ..

    if the_callback_i_want:
      the_callback_i_want()

    with tf.Session() as eval_sess:
      # Do evaluation ..

这可能吗?在


Tags: andthe模型sessiontfaswithcallback
1条回答
网友
1楼 · 发布于 2024-05-18 08:34:19

我想您可以实现您自己的SessionHook子类的begin方法。在

为了这个例子,我使用了iris code(参见this doc)。在

import tensorflow as tf

def the_callback_i_want():
    # You need to work in a new graph so let's create a new one
    g = tf.Graph()
    with g.as_default():
        x = tf.get_variable("x", ())
        x = tf.assign_add(x, 1)
        init = tf.global_variables_initializer()
        with tf.Session() as sess:   
            sess.run(init)
            print("I'm here !", sess.run(x))


class MyHook(tf.train.SessionRunHook):

  def begin(self):
    """Called once before using the session.

    When called, the default graph is the one that will be launched in the
    session.  The hook can modify the graph by adding new operations to it.
    After the `begin()` call the graph will be finalized and the other callbacks
    can not modify the graph anymore. Second call of `begin()` on the same
    graph, should not change the graph.
    """
    the_callback_i_want()


import iris_data
# Fetch the data
(train_x, train_y), (test_x, test_y) = iris_data.load_data()

# Feature columns describe how to use the input.
my_feature_columns = []
for key in train_x.keys():
    my_feature_columns.append(tf.feature_column.numeric_column(key=key))

# Build 2 hidden layer DNN with 10, 10 units respectively.
classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns, hidden_units=[10, 10],  n_classes=3)

# Fetch the data
(train_x, train_y), (test_x, test_y) = iris_data.load_data()

# Feature columns describe how to use the input.
my_feature_columns = []
for key in train_x.keys():
    my_feature_columns.append(tf.feature_column.numeric_column(key=key))

train_spec = tf.estimator.TrainSpec(input_fn=lambda:iris_data.train_input_fn(train_x, train_y,
                                                 10), max_steps=100)
eval_spec = tf.estimator.EvalSpec(input_fn=lambda:iris_data.eval_input_fn(test_x, test_y,
                                                10), hooks=[MyHook()])
tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)

它打印出:

^{pr2}$

相关问题 更多 >