桶迭代器不返回正确大小的批次

2024-10-02 08:17:51 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在PyTorch中实现一个简单的LSTM语言模型,并想查看torchtext提供的BucketIterator。在

结果是返回的批处理的大小与我整个语料库的大小相同,所以在初始化它的过程中,我一定是做错了什么。在

我已经让BPTTIterator工作了,但是由于我希望能够训练成批完整的句子,我认为BucketIterator应该是一种方法。在

我使用下面的设置,我的语料库是一个简单的txt文件,每行都包含句子。在

field = Field(use_vocab=True, batch_first=True)
corpus = PennTreebank('project_2_data/train_lines.txt', field)
field.build_vocab(corpus)

iterator = BucketIterator(corpus,
                          batch_size=64,
                          repeat=False,
                          sort_key=lambda x: len(x.text),
                          sort_within_batch=True,
                          )

我希望来自这个迭代器的批处理具有(batch_size, max_len)的形状,但是它将整个语料库追加到形状(1, corpus_size)的1个张量中。在

我的设置中遗漏了什么?在

编辑:似乎PennTreebank对象与BucketIterator不兼容(它只包含1Example,如这里所述http://mlexplained.com/2018/02/15/language-modeling-tutorial-in-torchtext-practical-torchtext-part-2/)。使用一个只有1FieldTabularDataset可以使它正常工作。在

如果有人知道如何以更优雅的方式使用填充语句批处理语言建模,我很乐意听到!在


Tags: txt语言truefieldsizelenbatchcorpus

热门问题