如何在PyTorch中找到对文本分类模型的标签结果负责的(最重要的)负责词/标记/嵌入

2024-05-19 00:00:06 发布

您现在位置:Python中文网/ 问答频道 /正文

让我们假设我有一个这样的模型:

class BERT_Subject_Classifier(nn.Module):

    def __init__(self,out_classes,hidden1=128,hidden2=32,dropout_val=0.2):
      super(BERT_Subject_Classifier, self).__init__()

      self.hidden1 = hidden1
      self.hidden2 = hidden2
      self.dropout_val = dropout_val
      self.logits = logit
      self.bert = AutoModel.from_pretrained('bert-base-uncased')
      self.out_classes = out_classes
      self.unfreeze_n = unfreeze_n # make the last n layers trainable
      
      self.dropout = nn.Dropout(self.dropout_val)
      self.relu =  nn.ReLU()
      self.fc1 = nn.Linear(768,self.hidden1)
      self.fc2 = nn.Linear(self.hidden1,self.hidden2)
      self.fc3 = nn.Linear(self.hidden2,self.out_classes)

    def forward(self, sent_id, mask):
      _, cls_hs = self.bert(sent_id, attention_mask=mask)
      x = self.fc1(cls_hs)
      x = self.relu(x)
      x = self.dropout(x)
      x = self.fc2(x)
      x = self.dropout(x)
      return self.fc3(x)

我训练我的模型,对于一个新的数据点x = ['My Name is Slim Shady'],我得到的标签结果是3

我的问题是,我如何检查句子中的哪些单词负责分类?我的意思是它可以是任何单词的集合。是否有一个库或方法来检查功能?正如本文和{}中的{a1}所示,您可以获得模型关注的图像区域。我怎样才能完成这篇课文


Tags: 模型selfmaskvalnnoutclassesdropout
1条回答
网友
1楼 · 发布于 2024-05-19 00:00:06

当然。证明哪些单词影响最大的一种方法是综合梯度法。对于PyTorch,您可以使用的一个软件包是Captum。我想在这个页面上找到一个很好的例子:https://captum.ai/tutorials/IMDB_TorchText_Interpret

对于Tensorflow,您可以使用的一个包是Seldon。我想在这一页上找到一个很好的例子: https://docs.seldon.io/projects/alibi/en/stable/examples/integrated_gradients_imdb.html

相关问题 更多 >

    热门问题