要查找“开始”分数最高的令牌时,torch.argmax()中存在TypeError

2024-10-04 11:22:24 发布

您现在位置:Python中文网/ 问答频道 /正文

我想用拥抱变形金刚来运行这段代码来回答问题

import torch
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer

#Model
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

#Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

question = '''Why was the student group called "the Methodists?"'''

paragraph = ''' The movement which would become The United Methodist Church began in the mid-18th century within the Church of England.
            A small group of students, including John Wesley, Charles Wesley and George Whitefield, met on the Oxford University campus.
            They focused on Bible study, methodical study of scripture and living a holy life.
            Other students mocked them, saying they were the "Holy Club" and "the Methodists", being methodical and exceptionally detailed in their Bible study, opinions and disciplined lifestyle.
            Eventually, the so-called Methodists started individual societies or classes for members of the Church of England who wanted to live a more religious life. '''
            
encoding = tokenizer.encode_plus(text=question,text_pair=paragraph)

inputs = encoding['input_ids']  #Token embeddings
sentence_embedding = encoding['token_type_ids']  #Segment embeddings
tokens = tokenizer.convert_ids_to_tokens(inputs) #input tokens

start_scores, end_scores = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))

start_index = torch.argmax(start_scores)

但我在最后一行得到了这个错误:

Exception has occurred: TypeError
argmax(): argument 'input' (position 1) must be Tensor, not str
  File "D:\bert\QuestionAnswering.py", line 33, in <module>
    start_index = torch.argmax(start_scores)

我不知道怎么了。有人能帮我吗


Tags: andoftheinfromimportidsinput
2条回答

Huggingface transformers提供了一种运行模型的简单高级方法,如下图所示:

from transformers import pipeline

nlp = pipeline('question-answering', model=model, tokenizer=tokenizer)
print(nlp(question=question, context=paragraph, topk=5))

topk允许选择几个得分最高的答案

BertForQuestionAnswering返回一个^{}对象

由于将BertForQuestionAnswering的输出设置为start_scores, end_scores,因此返回的QuestionAnsweringModelOutput对象被强制转换为字符串的元组('start_logits', 'end_logits'),从而导致类型不匹配错误

以下方面应起作用:

outputs = model(input_ids=torch.tensor([inputs]), token_type_ids=torch.tensor([sentence_embedding]))

start_index = torch.argmax(outputs.start_logits)

相关问题 更多 >