由于自定义图层,TensorFlow Keras模型无法加载到Deeplearning4J中

2024-09-29 00:19:00 发布

您现在位置:Python中文网/ 问答频道 /正文

我已经用Python训练了TensorFlow Keras模型,并将其保存为trained_model.h5:

model = tf.keras.Sequential([hub.KerasLayer("https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4",
                                            output_shape = [1280],
                                            trainable = False),
                                            tf.keras.layers.Dropout(0.4),
                                            tf.keras.layers.Dense(train_generator.num_classes, activation='softmax')
                             ])

model.build([None, 224, 224, 3])
model.summary()

optimizer = tf.keras.optimizers.Adam(lr=1e-3)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics = ['acc'])

steps_per_epoch = np.ceil(train_generator.samples /train_generator.batch_size)
val_steps_per_epoch = np.ceil(valid_generator.samples/valid_generator.batch_size)
hist = model.fit(
    train_generator,
    epochs=3,
    verbose=1,
    steps_per_epoch=steps_per_epoch,
    validation_data=valid_generator,
    validation_steps=val_steps_per_epoch).history

然后,我想使用Deeplearning4J版本1.0.0-beta7将经过培训的模型加载到我的Java应用程序中:

public static void main(String[] args) throws Exception
{
  String simpleMlp = new ClassPathResource("trained_model.h5").getFile().getPath();
  MultiLayerNetwork model = KerasModelImport.importKerasSequentialModelAndWeights(simpleMlp);

  String imagePath = new ClassPathResource("image.png").getFile().getPath();
  File imageFile = new File(imagePath);
  ImageLoader imageLoader = new ImageLoader(height, width, channels);
  INDArray image = imageLoader.asRowVector(imageFile);

  INDArray outut = model.output(image);
}

但是,这会导致以下错误:

Exception in thread "main" org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException: Unsupported keras layer type KerasLayer. Please file an issue at https://github.com/eclipse/deeplearning4j/issues.
    at org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils.getKerasLayerFromConfig(KerasLayerUtils.java:334)
    at org.deeplearning4j.nn.modelimport.keras.KerasModel.prepareLayers(KerasModel.java:218)
    at org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel.<init>(KerasSequentialModel.java:110)
    at org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel.<init>(KerasSequentialModel.java:57)
    at org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder.buildSequential(KerasModelBuilder.java:322)
    at org.deeplearning4j.nn.modelimport.keras.KerasModelImport.importKerasSequentialModelAndWeights(KerasModelImport.java:223)

我相信这与我使用mobilenet_v2特征向量的自定义KerasLayer有关。 有没有办法解决这个问题,以便在Java应用程序中加载模型


Tags: orgnewmodeltftrainnnjavasteps