令牌分类的Tensorflow BERT在培训和测试时将PadToken排除在准确性之外

2024-09-29 21:33:51 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用tensorflow的预训练BERT模型进行基于标记的分类,以自动标记句子中的因果关系

为了访问BERT,我使用huggingface:https://huggingface.co/transformers/model_doc/bert.html#tfbertfortokenclassification中的TFBertForTokenClassification接口

我用来训练的句子都根据BERT标记器转换成标记(基本上是单词到数字的映射),然后在训练前填充到一定长度,因此,当一个句子只有50个标记,而另一个句子只有30个标记时,第一个句子用50个pad标记填充,第二个句子用70个pad标记填充,以获得100的通用输入句子长度

然后,我训练我的模型预测每个标记,该标记属于哪个标记;无论是原因的一部分,还是结果的一部分

然而,在培训和评估期间,我的模型也会对PAD标记进行预测,并且它们也包含在模型的准确性中。由于PAD标记对于模型来说非常容易预测(它们总是有相同的标记,并且它们都有“无”标签,这意味着它们既不属于句子的原因也不属于句子的结果),它们确实扭曲了我的模型的准确性

例如,如果您有一个包含30个单词的句子->;30个标记,你将所有句子的长度填充到100,那么这个句子将得到70%的分数,即使模型没有正确预测任何“真实”标记。 通过这种方式,我可以很快获得90%以上的培训和验证准确率,尽管该模型在真实的pad令牌上表现不佳

我以为注意力面具可以解决这个问题,但事实似乎并非如此

输入数据集的创建如下所示:

def example_to_features(input_ids,attention_masks,token_type_ids,label_ids):
  return {"input_ids": input_ids,
          "attention_mask": attention_masks},label_ids

train_ds = tf.data.Dataset.from_tensor_slices((input_ids_train,attention_masks_train,token_ids_train,label_ids_train)).map(example_to_features).shuffle(buffer_size=1000).batch(32)

模型创建:

from transformers import TFBertForTokenClassification

num_epochs = 30

model = TFBertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=3)

model.layers[-1].activation = tf.keras.activations.softmax

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-6)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')

model.compile(optimizer=optimizer, loss=loss, metrics=[metric])

model.summary()

然后我就这样训练它:

history = model.fit(train_ds, epochs=num_epochs, validation_data=validate_ds)

到目前为止,是否有人遇到过这个问题,或者知道如何在训练和评估期间将pad代币上的预测从模型的准确性中排除


Tags: from标记模型idsinputmodeltftrain
1条回答
网友
1楼 · 发布于 2024-09-29 21:33:51

是的,这很正常

BERT[batch_size, max_seq_len = 100, hidden_size]的输出也将包括[PAD]令牌的值或嵌入。但是,您还为BERT模型提供了attention_masks,以便它不考虑这些[PAD]令牌

类似地,在将BERT结果传递到最终完全连接的层之前,您需要屏蔽这些[PAD]标记,在计算损失时屏蔽它们,以及计算精度和召回率等指标

相关问题 更多 >

    热门问题