我试图在Linux上通过1.4javaapi构建和评估TensorFlow图。我注意到,每次调用会话.运行()已制作完成。这种行为似乎与Python中发生的情况不匹配。我最后的问题是如何避免这种明显的行为?在
这里的示例是Python代码(也使用1.4api),它以标量张量的形式递增值。在
>>> import tensorflow as tf
>>> x = tf.get_variable("x", [], dtype=tf.float32, initializer=tf.zeros_initializer)
>>> step = tf.constant(1.0)
>>> xUpdateOp = x.assign_add(step)
>>> s = tf.Session()
>>> s.run(x.initializer)
>>> x.eval(s)
0.0
>>> s.run(xUpdateOp)
1.0
>>> x.eval(s)
1.0
>>> s.run(xUpdateOp)
2.0
>>> x.eval(s)
2.0
>>>
注意,正如预期的那样,计算x会给出它的当前值,而使用session运行xUpdateOp会导致x变大1。在
这是我尝试使用Java来构建一个增加标量张量的张量流图。javaapi中的初始化不同,因为它缺少一些Pythons方便的方法。在
^{pr2}$以上代码片段的输出
1.0
但我希望它是4.0,因为我在xUpdateOp上调用了4次run()。即使我落后1.0也不是我所期望的。在
为了获得与Python示例相同的行为,我需要如何处理这个Java示例?如何让xUpdateOp使用在上一次run()调用中计算的x值?在
我已经尝试过使用feed()函数来输入一个x值
try(Session s = new Session(g)) {
try(Tensor<Float> x1 = s.runner().fetch(xUpdateOp.name()).run().get(0).expect(Float.class)) {
s.runner().feed(xUpdateOp.name(), 0, x1);
try (Tensor<Float> result = s.runner().fetch(xUpdateOp.name(), 0).run().get(0).expect(Float.class)) {
System.out.println(result.floatValue());
}
}
}
结果
1.0
我还尝试在没有addTarget或fetch()的情况下调用run(),认为addTarget或fetch()是导致状态重置的原因。也许一旦一个会话知道要运行什么,它就可以运行多次。在
try(Session s = new Session(g)) {
s.runner().addTarget(xUpdateOp).run();
s.runner().run();
s.runner().run();
try(Tensor<Float> result = s.runner().fetch(xUpdateOp.name(), 0).run().get(0).expect(Float.class)){
System.out.println(result.floatValue());
}
}
结果
Exception in thread "main" java.lang.IllegalArgumentException: Must specify at least one target to fetch or execute.
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:298)
at org.tensorflow.Session$Runner.run(Session.java:248)
at org.tensorflow.examples.Example.doCounting(MandelbrotExample.java:80)
at org.tensorflow.examples.Example.main(MandelbrotExample.java:50)
ERROR: Non-zero return code '1' from command: Process exited with status 1.
How to create/initialize a Variable with Tensorflow 1.0 Java API
java tensorflow reset_default_graph
Java - train loaded tensorflow model
提前感谢您的时间!在
在您的示例中,}分配给变量的操作的输出。因此,每次
xUpdateOp
将x
作为其输入,x
是将{xUpdateOp
运行时,它首先给变量赋值0。在稍微调整一下代码,就会得到4.0:
与Python代码并行:上面的Java代码片段更像问题中的Python代码。而问题中的Java代码更像以下Python中的代码:
^{pr2}$所以
tf.assign_add(x, step)
与tf.assign_add(xVar, step)
之间的差别会很大。在前者中,AssignAdd
操作应用于Assign
操作的输出。在希望有帮助。在
相关问题 更多 >
编程相关推荐