TensorFlow:使用监控培训课程时模型的验证

2024-09-30 14:16:11 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用dataset API导入培训和验证数据。我有TF1.2。所以我只能使用可重新初始化的迭代器,而不能使用feedable iterator,因为feedable iterator只能从tf1.4获得。在

1)如果我们只想培训网络,我们可以简单地使用监控培训课程。但是,当我们想在培训时进行验证时,我们应该怎么做呢?我们是否应该放弃监控的培训课程,而使用低级别的培训课程?在

train_dataset = tf.contrib.data.TFRecordDataset([FLAGS.data_dir + "train.tfrecords"])
train_dataset = train_dataset.map(_parse_records)
train_dataset.shuffle(buffer_size=1000)
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.batch(FLAGS.batch_size)

validation_dataset = tf.contrib.data.TFRecordDataset([FLAGS.data_dir + "test.tfrecords"])
validation_dataset = test_dataset.map(_parse_records)
validation_dataset = test_dataset.batch(FLAGS.batch_size)

iterator = tf.contrib.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)

train_init_op = iterator.make_initializer(train_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

next_example, next_label = iterator.get_next()
loss = model_function(next_example, next_label)

training_op = tf.train.AdagradOptimizer(...).minimize(loss)

with tf.train.MonitoredTrainingSession(...) as sess:
    sess.run(train_init_op)
    while not sess.should_stop():
        sess.run(training_op)

   # HOW TO VALIDATE?

2)是否有任何方法可以在epoch中间使用可重新初始化的迭代器验证模型,因为在迭代器之间切换时,需要从数据集的开始初始化迭代器。用可重新初始化的迭代器是可能的还是我们必须切换到可馈迭代器来实现这一点?在

这是TF数据集教程中提供的示例。如果在一个纪元中有100次迭代,我们可以在迭代50时使用可重新初始化的迭代器来验证模型吗?(我认为可以使用feedable iterator)

^{pr2}$

3)当使用可重新初始化的迭代器时,在epoch的最后一次迭代中,如果剩余的训练数据样本小于所需的批大小,会发生什么情况? 剩下的几个样本是在减少批量的情况下使用还是被忽略?在


Tags: 数据datatfbatchtraincontribdataset课程
2条回答

对于你的第3个问题,我认为张量流的表现很差。对于最后一批,它可能有较少的样本。这会经常发生(总是?)在训练中导致“不兼容形状”错误。关于从TensorFlow 1.4开始如何解决这个问题,请参见https://stackoverflow.com/a/48331954/2184122

请看一下How to switch between training and validation dataset with tf.MonitoredTrainingSession? 我想你会找到1)和2)的答案。 您可以使用feed-dict来更改要评估的数据集,或者只是重新初始化它。从链接:

...
training_iterator = training_ds.make_initializable_iterator()
validation_iterator = validation_ds.make_initializable_iterator()
...
sess.run(next_element, feed_dict={handle: training_handle})
...
sess.run(next_element, feed_dict={handle: validation_iterator })

相关问题 更多 >

    热门问题