找到同一集合的两个分区之间的所有不同交点的一种简单而有效的方法

2024-06-25 23:22:45 发布

您现在位置:Python中文网/ 问答频道 /正文

我需要找到同一集合的两个分区之间的所有不同交点。例如,如果我们有相同集合的以下两个分区

x = [[1, 2], [3, 4, 5], [6, 7, 8, 9, 10]]
y = [[1, 3, 6, 7], [2, 4, 5, 8, 9, 10]]

要求的结果是

[1],[2],[3],[4,5],[6,7],[8,9,10]

具体地说,我们计算x和y的每个子集之间的笛卡尔积,对于这些积中的每一个,我们相应地将元素分类到新的子集中,如果它们是否属于其相关子集的交集

做这件事的最佳方式是什么?提前谢谢


当前答案的性能比较:

import numpy as np

def partitioning(alist, indices):
    return [alist[i:j] for i, j in zip([0]+indices, indices+[None])]

total = 1000
sample1 = np.sort(np.random.choice(total, int(total/10), replace=False))
sample2 = np.sort(np.random.choice(total, int(total/2), replace=False))

a = partitioning(np.arange(total), list(sample1))
b = partitioning(np.arange(total), list(sample2))

def partition_decomposition_product_1(x, y):
    out = []
    for sublist1 in x:
        d = {}
        for val in sublist1:
            for i, sublist2 in enumerate(y):
                if val in sublist2:
                    d.setdefault(i, []).append(val)
        out.extend(d.values())
    return out

def partition_decomposition_product_2(x, y):
    all_s = []
    for sx in x:
        for sy in y:
            ss = list(filter(lambda x:x in sx, sy))
            if ss:
                all_s.append(ss)
    return all_s

def partition_decomposition_product_3(x, y):
    return [np.intersect1d(i,j) for i in x for j in y]

并使用%timeit测量执行时间

%timeit partition_decomposition_product_1(a, b)
%timeit partition_decomposition_product_2(a, b)
%timeit partition_decomposition_product_3(a, b)

我们发现

2.16 s ± 246 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
620 ms ± 84.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.03 s ± 111 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

因此,第二种解决方案是最快的


Tags: inloopforreturndefnpproductlist
3条回答

两个列表是同一集合的分区这一事实与算法选择无关。这归结为迭代两个列表并获得每个组合之间的交集(您可以在函数开头添加该断言,以确保它们是同一集合的分区,使用this answer有效地展平列表)。考虑到这一点,此函数通过使用this answer计算列表交点来完成任务:

def func2(x, y):
    # check that they partition the same set 
    checkx = sorted([item for sublist in x for item in sublist])
    checky = sorted([item for sublist in y for item in sublist])
    assert checkx == checky

    # get all intersections
    all_s = []
    for sx in x:
        for sy in y:
            ss = list(filter(lambda x:x in sx, sy))
            if ss:
                all_s.append(ss)
    return all_s

然后使用this time comparison method,我们可以看到这个新函数比原始实现快约100倍

我可能会错过一些细节,但似乎有点太简单了:

[np.intersect1d(a,b) for a in x for b in y]

输出:

[array([1]),
 array([2]),
 array([3]),
 array([4, 5]),
 array([6, 7]),
 array([ 8,  9, 10])]

上面包括重复项,例如x=[[1,2,3],[1,4,5]]y=[[1,6,7]]将给出[[1],[1]]


如果要查找唯一的交点,请执行以下操作:

[list(i) for i in {tuple(np.intersect1d(a,b)) for a in x for b in y}]

输出:

[[8, 9, 10], [6, 7], [1], [4, 5], [2], [3]]

我不确定我是否正确理解您,但此脚本生成了您在问题中的结果:

x = [[1, 2], [3, 4, 5], [6, 7, 8, 9, 10]]
y = [[1, 3, 6, 7], [2, 4, 5, 8, 9, 10]]

out = []
for sublist1 in x:
    d = {}
    for val in sublist1:
        for i, sublist2 in enumerate(y):
            if val in sublist2:
                d.setdefault(i, []).append(val)
    out.extend(d.values())

print(out)

印刷品:

[[1], [2], [3], [4, 5], [6, 7], [8, 9, 10]]

相关问题 更多 >