`MonitoredTrainingSession()`如何与“restore”和“testing mode”一起工作?

2024-06-16 11:01:34 发布

您现在位置:Python中文网/ 问答频道 /正文

在Tensorflow中,我们可以使用Between-graph Replication为分布式培训构建和创建多个Tensorflow会话。MonitoredTrainingSession()协调多个Tensorflow会话,并且有一个参数checkpoint_dir来恢复Tensorflow会话/图。现在我有以下问题:

  1. 我们通常使用tf.train.Saver()对象来恢复张量流图。但是我们如何使用MonitoredTrainingSession()恢复它们呢?在
  2. 由于我们运行多个进程,每个进程都会构建并创建一个用于培训的Tensorflow会话,我想知道我们是否也必须在培训后运行多个进程来进行测试(或预测)。换句话说,MonitoredTrainingSession()如何与测试(或预测)模式一起工作?在

我读了Tensorflow博士,但没有找到这两个问题的答案。如果有人能找到解决办法,我真的很感激。谢谢!在


Tags: 对象参数进程tftensorflowdir分布式train
3条回答

您应该导入元图,然后恢复模型。 从下面的片段中获得灵感,应该对你有用。在

    self.sess = tf.Session()
    ckpt = tf.train.latest_checkpoint("location-of/model")
    saver = tf.train.import_meta_graph(ckpt + '.meta', clear_devices=True)
    saver.restore(self.sess, ckpt)
  1. 看来恢复是为你处理的。在API文档中,它指出调用MonitoredTrainingSession将返回MonitoredSession的实例,该实例在创建时“…如果存在检查点,则还原变量…”

  2. 查看tf.contrib.learn.Estimator(..).predict(..),更具体地说是tf.contrib.learn.Estimator(..)._infer_model(..)方法here和{a2}。他们还创建了一个MonitoredSession。

简短回答:

  1. 您需要将全局步骤传递给传递给mon的优化器_sess.运行. 这使得保存和检索保存的检查点成为可能。在
  2. 可以通过单个MonitoredTrainingSession同时运行培训+交叉验证会话。首先,您需要通过同一个图的不同流传递训练批和交叉验证批(我建议您查找this guide以了解如何执行此操作的信息)。第二,你必须-去孟_sess.运行()—传递训练流的优化器,以及交叉验证流丢失的参数(/parameter you want track)。如果要与培训分开运行测试会话,只需在图形中运行测试集,并在图形中只运行test_loss(/other parameters you want track)。有关如何完成此操作的更多详细信息,请查看下面的内容。在

长话短说:

我会更新我的答案,因为我自己得到了更好的看法,可以做什么列车监控会话(列车监控训练课程只是创建一个列车监控会话,如source code)所示。在

下面是一个示例代码,演示如何每隔5秒将检查点保存到“./ckpt_dir”。中断时,它将在最后保存的检查点重新启动:

def train(inputs, labels_onehot, global_step):
    out = tf.contrib.layers.fully_connected(
                            inputs,
                            num_outputs=10,
                            activation_fn=tf.nn.sigmoid)
    loss = tf.reduce_mean(
             tf.reduce_sum(
                tf.nn.sigmoid_cross_entropy_with_logits(
                            logits=out,
                            labels=labels_onehot), axis=1))
    train_op = opt.minimize(loss, global_step=global_step)
    return train_op

with tf.Graph().as_default():
    global_step = tf.train.get_or_create_global_step()
    inputs = ...
    labels_onehot = ...
    train_op = train(inputs, labels_onehot, global_step)

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir='./ckpt_dir',
        save_checkpoint_secs=5,
        hooks=[ ... ] # Choose your hooks
    ) as mon_sess:
        while not mon_sess.should_stop():
            mon_sess.run(train_op)

为了实现这一目标,在监控培训课程中所发生的事情实际上是三件事:

  1. 在列车监控训练课程创建一个tf.train.Scaffold对象,其工作方式类似于web中的spider;它收集了训练、保存和加载模型所需的部分。在
  2. 它创建一个tf.train.ChiefSessionCreator对象。我对这一点的了解是有限的,但据我的理解,它是用于当您的tf算法在多个服务器上传播时使用的。我的看法是,它告诉运行文件的计算机它是主计算机,检查点目录应该保存在这里,日志记录者应该在这里记录他们的数据,等等
  3. 它创建一个tf.train.CheckpointSaverHook,用于保存检查点。在

为了让它发挥作用tf.train.CheckpointSaverHook以及tf.train.ChiefSessionCreator必须传递对检查点目录和scaffold的相同引用。如果列车监控训练课程由于上述示例中的参数将由上述3个组件实现,因此它看起来如下所示:

^{pr2}$

为了进行训练+交叉验证会话,可以使用列车监控会话。与partial一起运行\u step_fn(),后者运行会话调用而不调用任何钩子。这看起来是你训练你的模型进行n迭代,然后你运行你的测试集,重新初始化你的迭代器,回到训练你的模型,等等。当然,你必须设置你的变量来重用=自动再利用这样做的时候。代码中的方法如下所示:

from functools import partial

# Build model
...

with tf.variable_scope(..., reuse=tf.AUTO_REUSE):
    ...

...

def step_fn(fetches, feed_dict, step_context):
    return step_context.session.run(fetches=fetches, feed_dict=feed_dict)

with tf.train.MonitoredTrainingSession(
                checkpoint_dir=...,
                save_checkpoint_steps=...,
                hooks=[...],
                ...
                ) as mon_sess:

                # Initialize iterators (assuming tf.Databases are used)
                mon_sess.run_step_fn(
                           partial(
                               step_fn, 
                               [train_it.initializer, 
                                test_it.initializer, 
                                ...
                               ], 
                               {}
                            )
                )

                while not mon_sess.should_stop():
                    # Train session
                    for i in range(n):
                        try:
                            train_results = mon_sess.run(<train_fetches>)
                        except Exception as e:
                            break

                    # Test session
                    while True:
                        try:
                            test_results = mon_sess.run(<test_fetches>)
                        except Exception as e:
                            break

                    # Reinitialize parameters
                    mon_sess.run_step_fn(
                               partial(
                                  step_fn, 
                                  [train_it.initializer, 
                                   test_it.initializer, 
                                   ...
                                  ], 
                                  {}
                               )
                    )

部分函数只在mon中使用的步骤_fn上执行currying(函数编程中的经典函数)_sess.run_步骤\fn(). 上面的代码还没有经过测试,在开始测试之前,您可能需要重新初始化train_it,但希望现在可以清楚地知道如何在同一次运行中同时运行训练集和验证集。如果你想在同一个图中同时绘制训练曲线和测试曲线,这还可以与tensorboard的custom_scalar tool一起使用。在

最后,这是这个功能的最好的实现,我个人希望tensorflow将来能使这个功能的实现更容易,因为它非常乏味,而且可能没有那么高效。我知道有一些工具,比如Estimator可以运行train_and_evaluate函数,但是由于这重建了每个列车和交叉验证运行之间的关系图,它非常适合如果只在一台计算机上运行,则效率很高。我在某个地方读到Keras+tf有这个功能,但是由于我不使用Keras+tf,这不是一个选项。不管怎样,我希望这能帮助其他人在同样的事情上挣扎!在

相关问题 更多 >