为什么for循环在tensorflow中如此缓慢

2024-10-06 06:52:24 发布

您现在位置:Python中文网/ 问答频道 /正文

所以我知道这与tensorflow构建图形时,它做得不好有关。。。“有效地”。下面是我正在运行的伪代码:

@tf.function
def parTest(x_in):
    res = 0
    for i in range(5000):
        res += x_in + i
    return res

不使用tensorflow运行该功能需要0.002秒,但是使用tensorflow运行该功能需要10到20秒。这对我来说毫无意义,这是怎么回事?还有,我如何修复它?这里res的实际值显然可以用一种更有效的方法来计算,但我面临的真正问题是,我有一个for循环,其中每个迭代都有许多可以独立运行的迭代,但是tensorflow拒绝这样做,并且一个接一个地运行它们,就像这个虚拟示例一样。那么我如何告诉tensorflow不要这样做呢


Tags: 方法代码in功能图形示例forreturn
1条回答
网友
1楼 · 发布于 2024-10-06 06:52:24

在TensorFlow中,循环从来都不是非常有效的。然而,这个函数对TensorFlow尤其不利,因为它将尝试静态地“展开”整个循环。也就是说,它不会将您的函数“翻译”为^{},而是在每次迭代中创建5000个操作副本。这是一个非常大的图表,在上面总是按顺序运行。事实上,我在TensorFlow 2.2.0中得到了一个警告,它将您指向这个信息页面:"WARNING: Large unrolled loop detected"

正如该链接中提到的,问题在于TensorFlow(至少目前)无法检测任意迭代器上的循环,即使它们是一个简单的range,因此它只能在Python中运行循环并创建相应的操作。您可以通过自己编写^{}来避免这种情况,或者,多亏了AutoGraph,只需将range替换为^{}

import tensorflow as tf
@tf.function
def parTest(x_in):
    res = 0
    for i in tf.range(5000):
        res += x_in + i
    return res

尽管如此,编写自己的^{}(只要绝对必要,因为矢量化操作总是会更快)可以让您更明确地控制parallel_iterations参数之类的细节

相关问题 更多 >