如何从TensorFlow Java调用模型?

2024-10-01 22:39:18 发布

您现在位置:Python中文网/ 问答频道 /正文

下面的python代码将["hello", "world"]传递到通用语句编码器,并返回一个浮点数数组,表示它们的编码表示

import tensorflow as tf
import tensorflow_hub as hub

module = hub.KerasLayer("https://tfhub.dev/google/universal-sentence-encoder/4")
model = tf.keras.Sequential(module)
print("model: ", model(["hello", "world"]))

这段代码可以工作,但我现在想使用JavaAPI做同样的事情。我已成功加载模块,但无法将输入传递到模型并提取输出。以下是我到目前为止得到的信息:

import org.tensorflow.Graph;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.GPUOptions;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.util.SaverDef;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;

public final class NaiveBayesClassifier
{
    public static void main(String[] args)
    {
        new NaiveBayesClassifier().run();
    }

    protected SavedModelBundle loadModule(Path source, String... tags) throws IOException
    {
        return SavedModelBundle.load(source.toAbsolutePath().normalize().toString(), tags);
    }

    public void run()
    {
        try (SavedModelBundle module = loadModule(Paths.get("universal-sentence-encoder"), "serve"))
        {
            Graph graph = module.graph();
            try (Session session = new Session(graph, ConfigProto.newBuilder().
                setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true)).
                setAllowSoftPlacement(true).
                build().toByteArray()))
            {
                Tensor<String> input = Tensors.create(new byte[][]
                    {
                        "hello".getBytes(StandardCharsets.UTF_8),
                        "world".getBytes(StandardCharsets.UTF_8)
                    });
                List<Tensor<?>> result = session.runner().feed("serving_default_inputs", input).
                    addTarget("???").run();
            }
        }
        catch (IOException e)
        {
            e.printStackTrace();
        }
    }
}

我使用https://stackoverflow.com/a/51952478/14731扫描模型以查找可能的输入/输出节点。我相信输入节点是“为默认输入提供服务”,但我不知道输出节点是什么。更重要的是,在通过Keras调用python中的代码时,我不必指定这些值中的任何一个,那么有没有一种方法可以使用javaapi进行同样的操作呢

更新:多亏了roywei,我现在可以确认输入节点是serving_default_input,输出节点是StatefulPartitionedCall_1,但当我将这些名称插入上述代码时,我得到:

2020-05-22 22:13:52.266287: W tensorflow/core/framework/op_kernel.cc:1651] OP_REQUIRES failed at lookup_table_op.cc:809 : Failed precondition: Table not initialized.
Exception in thread "main" java.lang.IllegalStateException: [_Derived_]{{function_node __inference_pruned_6741}} {{function_node __inference_pruned_6741}} Error while reading resource variable EncoderDNN/DNN/ResidualHidden_0/dense/kernel/part_25 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/EncoderDNN/DNN/ResidualHidden_0/dense/kernel/part_25/class tensorflow::Var does not exist.
     [[{{node EncoderDNN/DNN/ResidualHidden_0/dense/kernel/ConcatPartitions/concat/ReadVariableOp_25}}]]
     [[StatefulPartitionedCall_1/StatefulPartitionedCall]]
    at libtensorflow@1.15.0/org.tensorflow.Session.run(Native Method)
    at libtensorflow@1.15.0/org.tensorflow.Session.access$100(Session.java:48)
    at libtensorflow@1.15.0/org.tensorflow.Session$Runner.runHelper(Session.java:326)
    at libtensorflow@1.15.0/org.tensorflow.Session$Runner.run(Session.java:276)

也就是说,我仍然无法调用该模型。我错过了什么


Tags: run代码orgimporthello节点sessiontensorflow
3条回答

您可以使用Deep Java Library加载TF模型

System.setProperty("ai.djl.repository.zoo.location", "https://storage.googleapis.com/tfhub-modules/google/universal-sentence-encoder/1.tar.gz?artifact_id=encoder");

Criteria.Builder<NDList, NDList> builder =
        Criteria.builder()
                .setTypes(NDList.class, NDList.class)
                .optArtifactId("ai.djl.localmodelzoo:encoder")
                .build();
ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);

详情见https://github.com/awslabs/djl/blob/master/docs/load_model.md#load-model-from-a-url

我是在roywei pointed me in the right direction之后发现的

  • 我需要使用SavedModuleBundle.session()而不是构建自己的实例。这是因为加载程序初始化图形变量
  • 我没有将ConfigProto传递给Session构造函数,而是将其传递给SavedModelBundle加载程序
  • 我需要使用fetch()而不是addTarget()来检索输出张量

以下是工作代码:

public final class NaiveBayesClassifier
{
    public static void main(String[] args)
    {
        new NaiveBayesClassifier().run();
    }

    public void run()
    {
        try (SavedModelBundle module = loadModule(Paths.get("universal-sentence-encoder"), "serve"))
        {
            try (Tensor<String> input = Tensors.create(new byte[][]
                {
                    "hello".getBytes(StandardCharsets.UTF_8),
                    "world".getBytes(StandardCharsets.UTF_8)
                }))
            {
                MetaGraphDef metadata = MetaGraphDef.parseFrom(module.metaGraphDef());
                Map<String, Shape> nameToInput = getInputToShape(metadata);
                String firstInput = nameToInput.keySet().iterator().next();

                Map<String, Shape> nameToOutput = getOutputToShape(metadata);
                String firstOutput = nameToOutput.keySet().iterator().next();

                System.out.println("input: " + firstInput);
                System.out.println("output: " + firstOutput);
                System.out.println();

                List<Tensor<?>> result = module.session().runner().feed(firstInput, input).
                    fetch(firstOutput).run();
                for (Tensor<?> tensor : result)
                {
                    {
                        float[][] array = new float[tensor.numDimensions()][tensor.numElements() /
                            tensor.numDimensions()];
                        tensor.copyTo(array);
                        System.out.println(Arrays.deepToString(array));
                    }
                }
            }
        }
        catch (IOException e)
        {
            e.printStackTrace();
        }
    }

    /**
     * Loads a graph from a file.
     *
     * @param source the directory containing  to load from
     * @param tags   the model variant(s) to load
     * @return the graph
     * @throws NullPointerException if any of the arguments are null
     * @throws IOException          if an error occurs while reading the file
     */
    protected SavedModelBundle loadModule(Path source, String... tags) throws IOException
    {
        // https://stackoverflow.com/a/43526228/14731
        try
        {
            return SavedModelBundle.loader(source.toAbsolutePath().normalize().toString()).
                withTags(tags).
                withConfigProto(ConfigProto.newBuilder().
                    setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true)).
                    setAllowSoftPlacement(true).
                    build().toByteArray()).
                load();
        }
        catch (TensorFlowException e)
        {
            throw new IOException(e);
        }
    }

    /**
     * @param metadata the graph metadata
     * @return the first signature, or null
     */
    private SignatureDef getFirstSignature(MetaGraphDef metadata)
    {
        Map<String, SignatureDef> nameToSignature = metadata.getSignatureDefMap();
        if (nameToSignature.isEmpty())
            return null;
        return nameToSignature.get(nameToSignature.keySet().iterator().next());
    }

    /**
     * @param metadata the graph metadata
     * @return the output signature
     */
    private SignatureDef getServingSignature(MetaGraphDef metadata)
    {
        return metadata.getSignatureDefOrDefault("serving_default", getFirstSignature(metadata));
    }

    /**
     * @param metadata the graph metadata
     * @return a map from an output name to its shape
     */
    protected Map<String, Shape> getOutputToShape(MetaGraphDef metadata)
    {
        Map<String, Shape> result = new HashMap<>();
        SignatureDef servingDefault = getServingSignature(metadata);
        for (Map.Entry<String, TensorInfo> entry : servingDefault.getOutputsMap().entrySet())
        {
            TensorShapeProto shapeProto = entry.getValue().getTensorShape();
            List<Dim> dimensions = shapeProto.getDimList();
            long firstDimension = dimensions.get(0).getSize();
            long[] remainingDimensions = dimensions.stream().skip(1).mapToLong(Dim::getSize).toArray();
            Shape shape = Shape.make(firstDimension, remainingDimensions);
            result.put(entry.getValue().getName(), shape);
        }
        return result;
    }

    /**
     * @param metadata the graph metadata
     * @return a map from an input name to its shape
     */
    protected Map<String, Shape> getInputToShape(MetaGraphDef metadata)
    {
        Map<String, Shape> result = new HashMap<>();
        SignatureDef servingDefault = getServingSignature(metadata);
        for (Map.Entry<String, TensorInfo> entry : servingDefault.getInputsMap().entrySet())
        {
            TensorShapeProto shapeProto = entry.getValue().getTensorShape();
            List<Dim> dimensions = shapeProto.getDimList();
            long firstDimension = dimensions.get(0).getSize();
            long[] remainingDimensions = dimensions.stream().skip(1).mapToLong(Dim::getSize).toArray();
            Shape shape = Shape.make(firstDimension, remainingDimensions);
            result.put(entry.getValue().getName(), shape);
        }
        return result;
    }
}

有两种方法可以获取名称:

1)使用Java:

您可以从保存的模型包中存储的org.tensorflow.proto.framework.MetaGraphDef中读取输入和输出名称

下面是一个关于如何提取信息的示例:

https://github.com/awslabs/djl/blob/master/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java#L149

2)使用python:

在tensorflow python中加载保存的模型并打印名称

loaded = tf.saved_model.load("path/to/model/")
print(list(loaded.signatures.keys()))
infer = loaded.signatures["serving_default"]
print(infer.structured_outputs)

我建议看一下Deep Java Library,它会自动处理输入、输出名称。 它支持TensorFlow 2.1.0,允许您加载Keras模型以及TF Hub保存的模型。请看一下文件herehere

如果您在加载模型时遇到问题,请随时打开issue

相关问题 更多 >

    热门问题