永久更新tensorflowjava中的变量(在推理过程中)

2024-09-30 02:32:44 发布

您现在位置:Python中文网/ 问答频道 /正文

我用pythontensorflow训练了一个模型,我想用javatensorflow进行推理。我已经将经过训练的模型/图形加载到Java中。在此之后,我想永久更新图形中的一个变量。我知道python中的tf.variable.load(value,session)函数可以用来更新变量的值。我想知道Java中是否有类似的方法。在

到目前为止,我已经尝试了以下几点。在

// g and s are loaded graphs and sessions respectively
s.runner().feed(variableName,updatedTensorValue)

但是在同一行中执行的fetch调用期间,上面的行只对updatedTensorValue使用variableName。在

^{pr2}$

上面的行没有更新值,而是尝试向图中添加相同的变量,因此引发异常。在

另一种永久更新图中变量的方法是,在所有fetch调用期间,我将始终调用feed(variableName,updatedTensorValue)方法。我将在几个实例上运行推理代码,因此我想知道这个额外的feed调用所需的额外时间。在

谢谢


Tags: and方法模型图形tffeedloadfetch
1条回答
网友
1楼 · 发布于 2024-09-30 02:32:44

在TensorFlow中做大多数事情的方法是执行一个操作。您尝试运行Assign操作是正确的,但是调用不正确,因为要分配的value不是Assign操作的“属性”,而是输入张量。(请参见原始的definition of the operation,尽管不可否认,除非您熟悉TensorFlow内部结构,否则定义可能不容易理解)。在

但是,您不需要在Java中向图中添加操作。相反,您可以做Python中^{}所做的事情——执行^{}操作,输入输入值。在

例如,考虑以下Python内置的图形:

import tensorflow as tf

var = tf.Variable(1.0, name='myvar')
init = tf.global_variables_initializer()

# Save the graph and write out the names of the operations of interest
tf.train.write_graph(tf.get_default_graph(), '/tmp', 'graph.pb', as_text=False)
print('Init all variables:         ', init.name)
print('myvar.initializer:          ', var.initializer.name)
print('myvar.initializer.inputs[1]:', var.initializer.inputs[1].name)

现在,我们复制Java中Python var.load()的行为,使用如下方法将值3.0赋给变量:

^{pr2}$

希望有帮助。在

相关问题 更多 >

    热门问题