如何编写equals方法

2024-06-24 13:40:23 发布

您现在位置:Python中文网/ 问答频道 /正文

情境:我正试图很好地理解双链接结构。到目前为止,我对这些方法掌握得很好。我希望能够为这个类创建两个对象,并检查其中的每个项是否相等。我没有任何语法错误,我得到的错误是有点混乱。这是我到目前为止的情况。你知道吗

class LinkedList:
    class Node:
        def __init__(self, val, prior=None, next=None):
            self.val = val
            self.prior = prior
            self.next  = next

    def __init__(self):
        self.head = LinkedList.Node(None) # sentinel node (never to be removed)
        self.head.prior = self.head.next = self.head # set up "circular" topology
        self.length = 0

    def append(self, value):
        n = LinkedList.Node(value, prior=self.head.prior, next=self.head)
        n.prior.next = n.next.prior = n
        self.length += 1

    def _normalize_idx(self, idx):
        nidx = idx
        if nidx < 0:
            nidx += len(self)
            if nidx < -1:
                raise IndexError  
        return nidx

    def __getitem__(self, idx):
        """Implements `x = self[idx]`"""
        nidx = self._normalize_idx(idx)
        currNode = self.head.next
        for i in range(nidx):
            currNode = currNode.next
        if nidx >= len(self):
            raise IndexError
        return currNode.val


    def __setitem__(self, idx, value):
        """Implements `self[idx] = x`"""
        nidx = self._normalize_idx(idx)
        currNode = self.head.next
        if nidx >= len(self):
            raise IndexError
        for i in range(nidx):
            currNode = currNode.next
        currNode.val = value

    def __iter__(self):
        """Supports iteration (via `iter(self)`)"""
        cursor = self.head.next
        while cursor is not self.head:
            yield cursor.val
            cursor = cursor.next

    def __len__(self):
        """Implements `len(self)`"""
        return self.length

    def __eq__(self, other):
        currNode = self.head.next
        currNode2 = other.head.next
        for currNode, currNode2 in zip(self, other):
            if currNode.val != currNode2.val:
                return False
        return True

测试:

from unittest import TestCase
tc = TestCase()
lst = LinkedList()
lst2 = LinkedList()

tc.assertEqual(lst, lst2)

lst2.append(100)
tc.assertNotEqual(lst, lst2)

当我测试这段代码时,我得到了一个断言错误,它说“[] == [100]”,我不确定为什么我的代码会将其识别为相等,而我希望它实际检查节点中的特定值。你知道吗


Tags: selflenreturnifvaluedefvalhead
2条回答

zip只能到达最短的列表。你想要itertools.zip_longest,你不想要.val(你的迭代器已经返回了实际值)。试试这个:

def __eq__(self, other):
    for val1, val2 in zip_longest(self, other):
        if val1 != val2:
            return False
    return True

或者更好?你知道吗

def __eq__(self, other):
    return all(val1 == val2 for val1, val2 in zip_longest(self, other))

编辑

我喜欢@BrenBarn建议先检查长度。这里有一个更有效的答案:

def __eq__(self, other):
    return len(self) == len(other) and all(
        val1 == val2 for val1, val2 in zip(self, other))

zip(self.other)只提供两个列表中较短的元素。它将丢弃较长列表中多余的部分。因此,对于[] == [100]zip不提供任何元素,您的代码不检查任何内容就返回True。你知道吗

你可以在开始的时候做一个检查,看看列表是否有不同的长度。如果他们这样做了,他们就不可能平等。你知道吗

相关问题 更多 >