我正在使用TensorFlow Java和Python做一些实验。 我使用Python训练并保存了一个MNIST模型,现在我想用Java加载它并进行预测。我有以下方法
public void predict() {
final int IMAGE_INDEX = 10;
MnistDataset dataset = MnistDataset.create(0, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE,
TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);
try (SavedModelBundle modelBundle = SavedModelBundle
.load("PATH_TO_TRAINED_MODEL/mymnist");
Tensor<TUint8> batchImages = TUint8.tensorOf(dataset.testBatch().images().get(IMAGE_INDEX))) {
FloatDataBuffer floatDataBuffer = DataBuffers.ofFloats(784);
float[] ds = new float[784];
for (int i = 0; i < batchImages.rawData().size(); i++) {
ds[i] = batchImages.rawData().getByte(i);
}
floatDataBuffer = floatDataBuffer.write(ds);
try (Tensor<TFloat32> input = TFloat32.tensorOf(Shape.of(1, 28, 28, 1), floatDataBuffer);
Tensor<TFloat32> tensor = modelBundle.session().runner()
.feed("serving_default_input", input)
.fetch("StatefulPartitionedCall")
.run().get(0).expect(TFloat32.DTYPE)) {
System.out.print(argmax(tensor.data().get(0)));
System.out.println("-" + dataset.testBatch().labels().get(IMAGE_INDEX).getByte());
}
}
}
Python中的模型如下所示
model = Sequential()
model.add(Input(shape=(28, 28, 1), dtype="float32", name="input"))
model.add(Conv2D(filters=32, kernel_size=(4, 4), input_shape=(28, 28, 1), activation='relu'))
model.add(MaxPool2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dense(10, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
early_stopping = EarlyStopping(monitor='val_loss', patience=2)
history = model.fit(x_train, y_train_categorical, epochs=10, validation_data=(x_test, y_test_categorical),
callbacks=[early_stopping])
输入层是float32
,但是MNIST图像是Tensor<UInt8>
,所以我必须将它们转换为Tensor<TFloat32>
,以便能够用Java提供模型。
我不确定这个转换是否是我想出的一个好主意(它让我烦透了;)
FloatDataBuffer floatDataBuffer = DataBuffers.ofFloats(784);
float[] ds = new float[784];
for (int i = 0; i < batchImages.rawData().size(); i++) {
ds[i] = batchImages.rawData().getByte(i);
}
floatDataBuffer = floatDataBuffer.write(ds);
Tensor<TFloat32> input = TFloat32.tensorOf(Shape.of(1, 28, 28, 1), floatDataBuffer);
我在Java代码中得到的结果总是不正确的,但当我在Python中对原始模型调用predict时,它工作得很好。 有人能帮我解决这个问题吗
提前谢谢
PS.我用了MnistDataset
目前没有回答
相关问题 更多 >
编程相关推荐