Deeplearning4j:预期的模型类名模型(发现功能性)(InvalidKerasConfiguration异常)

2024-09-29 00:19:23 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试使用Deeplearning4J在Java(Maven项目)中导入一个Python学习的机器学习模型。我在tf.keras中使用了一个函数模型。但每当我尝试(以文字形式)文档告诉我的操作时,它都会给我错误。为了完整起见,我在下面添加了我的代码

Python模型:

# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

class lstm_bottleneck(tf.keras.layers.Layer):
    def __init__(self, lstm_units, time_steps, **kwargs):
        self.lstm_units = lstm_units
        self.time_steps = time_steps
        self.lstm_layer = Bidirectional(LSTM(lstm_units, return_sequences=False))
        self.repeat_layer = RepeatVector(time_steps)
        super(lstm_bottleneck, self).__init__(**kwargs)
    
    def call(self, inputs):
        # just call the two initialized layers
        return self.repeat_layer(self.lstm_layer(inputs))
    
    def compute_mask(self, inputs, mask=None):
        # return the input_mask directly
        return mask

    def get_config(self):
        cfg = super().get_config()
        return cfg 

with strategy.scope():
    
  inp1 = Input(shape=(timesteps, 7), name="inp1")
  mask1 = Masking(mask_value=-1.)(inp1)

  enc = Bidirectional(LSTM(55, activation = 'tanh', return_sequences = True, dropout = 0.1, kernel_regularizer=l2(0.01)))(mask1)
  enc = Dropout(0.2)(enc)
  enc = Bidirectional(LSTM(50, activation = 'tanh', return_sequences = True, kernel_regularizer=l2(0.01)))(enc)
  enc = Dropout(0.1)(enc)

  decode = lstm_bottleneck(lstm_units=45, time_steps=timesteps)(enc)

  decode = Bidirectional(LSTM(50, activation = 'tanh', return_sequences = True, kernel_regularizer=l2(0.01)))(decode)
  decode = Dropout(0.2)(decode)
  decode = Bidirectional(LSTM(55, activation = 'tanh', return_sequences = True, kernel_regularizer=l2(0.01)))(decode)
  decode = TimeDistributed(Dense(6, activation="softmax"), name="dec1")(decode)
      
  new_model = Model(inputs=inp1, outputs = decode)
  new_model.compile(loss= 'categorical_crossentropy', optimizer= tf.keras.optimizers.Adam(lr=0.0005), metrics=['categorical_accuracy'])
  plot_model(new_model, to_file='model.png')
  new_model.summary()

Pom.xml:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>org.example</groupId>
    <artifactId>sampleProject</artifactId>
    <version>1.0-SNAPSHOT</version>

    <dependencies>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>1.0.0-beta6</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-modelimport</artifactId>
            <version>1.0.0-beta6</version>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-native-platform</artifactId>
            <version>1.0.0-beta6</version>
        </dependency>
    </dependencies>

</project>

Java代码:

import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.nd4j.linalg.io.ClassPathResource;

// Loading the model
String fullModel = new ClassPathResource("val_loss_model.h5").getFile().getPath();
thisModel = KerasModelImport.importKerasModelAndWeights(fullModel);

错误:

SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException: Expected model class name Model (found Functional). For more information, see http://deeplearning4j.org/docs/latest/keras-import-overview
    at org.deeplearning4j.nn.modelimport.keras.KerasModel.<init>(KerasModel.java:133)
    at org.deeplearning4j.nn.modelimport.keras.KerasModel.<init>(KerasModel.java:96)
    at org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder.buildModel(KerasModelBuilder.java:307)
    at org.deeplearning4j.nn.modelimport.keras.KerasModelImport.importKerasModelAndWeights(KerasModelImport.java:172)
    at MachineLearningModel.<init>(MachineLearningModel.java:21)
    at SimulatedAnnealing.Optimize(SimulatedAnnealing.java:8)
    at Main.main(Main.java:33)

Tags: orgselfmodelreturnversionnnatkeras