比较两个以numpy矩阵为值的字典

2024-06-26 14:51:19 发布

您现在位置:Python中文网/ 问答频道 /正文

我想断言两个Python字典是相等的(这意味着:键的数量相等,并且从键到值的每个映射都是相等的;顺序并不重要)。一个简单的方法是assert A==B,但是,如果字典的值是numpy arrays,则这不起作用。如果两个字典相等,如何编写一个函数来检查一般情况?在

>>> import numpy as np
>>> A = {1: np.identity(5)}
>>> B = {1: np.identity(5) + np.ones([5,5])}
>>> A == B
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

编辑我知道应该检查numpy矩阵是否与.all()相等。我要找的是一种通用的检查方法,而不必检查isinstance(np.ndarray)。这有可能吗?在

没有numpy数组的相关主题:


Tags: 方法函数numpy数量字典顺序npassert
3条回答

考虑一下这个代码

>>> import numpy as np
>>> np.identity(5)
array([[ 1.,  0.,  0.,  0.,  0.],
       [ 0.,  1.,  0.,  0.,  0.],
       [ 0.,  0.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  0.],
       [ 0.,  0.,  0.,  0.,  1.]])
>>> np.identity(5)+np.ones([5,5])
array([[ 2.,  1.,  1.,  1.,  1.],
       [ 1.,  2.,  1.,  1.,  1.],
       [ 1.,  1.,  2.,  1.,  1.],
       [ 1.,  1.,  1.,  2.,  1.],
       [ 1.,  1.,  1.,  1.,  2.]])
>>> np.identity(5) == np.identity(5)+np.ones([5,5])
array([[False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False]], dtype=bool)
>>> 

注意比较的结果是一个矩阵,而不是一个布尔值。Dict比较将使用valuescmp方法比较值,这意味着在比较矩阵值时,Dict比较将得到一个复合结果。你想做的是使用 numpy.all将复合数组结果折叠为标量布尔结果

^{2}$

您需要编写自己的函数来比较这些字典,测试值类型以确定它们是否是矩阵,然后使用numpy.all进行比较,否则使用==。当然,如果你想的话,你也可以开始编写dict的子类化和重载cmp。在

我将回答隐藏在你问题标题和前半部分的半个问题,因为坦率地说,这是一个需要解决的更常见的问题,而现有的答案并没有很好地解决这个问题。这个问题是“如何比较numpy数组的两个dicts以获得相等性?在

问题的第一部分是“从远处”检查dicts:看看它们的键是一样的。如果所有键都相同,则第二部分将比较每个对应的值。在

现在微妙的问题是很多numpy数组不是整数值,而且double-precision is imprecise。因此,除非您有整数值(或其他非浮点型)数组,否则您可能需要检查这些值是否几乎相同,即在机器精度范围内。所以在本例中,您不会使用^{}(它检查精确的数值相等性),而是使用^{}(它对两个数组之间的相对和绝对误差使用有限的公差)。在

问题的前一个半部分很简单:检查dicts的键是否一致,并使用生成器理解来比较每个值(并在理解之外使用all来验证每个项是否相同):

import numpy as np

# some dummy data

# these are equal exactly
dct1 = {'a': np.array([2, 3, 4])}
dct2 = {'a': np.array([2, 3, 4])}

# these are equal _roughly_
dct3 = {'b': np.array([42.0, 0.2])}
dct4 = {'b': np.array([42.0, 3*0.1 - 0.1])}  # still 0.2, right?

def compare_exact(first, second):
    """Return whether two dicts of arrays are exactly equal"""
    if first.keys() != second.keys():
        return False
    return all(np.array_equal(first[key], second[key]) for key in first)

def compare_approximate(first, second):
    """Return whether two dicts of arrays are roughly equal"""
    if first.keys() != second.keys():
        return False
    return all(np.allclose(first[key], second[key]) for key in first)

# let's try them:
print(compare_exact(dct1, dct2))  # True
print(compare_exact(dct3, dct4))  # False
print(compare_approximate(dct3, dct4))  # True

正如您在上面的例子中所看到的,整数数组比较精确,并且根据您正在做什么(或者如果幸运的话),它甚至可以用于float。但是如果你的浮动是任何一种算术的结果(比如线性变换?)你一定要用近似的支票。有关后一个选项的完整描述,请参见the docs of ^{}(及其元素方面的朋友^{}),特别是rtol和{}关键字参数。在

相关问题 更多 >