Java TensorFlow Tensor Access
如何使用TensorFlow Java 0.2.0在张量中找到最大浮点值的索引
我对Tensorflow Java相当陌生,还不清楚NdArray和Tensor之间的关系。如果有更好的方法将我的列表转换为张量输入,请告诉我
从SavedModelBundle
中,我使用SessionRunner
检索了一个形状为[1,3]的浮点(1)张量,并向其输入张量
通缉行为如下
text=“喜欢它”
[0.12,0.2,0.68]的张量
。预测(文本)返回-->;二,
JAVA
public Integer predict(String text) {
// Returns tokenized text in a List of MAXLEN
List<Integer> token_ids = tokenize(text);
// Convert List to tensor for Input
IntDataBuffer bufferTokens = DataBuffers.ofInts(MAXLEN);
int[] primArr = new int[MAXLEN];
for (int i=0; i<MAXLEN; i++) {
primArr[i] = token_ids.get(i);
}
bufferTokens.write(primArr);
IntNdArray tokensMatrix = NdArrays.ofInts(Shape.of(1, MAXLEN));
IntNdArray vector = tokensMatrix.get(0);
vector.write(bufferTokens);
Tensor<TInt32> input = TInt32.tensorOf(tokensMatrix);
// Model.predict
Tensor output = model.session()
.runner()
.feed("serving_default_input_ids:0", input)
.fetch("StatefulPartitionedCall:0")
.run() // List<Tensor<?>>
.get(0);
// TODO - HELP NEEDED: Extract arg max from tensor
Tensor val = output.expect(TFloat32.DTYPE); // val = FLOAT (1) tensor with shape [1, 3]
Integer maxIndex = ????????
return maxIndex;
}
运行myModel.metaGraphDef().getSignatureDefMap().get("serving_default");
时的模型输入、输出信息如下
ModelInfo
inputs {
key: "input_ids"
value {
name: "serving_default_input_ids:0"
dtype: DT_INT32
tensor_shape {
dim {
size: -1
}
dim {
size: MAXLEN
}
}
}
}
outputs {
key: "dense_3"
value {
name: "StatefulPartitionedCall:0"
dtype: DT_FLOAT
tensor_shape {
dim {
size: -1
}
dim {
size: 3
}
}
}
}
提前谢谢
共 (0) 个答案