下面的python代码将["hello", "world"]
传递到通用语句编码器,并返回一个浮点数数组,表示它们的编码表示
import tensorflow as tf
import tensorflow_hub as hub
module = hub.KerasLayer("https://tfhub.dev/google/universal-sentence-encoder/4")
model = tf.keras.Sequential(module)
print("model: ", model(["hello", "world"]))
这段代码可以工作,但我现在想使用JavaAPI做同样的事情。我已成功加载模块,但无法将输入传递到模型并提取输出。以下是我到目前为止得到的信息:
import org.tensorflow.Graph;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.GPUOptions;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.util.SaverDef;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;
public final class NaiveBayesClassifier
{
public static void main(String[] args)
{
new NaiveBayesClassifier().run();
}
protected SavedModelBundle loadModule(Path source, String... tags) throws IOException
{
return SavedModelBundle.load(source.toAbsolutePath().normalize().toString(), tags);
}
public void run()
{
try (SavedModelBundle module = loadModule(Paths.get("universal-sentence-encoder"), "serve"))
{
Graph graph = module.graph();
try (Session session = new Session(graph, ConfigProto.newBuilder().
setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true)).
setAllowSoftPlacement(true).
build().toByteArray()))
{
Tensor<String> input = Tensors.create(new byte[][]
{
"hello".getBytes(StandardCharsets.UTF_8),
"world".getBytes(StandardCharsets.UTF_8)
});
List<Tensor<?>> result = session.runner().feed("serving_default_inputs", input).
addTarget("???").run();
}
}
catch (IOException e)
{
e.printStackTrace();
}
}
}
我使用https://stackoverflow.com/a/51952478/14731扫描模型以查找可能的输入/输出节点。我相信输入节点是“为默认输入提供服务”,但我不知道输出节点是什么。更重要的是,在通过Keras调用python中的代码时,我不必指定这些值中的任何一个,那么有没有一种方法可以使用javaapi进行同样的操作呢
更新:多亏了roywei,我现在可以确认输入节点是serving_default_input
,输出节点是StatefulPartitionedCall_1
,但当我将这些名称插入上述代码时,我得到:
2020-05-22 22:13:52.266287: W tensorflow/core/framework/op_kernel.cc:1651] OP_REQUIRES failed at lookup_table_op.cc:809 : Failed precondition: Table not initialized.
Exception in thread "main" java.lang.IllegalStateException: [_Derived_]{{function_node __inference_pruned_6741}} {{function_node __inference_pruned_6741}} Error while reading resource variable EncoderDNN/DNN/ResidualHidden_0/dense/kernel/part_25 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/EncoderDNN/DNN/ResidualHidden_0/dense/kernel/part_25/class tensorflow::Var does not exist.
[[{{node EncoderDNN/DNN/ResidualHidden_0/dense/kernel/ConcatPartitions/concat/ReadVariableOp_25}}]]
[[StatefulPartitionedCall_1/StatefulPartitionedCall]]
at libtensorflow@1.15.0/org.tensorflow.Session.run(Native Method)
at libtensorflow@1.15.0/org.tensorflow.Session.access$100(Session.java:48)
at libtensorflow@1.15.0/org.tensorflow.Session$Runner.runHelper(Session.java:326)
at libtensorflow@1.15.0/org.tensorflow.Session$Runner.run(Session.java:276)
也就是说,我仍然无法调用该模型。我错过了什么
您可以使用Deep Java Library加载TF模型
详情见https://github.com/awslabs/djl/blob/master/docs/load_model.md#load-model-from-a-url
我是在roywei pointed me in the right direction之后发现的
SavedModuleBundle.session()
而不是构建自己的实例。这是因为加载程序初始化图形变量李>ConfigProto
传递给Session
构造函数,而是将其传递给SavedModelBundle
加载程序李>fetch()
而不是addTarget()
来检索输出张量李>以下是工作代码:
有两种方法可以获取名称:
1)使用Java:
您可以从保存的模型包中存储的
org.tensorflow.proto.framework.MetaGraphDef
中读取输入和输出名称下面是一个关于如何提取信息的示例:
https://github.com/awslabs/djl/blob/master/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java#L149
2)使用python:
在tensorflow python中加载保存的模型并打印名称
我建议看一下Deep Java Library,它会自动处理输入、输出名称。 它支持TensorFlow 2.1.0,允许您加载Keras模型以及TF Hub保存的模型。请看一下文件here和here
如果您在加载模型时遇到问题,请随时打开issue
相关问题 更多 >
编程相关推荐