Python-PyTorch:indexer错误:只有整数、片(`:`)、省略号(`…`)、numpy.newaxis(`None`)和整数或布尔数组是有效的索引

2024-05-06 22:47:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在与PyTorch一起处理与BERT的文本分类问题。这是我正在使用的PyTorch数据集格式,但是当我试图访问数据集的输入时,我得到了一个错误

PyTorch数据集

数据集返回一个包含以下内容的字典:idsmasktoken_type_idstargets

class JigsawDataset:
    def __init__(self, df, train_transforms = None):
        self.comment_text = df["comment_text"].values
        self.target = df["toxic"].values
        self.tokenizer = config.BERT_TOKENIZER
        self.max_len = config.MAX_LEN
        self.langs = df["lang"].values
        self.train_transforms = train_transforms

    def __len__(self):
        return len(self.comment_text)

    def __getitem__(self, item):
        comment_text = str(self.comment_text[item])
        comment_text = " ".join(comment_text.split())
        lang = self.langs[item]
        
        if self.train_transforms:
            comment_text, _ = self.train_transforms(data=(comment_text, lang))['data']

        inputs = self.tokenizer.encode_plus(
            comment_text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True,
            truncation=True,
        )

        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]
        token_type_ids = inputs["token_type_ids"]

        data_loader_dict = {}
        data_loader_dict["ids"] = torch.tensor(ids, dtype=torch.long)
        data_loader_dict["mask"] = torch.tensor(mask, dtype=torch.long)
        data_loader_dict["token_type_ids"] = torch.tensor(token_type_ids, dtype=torch.long)
        data_loader_dict["targets"] = torch.tensor(self.target[item], dtype=torch.float)
        
        
        return data_loader_dict

给出错误的相关代码

在本例中,我尝试只加载1个示例,并使其成为PyTorch数据集的格式

df = pd.read_csv("dataset.csv")
df = df.head(1) # Trying with only 1 Sample
dataset = JigsawDataset(df)

ids = dataset["ids"]    # Error occurs at this line
mask = dataset["mask"]
token_type_ids = ["token_type_ids"]

错误

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-78-4608dd623cac> in <module>
      3 dataset = JigsawDataset(df)
      4 
----> 5 ids = dataset["ids"]    # Error occurs at this line
      6 mask = dataset["mask"]
      7 token_type_ids = ["token_type_ids"]

<ipython-input-40-121d8aa71516> in __getitem__(self, item)
     13 
     14     def __getitem__(self, item):
---> 15         comment_text = str(self.comment_text[item])
     16         comment_text = " ".join(comment_text.split())
     17         lang = self.langs[item]

IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

如何解决这个问题


Tags: 数据textselftokenidsdfdatatype
2条回答

我解决了这个问题

错误代码

ids = dataset["ids"]    
mask = dataset["mask"]
token_type_ids = ["token_type_ids"]

正确的代码

ids = dataset[0]["ids"]    
mask = dataset[0]["mask"]
token_type_ids = [0]["token_type_ids"]

问题是“id”、“mask”和“token\u type\u id”是字典键JigsawDataset为每个示例返回一个字典。因此,为了访问示例,我们需要在指定键之前指定索引([0]

根据the docs,panda的DataFrame对象的方法“values”返回一个numpy数组。
在代码中,将属性“self.comment\u text”设置为“df[“comment\u text”].values”返回的numpy数组(代码框1中的第3行)。
Numpy数组不接受字符串作为索引。
很难给你一个答案,我相信如果不测试它,它一定会工作,但我会先将属性“self.comment_text”设置为数据帧或其副本,而不仅仅是它所包含的值

我想改变这一点:

class JigsawDataset:
    def __init__(self, df, train_transforms = None):
        self.comment_text = df["comment_text"].values
        self.target = df["toxic"].values
        self.tokenizer = config.BERT_TOKENIZER
.
.
.

为此:

class JigsawDataset:
    def __init__(self, df, train_transforms = None):
        self.comment_text = df["comment_text"]
        self.target = df["toxic"].values
        self.tokenizer = config.BERT_TOKENIZER
.
.
.

相关问题 更多 >