使用Tensorflow Java错误预测结果

2024-09-30 00:34:48 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用TensorFlow Java和Python做一些实验。 我使用Python训练并保存了一个MNIST模型,现在我想用Java加载它并进行预测。我有以下方法

public void predict() {
    final int IMAGE_INDEX = 10;
    MnistDataset dataset = MnistDataset.create(0, TRAINING_IMAGES_ARCHIVE, TRAINING_LABELS_ARCHIVE,
            TEST_IMAGES_ARCHIVE, TEST_LABELS_ARCHIVE);

    try (SavedModelBundle modelBundle = SavedModelBundle
            .load("PATH_TO_TRAINED_MODEL/mymnist");
         Tensor<TUint8> batchImages = TUint8.tensorOf(dataset.testBatch().images().get(IMAGE_INDEX))) {
        FloatDataBuffer floatDataBuffer = DataBuffers.ofFloats(784);
        float[] ds = new float[784];
        for (int i = 0; i < batchImages.rawData().size(); i++) {
            ds[i] = batchImages.rawData().getByte(i);
        }
        floatDataBuffer = floatDataBuffer.write(ds);

        try (Tensor<TFloat32> input = TFloat32.tensorOf(Shape.of(1, 28, 28, 1), floatDataBuffer);
             Tensor<TFloat32> tensor = modelBundle.session().runner()
                     .feed("serving_default_input", input)
                     .fetch("StatefulPartitionedCall")
                     .run().get(0).expect(TFloat32.DTYPE)) {
            System.out.print(argmax(tensor.data().get(0)));
            System.out.println("-" + dataset.testBatch().labels().get(IMAGE_INDEX).getByte());

        }
    }
}

Python中的模型如下所示

model = Sequential()
model.add(Input(shape=(28, 28, 1), dtype="float32", name="input"))
model.add(Conv2D(filters=32, kernel_size=(4, 4), input_shape=(28, 28, 1), activation='relu'))
model.add(MaxPool2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dense(10, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
early_stopping = EarlyStopping(monitor='val_loss', patience=2)
history = model.fit(x_train, y_train_categorical, epochs=10, validation_data=(x_test, y_test_categorical),
                    callbacks=[early_stopping])

输入层是float32,但是MNIST图像是Tensor<UInt8>,所以我必须将它们转换为Tensor<TFloat32>,以便能够用Java提供模型。 我不确定这个转换是否是我想出的一个好主意(它让我烦透了;)

    FloatDataBuffer floatDataBuffer = DataBuffers.ofFloats(784);
    float[] ds = new float[784];
    for (int i = 0; i < batchImages.rawData().size(); i++) {
        ds[i] = batchImages.rawData().getByte(i);
    }
    floatDataBuffer = floatDataBuffer.write(ds);

    Tensor<TFloat32> input = TFloat32.tensorOf(Shape.of(1, 28, 28, 1), floatDataBuffer);

我在Java代码中得到的结果总是不正确的,但当我在Python中对原始模型调用predict时,它工作得很好。 有人能帮我解决这个问题吗

提前谢谢

PS.我用了MnistDataset


Tags: 模型addinputsizegetmodeldsjava

热门问题