TensorFlow 2.0:在镜像策略下创建局部变量

2024-10-03 02:37:18 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在实现内存转换器,我需要在对模型的调用之间保持memory。我已经尝试使用tf.Variable来完成这个任务,它在一个GPU上就可以完美地工作

但是在多个GPU上的MirroredStrategy下,这种方法失败了,因为MirroredStrategy希望同步写入多个副本上的变量。这在我的例子中是不需要的,我需要在每个“塔”上创建一组个人内存变量,就像TransformerXL实现中一样

我想我可以使用with tf.device():来创建这些变量,但我不确定如何在build方法中获取当前副本的设备


Tags: 方法内存buildgpudevicetfwith副本