如何将在非分布式环境中创建的Tensorflow图加载到分布式环境中?

2024-09-28 20:55:31 发布

您现在位置:Python中文网/ 问答频道 /正文

我将为单节点运行创建一个Tensorflow图。但是以后如果我想在一个分布式环境中训练同一个模型图(在多个参数服务器之间划分变量,在n个worker之间复制这个图),我该怎么做呢?你知道吗

我找到了一个名为^{}的东西来导出GraphDef proto,然后将该图作为^{}导入。但这不起作用。你知道吗

代码:

import tensorflow as tf

graph = tf.Graph()

with graph.as_default():
    x_place_holder = tf.placeholder(dtype=tf.float32, shape=[], name="xin")
    y_place_holder = tf.placeholder(dtype=tf.float32, shape=[], name="yin")

    m = tf.Variable(10.0, name="varm")
    l = tf.Variable(20.0, name="varl")

    Y = tf.multiply(m, x_place_holder, name="mulop")
    X = tf.add(l, x_place_holder, name="addop")
    cost = tf.abs(Y - X, name="cost")

    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.5, name="optimizer").minimize(cost)

tf.reset_default_graph()

if FLAGS.job_name == "ps":
    server.join()

elif FLAGS.job_name == "worker":
    print(FLAGS.task_index, "task index")

    with tf.device(tf.train.replica_device_setter(
            worker_device="/job:worker/task:%d" % FLAGS.task_index,
            cluster=cluster)):
        tf.import_graph_def(graph.as_graph_def(),return_elements=["xin","yin","varm","varl","mulop","addop","cost","optimizer"])

堆栈跟踪:

    Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1039, in _do_call
    return fn(*args)
  File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1017, in _run_fn
    self._extend_graph()
  File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1066, in _extend_graph
    self._session, graph_def.SerializeToString(), status)
  File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/contextlib.py", line 66, in __exit__
    next(self.gen)
  File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
    pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot colocate nodes 'import/varl/read' and 'import/varl: Cannot merge devices with incompatible jobs: '/job:ps/task:1' and '/job:worker/task:1'
     [[Node: import/varl/read = Identity[T=DT_FLOAT, _class=["loc:@import/varl"], _device="/job:worker/task:1"](import/varl)]]

或者Tensorflow还有其他方法可以做到这一点吗?你知道吗


Tags: nameimporttasklibtftensorflowlibraryjob
1条回答
网友
1楼 · 发布于 2024-09-28 20:55:31

截至2017年6月,这一点不受支持。要在分布式环境中训练模型,可以重用图形生成python代码(如果它包装在副本设备设置器中),但不能重用生成的图形本身。你知道吗

相关问题 更多 >