TypeError:“函数”和“str”的实例之间不支持“<”

2024-09-27 21:30:59 发布

您现在位置:Python中文网/ 问答频道 /正文

我已经建立了一个带有自定义f1分数指标的序列模型。我在编译模型的过程中传递它,然后将其保存为*.hdf5格式。每当我使用custom_objects属性加载模型进行测试时

model = load_model('app/model/test_model.hdf5', custom_objects={'f1':f1})

Keras抛出以下错误

TypeError: '<' not supported between instances of 'function' and 'str'

注意:如果在编译过程中不包含f1度量,则不会显示任何错误,并且测试过程运行良好

训练方法

from metrics import f1

...

# GRU with glove embeddings and two dense layers
model = Sequential()
model.add(Embedding(len(word_index) + 1,
                    100,
                    weights=[embedding_matrix],
                    input_length=max_len,
                    trainable=False))
model.add(SpatialDropout1D(0.3))
model.add(GRU(100, dropout=0.3, recurrent_dropout=0.3, return_sequences=True))
model.add(GRU(100, dropout=0.3, recurrent_dropout=0.3))

model.add(Dense(1024, activation='relu'))
#model.add(Dropout(0.8))

model.add(Dense(1024, activation='relu'))
#model.add(Dropout(0.8))

model.add(Dense(2))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc', f1])

# Fit the model with early stopping callback
earlystop = EarlyStopping(monitor='val_loss', min_delta=0, patience=3, verbose=0, mode='auto')
model.fit(xtrain_pad, y=ytrain_enc, batch_size=512, epochs=100, 
        verbose=1, validation_data=(xvalid_pad, yvalid_enc), callbacks=[earlystop])

model.save('app/model/test_model.hdf5')

试验方法

from metrics import f1

...

model = load_model('app/model/test_model.hdf5', custom_objects={'f1':f1}) 
print(model.summary())

model.evaluate(xtest_pad, ytest_enc) # <-- error happens

自定义f1指标

from keras import backend as K

def f1(y_true, y_pred):
    def recall(y_true, y_pred):
        """Recall metric.

        Only computes a batch-wise average of recall.

        Computes the recall, a metric for multi-label classification of
        how many relevant items are selected.
        """
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        recall = true_positives / (possible_positives + K.epsilon())
        return recall

    def precision(y_true, y_pred):
        """Precision metric.

        Only computes a batch-wise average of precision.

        Computes the precision, a metric for multi-label classification of
        how many selected items are relevant.
        """
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision 
    
    precision = precision(y_true, y_pred)
    recall = recall(y_true, y_pred)
    return 2*((precision*recall)/(precision+recall+K.epsilon()))

编辑

测试

用于评估模型的预处理数据

normalized_dataset = pd.read_pickle(DATA['preprocessed_test_path'])

lbl_enc = preprocessing.LabelEncoder()
y = lbl_enc.fit_transform(normalized_dataset.label.values)

xtest = normalized_dataset.preprocessed_tweets.values
ytest_enc = np_utils.to_categorical(y)

token = text.Tokenizer(num_words=None)
max_len = 70

token.fit_on_texts(list(xtest))
xtest_seq = token.texts_to_sequences(xtest)
xtest_pad = sequence.pad_sequences(xtest_seq, maxlen=max_len)

EDIT2

这是我触发所述错误的完整回溯

Traceback (most recent call last):
  File "app/main.py", line 67, in <module>
    main()
  File "app/main.py", line 64, in main
    test(embedding_dict)
  File "/Users/justauser/Desktop/sentiment-analysis/app/test.py", line 50, in test
    model.evaluate(xtest_pad, ytest_enc)
  File "/Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1389, in evaluate
    tmp_logs = self.test_function(iterator)
  File "/Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "/Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 871, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 725, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2969, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3361, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3196, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 990, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 634, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 977, in wrapper
    raise e.ag_error_metadata.to_exception(e)
TypeError: in user code:

    /Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:1233 test_function  *
        return step_function(self, iterator)
    /Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:1224 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    /Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:1259 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2730 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:3417 _call_for_each_replica
        return fn(*args, **kwargs)
    /Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:1219 run_step  **
        with ops.control_dependencies(_minimum_control_deps(outputs)):
    /Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:2793 _minimum_control_deps
        outputs = nest.flatten(outputs, expand_composites=True)
    /Users/justauser/Desktop/sentiment-analysis/env/lib/python3.8/site-packages/tensorflow/python/util/nest.py:341 flatten
        return _pywrap_utils.Flatten(structure, expand_composites)

    TypeError: '<' not supported between instances of 'function' and 'str'


Tags: inpyenvmodellibpackagestensorflowsite
3条回答

这解决了我的问题

内存模型的分数与load_model上的分数相同

scores = model.evaluate(X, y, verbose=2, batch_size = batch_size)

model.load()之后,如果使用自定义度量再次编译模型,那么它应该可以工作

因此,在使用从磁盘加载模型之后

model = load_model('app/model/test_model.hdf5', custom_objects={'f1':f1}) 

确保使用感兴趣的度量来编译它

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc', f1])

正如@Zaccharie Ramzi所指出的,如果您想加载模型以恢复训练,那么接受的答案是不合适的,因为编译步骤将重置优化器状态(如果您只想评估或测试模型,这是很好的)。如果要加载模型以恢复培训,可以解决此问题,使用加载的输出重新编译加载的模型:

model = load_model('app/model/test_model.hdf5', custom_objects={'f1':f1}) 
model.compile(loss=model.loss, optimizer=model.optimizer, metrics=['acc', f1])

作为参考,请参阅此解决方案最先发布的github issue

相关问题 更多 >

    热门问题