有 Java 编程相关的问题?

你可以在下面搜索框中键入要查询的问题!

Java TensorFlow Tensor Access

如何使用TensorFlow Java 0.2.0在张量中找到最大浮点值的索引

我对Tensorflow Java相当陌生,还不清楚NdArray和Tensor之间的关系。如果有更好的方法将我的列表转换为张量输入,请告诉我

SavedModelBundle中,我使用SessionRunner检索了一个形状为[1,3]的浮点(1)张量,并向其输入张量

通缉行为如下

text=“喜欢它”

[0.12,0.2,0.68]的张量

。预测(文本)返回-->;二,


JAVA

public Integer predict(String text) {

        // Returns tokenized text in a List of MAXLEN
        List<Integer> token_ids = tokenize(text);

        // Convert List to tensor for Input
        IntDataBuffer bufferTokens = DataBuffers.ofInts(MAXLEN);
        int[] primArr = new int[MAXLEN];
        for (int i=0; i<MAXLEN; i++) {
            primArr[i] = token_ids.get(i);
        }
        bufferTokens.write(primArr);

        IntNdArray tokensMatrix = NdArrays.ofInts(Shape.of(1, MAXLEN));
        IntNdArray vector = tokensMatrix.get(0);
        vector.write(bufferTokens);

        Tensor<TInt32> input = TInt32.tensorOf(tokensMatrix);

        // Model.predict
        Tensor output = model.session()
                .runner()
                .feed("serving_default_input_ids:0", input)
                .fetch("StatefulPartitionedCall:0")
                .run() // List<Tensor<?>>
                .get(0);

        // TODO - HELP NEEDED: Extract arg max from tensor
        Tensor val = output.expect(TFloat32.DTYPE); // val = FLOAT (1) tensor with shape [1, 3]
        Integer maxIndex = ????????

        return maxIndex;
    }

运行myModel.metaGraphDef().getSignatureDefMap().get("serving_default");时的模型输入、输出信息如下

 ModelInfo
    inputs {
        key: "input_ids"
        value {
            name: "serving_default_input_ids:0"
            dtype: DT_INT32
            tensor_shape {
                dim {
                    size: -1
                }
                dim {
                    size: MAXLEN
                }
            }
        }
    }
    outputs {
        key: "dense_3"
        value {
            name: "StatefulPartitionedCall:0"
            dtype: DT_FLOAT
            tensor_shape {
                dim {
                    size: -1
                }
                dim {
                    size: 3
                }
            }
        }
    }

提前谢谢


共 (0) 个答案