我正在与PyTorch一起处理与BERT的文本分类问题。这是我正在使用的PyTorch数据集格式,但是当我试图访问数据集的输入时,我得到了一个错误
数据集返回一个包含以下内容的字典:ids
、mask
、token_type_ids
、targets
class JigsawDataset:
def __init__(self, df, train_transforms = None):
self.comment_text = df["comment_text"].values
self.target = df["toxic"].values
self.tokenizer = config.BERT_TOKENIZER
self.max_len = config.MAX_LEN
self.langs = df["lang"].values
self.train_transforms = train_transforms
def __len__(self):
return len(self.comment_text)
def __getitem__(self, item):
comment_text = str(self.comment_text[item])
comment_text = " ".join(comment_text.split())
lang = self.langs[item]
if self.train_transforms:
comment_text, _ = self.train_transforms(data=(comment_text, lang))['data']
inputs = self.tokenizer.encode_plus(
comment_text,
None,
add_special_tokens=True,
max_length=self.max_len,
pad_to_max_length=True,
truncation=True,
)
ids = inputs["input_ids"]
mask = inputs["attention_mask"]
token_type_ids = inputs["token_type_ids"]
data_loader_dict = {}
data_loader_dict["ids"] = torch.tensor(ids, dtype=torch.long)
data_loader_dict["mask"] = torch.tensor(mask, dtype=torch.long)
data_loader_dict["token_type_ids"] = torch.tensor(token_type_ids, dtype=torch.long)
data_loader_dict["targets"] = torch.tensor(self.target[item], dtype=torch.float)
return data_loader_dict
在本例中,我尝试只加载1个示例,并使其成为PyTorch数据集的格式
df = pd.read_csv("dataset.csv")
df = df.head(1) # Trying with only 1 Sample
dataset = JigsawDataset(df)
ids = dataset["ids"] # Error occurs at this line
mask = dataset["mask"]
token_type_ids = ["token_type_ids"]
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-78-4608dd623cac> in <module>
3 dataset = JigsawDataset(df)
4
----> 5 ids = dataset["ids"] # Error occurs at this line
6 mask = dataset["mask"]
7 token_type_ids = ["token_type_ids"]
<ipython-input-40-121d8aa71516> in __getitem__(self, item)
13
14 def __getitem__(self, item):
---> 15 comment_text = str(self.comment_text[item])
16 comment_text = " ".join(comment_text.split())
17 lang = self.langs[item]
IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices
如何解决这个问题
我解决了这个问题
错误代码
正确的代码
问题是“id”、“mask”和“token\u type\u id”是字典键
JigsawDataset
为每个示例返回一个字典。因此,为了访问示例,我们需要在指定键之前指定索引([0]
)根据the docs,panda的DataFrame对象的方法“values”返回一个numpy数组。
在代码中,将属性“self.comment\u text”设置为“df[“comment\u text”].values”返回的numpy数组(代码框1中的第3行)。
Numpy数组不接受字符串作为索引。
很难给你一个答案,我相信如果不测试它,它一定会工作,但我会先将属性“self.comment_text”设置为数据帧或其副本,而不仅仅是它所包含的值
我想改变这一点:
为此:
相关问题 更多 >
编程相关推荐