在Tensorflow中批量迭代是如何工作的?

2024-09-27 21:23:16 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图在我的数据上重用PTB language model,但缺乏Tensorflow的知识,无法理解它如何处理训练数据上的批处理迭代。以下是我在培训期间如何理解批处理迭代:

while epoch <= maxepoch do
  for minibatch in data_iterator() do
    model.forward(minibatch)
    (...)
  end
end

再简单不过了,不是吗? 类似的事情在很多其他框架中都有,但在Tensorflow中却没有:) 以下是来自官方PTB语言模型教程的minibatch函数示例:

^{pr2}$

此函数在调用后返回x输入和y目标。我在这里没有看到Python迭代器的迹象,但是有一个对tf.strided_slice的调用,它使用tf.train.range_input_producer生成的i索引,所以这应该模拟数据上的滑动窗口。然而,在训练之前,函数只被调用一次,那么它如何迭代我的数据呢?这还不清楚。有人能解释这个“神奇的”完全模糊的张量流机制吗?在


Tags: 数据函数informodeltftensorflowdo
1条回答
网友
1楼 · 发布于 2024-09-27 21:23:16

“魔力”隐藏在调用^{}的行中:

i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()

。。。它创建一个操作,该操作从队列中弹出值,并保存0..epoch_size-1个整数。换句话说,它在0..epoch_size-1范围内迭代。在


是的,这似乎违反直觉。下面是一个在tensorflow中处理队列的简单可运行示例:

^{pr2}$

在运行时,您应该看到从09的值,然后再从0到{}的5个值。注意,sess.run计算相同的张量index,但每次都得到一个不同的值。可以添加更多依赖于index的操作,它们将使用新值index进行计算。在

还请注意,队列在另一个线程中运行,因此为了使用tf.train.range_input_producer,必须启动一个Coordinator并生成多个线程(最后停止它们)。如果您试图在不使用aCoordinator的情况下运行同一个示例,sess.run(index)将阻止脚本的执行。在

你可以用这个例子来演示,例如set shuffle=True,等等


回到PTB制作人片段:

i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
x = tf.strided_slice(data, [0, i*num_steps], [batch_size, (i+1)*num_steps])
x.set_shape([batch_size, num_steps])
y = tf.strided_slice(data, [0, i*num_steps+1], [batch_size, (i+1)*num_steps+1])
y.set_shape([batch_size, num_steps])

现在应该很清楚,即使xy被定义为简单张量,它们实际上是{}片上的迭代器。所有的线程工作都由^{}负责。因此,调用优化操作(取决于xy)将自动执行新的批处理。在


建议阅读:

相关问题 更多 >

    热门问题