当我试图让TensorFlow的map_fn
在我的GPU上运行时,遇到了一个奇怪的问题。下面是一个最小的中断示例:
import numpy as np
import tensorflow as tf
with tf.Session() as sess:
with tf.device("/gpu:0"):
def test_func(i):
return i
test_range = tf.constant(np.arange(5))
test = sess.run(tf.map_fn(test_func, test_range, dtype=tf.float32))
print(test)
这将导致错误:
InvalidArgumentError: Cannot assign a device for operation 'map/TensorArray_1': Could not satisfy explicit device specification '' because the node was colocated with a group of nodes that required incompatible device '/device:GPU:0' Colocation Debug Info: Colocation group had the following types and devices: TensorArrayScatterV3: CPU TensorArrayGatherV3: GPU CPU Range: GPU CPU TensorArrayWriteV3: CPU TensorArraySizeV3: GPU CPU TensorArrayReadV3: CPU Enter: GPU CPU TensorArrayV3: CPU Const: GPU CPU
Colocation members and user-requested devices:
map/TensorArrayStack/range/delta (Const)
map/TensorArrayStack/range/start (Const) map/TensorArray_1 (TensorArrayV3) map/while/TensorArrayWrite/TensorArrayWriteV3/Enter (Enter) /device:GPU:0 map/TensorArrayStack/TensorArraySizeV3 (TensorArraySizeV3) map/TensorArrayStack/range (Range)
map/TensorArrayStack/TensorArrayGatherV3 (TensorArrayGatherV3)
map/TensorArray (TensorArrayV3) map/while/TensorArrayReadV3/Enter (Enter) /device:GPU:0 Const (Const) /device:GPU:0
map/TensorArrayUnstack/TensorArrayScatter/TensorArrayScatterV3 (TensorArrayScatterV3) /device:GPU:0 map/while/TensorArrayReadV3 (TensorArrayReadV3) /device:GPU:0
map/while/TensorArrayWrite/TensorArrayWriteV3 (TensorArrayWriteV3) /device:GPU:0[[Node: map/TensorArray_1 = TensorArrayV3clear_after_read=true, dtype=DT_FLOAT, dynamic_size=false, element_shape=, identical_element_shapes=true, tensor_array_name=""]]
在我的CPU上运行时,代码的行为与预期的一样,并且执行一些简单的操作,例如:
^{pr2}$在我的GPU上工作得很好。This post似乎描述了一个类似的问题。有人有什么建议吗?这篇文章的答案意味着map_fn
应该可以在GPU上正常工作。我在ArchLinux上的Python3.6.4上运行TensorFlow的1.8.0版本,在GeForce GTX 1050上运行CUDA版本9.0和cuDNN版本7.0。在
谢谢!在
错误实际上是由于
np.arange
在默认情况下生成int32
s,但您指定了float32
返回类型。错误消失了我同意你收到的错误信息相当混乱。通过删除设备位置,您将收到“真实”错误消息:
^{pr2}$相关问题 更多 >
编程相关推荐