在Python3中:列表的奇怪行为(iterables)

2024-10-02 10:25:22 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个关于python中iterables行为的具体问题。我的iterable是pytorch中的自定义数据集类:

import torch
from torch.utils.data import Dataset
class datasetTest(Dataset):
    def __init__(self, X):
        self.X = X

    def __len__(self):
        return len(self.X)

    def __getitem__(self, x):
        print('***********')
        print('getitem x = ', x)
        print('###########')
        y = self.X[x]
        print('getitem y = ', y)
        return y

当我初始化datasetest类的一个特定实例时,就会出现这种奇怪的行为。根据作为参数X传递的数据结构,调用list(datasetTestInstance)时,它的行为会有所不同。特别是,当经过张量作为参数是没有问题的,但是当作为参数传递dict时,它会抛出一个KeyError。原因是list(iterable)不仅调用i=0,…,len(iterable)-1,还调用i=0,…,len(iterable)。也就是说,它将迭代直到(包含)索引等于iterable的长度。显然,这个索引在任何python数据结构中都没有定义,因为最后一个元素总是索引len(datastructure)-1,而不是len(datastructure)。如果X是张量或者一个列表,不会出现错误,即使我认为应该是一个错误。即使对于索引为len(datasetTestinstance)的(不存在的)元素,它仍然会调用getitem,但它不会计算y=self.X[len(datasetTestinstance]。有人知道pytorch是否在内部优雅地处理了这个问题吗?你知道吗

当将dict作为数据传递时,它将在最后一次迭代中抛出一个错误,即x=len(datasetTestInstance)。我想这实际上是预期的行为。但为什么这种情况只发生在一个dict上,而不发生在一个list或list上呢火炬张量?你知道吗

if __name__ == "__main__":
    a = datasetTest(torch.randn(5,2))
    print(len(a))
    print('++++++++++++')
    for i in range(len(a)):
        print(i)
        print(a[i])
    print('++++++++++++')
    print(list(a))

    print('++++++++++++')
    b = datasetTest({0: 12, 1:35, 2:99, 3:27, 4:33})
    print(len(b))
    print('++++++++++++')
    for i in range(len(b)):
        print(i)
        print(b[i])
    print('++++++++++++')
    print(list(b))

如果您想更好地理解我所观察到的内容,可以尝试一下这段代码。你知道吗

我的问题是:

1.)为什么list(iterable)迭代到(包括)len(iterable)?for循环不会这样做。你知道吗

2.)如果是张量或者一个作为数据X传递的列表:为什么它在调用索引len(datasetTestInstance)的getitem方法时不抛出一个错误,因为它没有在tensor/list中定义为索引,所以实际应该超出范围?或者,换句话说,当达到索引len(datasetistinstance)然后进入getitem方法时,到底发生了什么?它显然不再调用“y=self.X[X]”(否则会有索引器),但它确实进入了getitem方法,我可以看到它从getitem方法中打印索引X。那么这种方法会发生什么呢?为什么它的行为会因是否有torch.tensor/列表还是口述?你知道吗


Tags: 方法self列表forlendef错误torch
2条回答

一堆有用的链接:

  1. [Python 3.Docs]: Data model - Emulating container types
  2. [Python 3.Docs]: Built-in Types - Iterator Types
  3. [Python 3.Docs]: Built-in Functions - iter(object[, sentinel])
  4. [SO]: Why does list ask about __len__?(所有答案)

关键的一点是list构造函数使用(iterable)参数的\uu len\ueem>((如果提供)来计算新的容器长度),但随后对其进行迭代(通过迭代器协议)。你知道吗

您的示例是以这种方式工作的(迭代了所有键,但未能找到与字典长度相等的键),因为发生了一个可怕的巧合(请记住,dict支持迭代器协议,并且这种情况发生在它的键(这是一个序列)上):

  • 您的字典只有int键(以及更多)
  • 它们的值与它们的索引相同(按顺序)

改变上述两个项目符号所表示的任何条件,都会使实际错误更具说服力。你知道吗

两个对象(dictlist(属于tensors))都支持迭代器协议。为了使事情正常工作,您应该将它包装在数据集类中,并稍微调整映射类型(使用值而不是键)。
代码(key\u func相关部分)有点复杂,但只是易于配置(如果您想更改某些内容-出于演示目的)。你知道吗

代码00.py:

#!/usr/bin/env python3

import sys
import torch
from torch.utils.data import Dataset
from random import randint


class SimpleDataset(Dataset):

    def __init__(self, x):
        self.__iter = None
        self.x = x

    def __len__(self):
        print("    __len__()")
        return len(self.x)

    def __getitem__(self, key):
        print("    __getitem__({0:}({1:s}))".format(key, key.__class__.__name__))
        try:
            val = self.x[key]
            print("    {0:}".format(val))
            return val
        except:
            print("    exc")
            raise #IndexError

    def __iter__(self):
        print("    __iter__()")
        self.__iter = iter(self.x)
        return self

    def __next__(self):
        print("    __next__()")
        if self.__iter is None:
            raise StopIteration
        val = next(self.__iter)
        if isinstance(self.x, (dict,)):  # Special handling for dictionaries
            val = self.x[val]
        return val


def key_transformer(int_key):
    return str(int_key)  # You could `return int_key` to see that it also works on your original example


def dataset_example(inner, key_func=None):
    if key_func is None:
        key_func = lambda x: x
    print("\nInner object: {0:}".format(inner))
    sd = SimpleDataset(inner)
    print("Dataset length: {0:d}".format(len(sd)))
    print("\nIterating (old fashion way):")
    for i in range(len(sd)):
        print("  {0:}: {1:}".format(key_func(i), sd[key_func(i)]))
    print("\nIterating (Python (iterator protocol) way):")
    for element in sd:
        print("  {0:}".format(element))
    print("\nTry building the list:")
    l = list(sd)
    print("  List: {0:}\n".format(l))


def main():
    dict_size = 2

    for inner, func in [
        (torch.randn(2, 2), None),
        ({key_transformer(i): randint(0, 100) for i in reversed(range(dict_size))}, key_transformer),  # Reversed the key order (since Python 3.7, dicts are ordered), to test int keys
    ]:
        dataset_example(inner, key_func=func)


if __name__ == "__main__":
    print("Python {0:s} {1:d}bit on {2:s}\n".format(" ".join(item.strip() for item in sys.version.split("\n")), 64 if sys.maxsize > 0x100000000 else 32, sys.platform))
    main()
    print("\nDone.")

输出

[cfati@CFATI-5510-0:e:\Work\Dev\StackOverflow\q059091544]> "e:\Work\Dev\VEnvs\py_064_03.07.03_test0\Scripts\python.exe" code00.py
Python 3.7.3 (v3.7.3:ef4ec6ed12, Mar 25 2019, 22:22:05) [MSC v.1916 64 bit (AMD64)] 64bit on win32


Inner object: tensor([[ 0.6626,  0.1107],
        [-0.1118,  0.6177]])
    __len__()
Dataset length: 2

Iterating (old fashion way):
    __len__()
    __getitem__(0(int))
    tensor([0.6626, 0.1107])
  0: tensor([0.6626, 0.1107])
    __getitem__(1(int))
    tensor([-0.1118,  0.6177])
  1: tensor([-0.1118,  0.6177])

Iterating (Python (iterator protocol) way):
    __iter__()
    __next__()
  tensor([0.6626, 0.1107])
    __next__()
  tensor([-0.1118,  0.6177])
    __next__()

Try building the list:
    __iter__()
    __len__()
    __next__()
    __next__()
    __next__()
  List: [tensor([0.6626, 0.1107]), tensor([-0.1118,  0.6177])]


Inner object: {'1': 86, '0': 25}
    __len__()
Dataset length: 2

Iterating (old fashion way):
    __len__()
    __getitem__(0(str))
    25
  0: 25
    __getitem__(1(str))
    86
  1: 86

Iterating (Python (iterator protocol) way):
    __iter__()
    __next__()
  86
    __next__()
  25
    __next__()

Try building the list:
    __iter__()
    __len__()
    __next__()
    __next__()
    __next__()
  List: [86, 25]


Done.

您可能还需要检查[PyTorch]: SOURCE CODE FOR TORCH.UTILS.DATA.DATASETIterableDataset)。你知道吗

这并不是pytorch特有的问题,而是一个一般的python问题。你知道吗

您正在使用list(iterable)构建一个列表,其中iterable类是实现sequence semantics的类。你知道吗

在这里看一下^{}对于序列类型的预期行为(大多数相关部分用粗体表示)

object.__getitem__(self, key)

Called to implement evaluation of self[key]. For sequence types, the accepted keys should be integers and slice objects. Note that the special interpretation of negative indexes (if the class wishes to emulate a sequence type) is up to the __getitem__() method. If key is of an inappropriate type, TypeError may be raised; if of a value outside the set of indexes for the sequence (after any special interpretation of negative values), IndexError should be raised. For mapping types, if key is missing (not in the container), KeyError should be raised.

Note: for loops expect that an IndexError will be raised for illegal indexes to allow proper detection of the end of the sequence.

这里的问题是,对于序列类型,python期望在使用无效索引调用__getitem__时使用IndexError。似乎list构造函数依赖于此行为。在您的示例中,当X是dict时,尝试访问无效键会导致__getitem__引发KeyError,而这不是预期的,因此不会被捕获并导致列表的构造失败。你知道吗


根据这些信息,你可以做如下的事情

class datasetTest:
    def __init__(self):
        self.X = {0: 12, 1:35, 2:99, 3:27, 4:33}

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        if index < 0 or index >= len(self):
            raise IndexError
        return self.X[index]

d = datasetTest()
print(list(d))

我不建议在实践中这样做,因为它依赖于只包含整数键的字典X01len(X)-1,这意味着在大多数情况下,它的行为就像一个列表,所以您最好只使用一个列表。你知道吗

相关问题 更多 >

    热门问题