为什么当Python Tensorflow会话没有重置状态时,Java Tensorflow会话似乎会重置状态?

2024-06-25 23:45:22 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图在Linux上通过1.4javaapi构建和评估TensorFlow图。我注意到,每次调用会话.运行()已制作完成。这种行为似乎与Python中发生的情况不匹配。我最后的问题是如何避免这种明显的行为?在

Python示例

这里的示例是Python代码(也使用1.4api),它以标量张量的形式递增值。在

>>> import tensorflow as tf
>>> x = tf.get_variable("x", [], dtype=tf.float32, initializer=tf.zeros_initializer)
>>> step = tf.constant(1.0)
>>> xUpdateOp = x.assign_add(step)
>>> s = tf.Session()
>>> s.run(x.initializer)
>>> x.eval(s)
0.0
>>> s.run(xUpdateOp)
1.0
>>> x.eval(s)
1.0
>>> s.run(xUpdateOp)
2.0
>>> x.eval(s)
2.0
>>>

注意,正如预期的那样,计算x会给出它的当前值,而使用session运行xUpdateOp会导致x变大1。在

Java示例

这是我尝试使用Java来构建一个增加标量张量的张量流图。javaapi中的初始化不同,因为它缺少一些Pythons方便的方法。在

^{pr2}$

以上代码片段的输出

1.0

但我希望它是4.0,因为我在xUpdateOp上调用了4次run()。即使我落后1.0也不是我所期望的。在

问题

为了获得与Python示例相同的行为,我需要如何处理这个Java示例?如何让xUpdateOp使用在上一次run()调用中计算的x值?在

我已经试过了

我已经尝试过使用feed()函数来输入一个x值

try(Session s = new Session(g)) {
    try(Tensor<Float> x1 = s.runner().fetch(xUpdateOp.name()).run().get(0).expect(Float.class)) {
        s.runner().feed(xUpdateOp.name(), 0, x1);
        try (Tensor<Float> result = s.runner().fetch(xUpdateOp.name(), 0).run().get(0).expect(Float.class)) {
            System.out.println(result.floatValue());
        }
    }
}

结果

1.0

我还尝试在没有addTarget或fetch()的情况下调用run(),认为addTarget或fetch()是导致状态重置的原因。也许一旦一个会话知道要运行什么,它就可以运行多次。在

try(Session s = new Session(g)) {
    s.runner().addTarget(xUpdateOp).run();
    s.runner().run();
    s.runner().run();

    try(Tensor<Float> result = s.runner().fetch(xUpdateOp.name(), 0).run().get(0).expect(Float.class)){
        System.out.println(result.floatValue());
    }
}

结果

Exception in thread "main" java.lang.IllegalArgumentException: Must specify at least one target to fetch or execute.
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:298)
at org.tensorflow.Session$Runner.run(Session.java:248)
at org.tensorflow.examples.Example.doCounting(MandelbrotExample.java:80)
at org.tensorflow.examples.Example.main(MandelbrotExample.java:50)
ERROR: Non-zero return code '1' from command: Process exited with status 1.

有点相关的问题

How to create/initialize a Variable with Tensorflow 1.0 Java API

java tensorflow reset_default_graph

Java - train loaded tensorflow model

提前感谢您的时间!在


Tags: runorg示例getsessiontftensorflowfetch
1条回答
网友
1楼 · 发布于 2024-06-25 23:45:22

在您的示例中,xUpdateOpx作为其输入,x是将{}分配给变量的操作的输出。因此,每次xUpdateOp运行时,它首先给变量赋值0。在

稍微调整一下代码,就会得到4.0:

# Changed addInput(x) to addInput(xVar)
Operation xUpdateOp =
    g.opBuilder("AssignAdd", "x_get_x_plus_step").addInput(xVar).addInput(step).build();

try (Session s = new Session(g)) {
  # Initialize the variable once
  s.runner().addTarget(x.op()).run();
  s.runner().addTarget(xUpdateOp).run();
  s.runner().addTarget(xUpdateOp).run();
  s.runner().addTarget(xUpdateOp).run();

  try (Tensor<Float> result =
       s.runner().fetch(xUpdateOp.name(), 0).run().get(0).expect(Float.class)) {
    System.out.println(result.floatValue());
  }                     
}

与Python代码并行:上面的Java代码片段更像问题中的Python代码。而问题中的Java代码更像以下Python中的代码:

^{pr2}$

所以tf.assign_add(x, step)tf.assign_add(xVar, step)之间的差别会很大。在前者中,AssignAdd操作应用于Assign操作的输出。在

希望有帮助。在

相关问题 更多 >