TensorFlow的地图只能在CPU上运行

2024-09-29 01:28:56 发布

您现在位置:Python中文网/ 问答频道 /正文

当我试图让TensorFlow的map_fn在我的GPU上运行时,遇到了一个奇怪的问题。下面是一个最小的中断示例:

import numpy as np
import tensorflow as tf

with tf.Session() as sess:
    with tf.device("/gpu:0"):
        def test_func(i):
            return i
        test_range = tf.constant(np.arange(5))
        test = sess.run(tf.map_fn(test_func, test_range, dtype=tf.float32))
print(test)

这将导致错误:

InvalidArgumentError: Cannot assign a device for operation 'map/TensorArray_1': Could not satisfy explicit device specification '' because the node was colocated with a group of nodes that required incompatible device '/device:GPU:0' Colocation Debug Info: Colocation group had the following types and devices: TensorArrayScatterV3: CPU TensorArrayGatherV3: GPU CPU Range: GPU CPU TensorArrayWriteV3: CPU TensorArraySizeV3: GPU CPU TensorArrayReadV3: CPU Enter: GPU CPU TensorArrayV3: CPU Const: GPU CPU

Colocation members and user-requested devices:
map/TensorArrayStack/range/delta (Const)
map/TensorArrayStack/range/start (Const) map/TensorArray_1 (TensorArrayV3) map/while/TensorArrayWrite/TensorArrayWriteV3/Enter (Enter) /device:GPU:0 map/TensorArrayStack/TensorArraySizeV3 (TensorArraySizeV3) map/TensorArrayStack/range (Range)
map/TensorArrayStack/TensorArrayGatherV3 (TensorArrayGatherV3)
map/TensorArray (TensorArrayV3) map/while/TensorArrayReadV3/Enter (Enter) /device:GPU:0 Const (Const) /device:GPU:0
map/TensorArrayUnstack/TensorArrayScatter/TensorArrayScatterV3 (TensorArrayScatterV3) /device:GPU:0 map/while/TensorArrayReadV3 (TensorArrayReadV3) /device:GPU:0
map/while/TensorArrayWrite/TensorArrayWriteV3 (TensorArrayWriteV3) /device:GPU:0

[[Node: map/TensorArray_1 = TensorArrayV3clear_after_read=true, dtype=DT_FLOAT, dynamic_size=false, element_shape=, identical_element_shapes=true, tensor_array_name=""]]

在我的CPU上运行时,代码的行为与预期的一样,并且执行一些简单的操作,例如:

^{pr2}$

在我的GPU上工作得很好。This post似乎描述了一个类似的问题。有人有什么建议吗?这篇文章的答案意味着map_fn应该可以在GPU上正常工作。我在ArchLinux上的Python3.6.4上运行TensorFlow的1.8.0版本,在GeForce GTX 1050上运行CUDA版本9.0和cuDNN版本7.0。在

谢谢!在


Tags: testmapgpudevicetfrangecpufn
1条回答
网友
1楼 · 发布于 2024-09-29 01:28:56

错误实际上是由于np.arange在默认情况下生成int32s,但您指定了float32返回类型。错误消失了

import numpy as np
import tensorflow as tf

with tf.Session() as sess:
    with tf.device("/gpu:0"):
        def test_func(i):
            return i
        test_range = tf.constant(np.arange(5, dtype=np.float32))
        test = sess.run(tf.map_fn(test_func, test_range, dtype=tf.float32))
print(test)

我同意你收到的错误信息相当混乱。通过删除设备位置,您将收到“真实”错误消息:

^{pr2}$

相关问题 更多 >