用另一个数组过滤numpy数组的最快方法是什么?

2024-10-03 02:46:31 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个相当大的np.arraya(10000-50000个元素,每个坐标(x,y))和另一个更大的np.arrayb(100000-200000个坐标)。我需要尽快删除a中不在b中的元素,只保留b中存在的a元素。所有坐标都是整数。例如:

a = np.array([[2,5],[6,3],[4,2],[1,4]])
b = np.array([[2,7],[4,2],[1,5],[6,3]])

期望输出:

a

>> [6,3],[4,2]

对于我提到的大小的阵列,最快的方法是什么

除了Numpy中的解决方案外,我还可以使用任何其他包或导入(例如,转换为基本Pythonlist or set,使用Pandas,等等)的解决方案


Tags: or方法numpy元素pandasnp整数解决方案
2条回答

这似乎在很大程度上取决于数组大小和“稀疏性”(可能是由于哈希表的魔力)

来自Get intersecting rows across two 2D numpy arrays的答案是so_8317022函数

外卖似乎(在我的机器上)是:

  • Pandas方法具有较大稀疏集的优势
  • 集合交集非常非常快,数组大小很小(尽管它返回的是集合,而不是numpy数组)
  • 另一个Numpy答案可以比设置较大数组大小的交集更快
from collections import defaultdict

import numpy as np
import pandas as pd
import timeit
import matplotlib.pyplot as plt


def pandas_merge(a, b):
    return pd.DataFrame(a).merge(pd.DataFrame(b)).to_numpy()


def set_intersection(a, b):
    return set(map(tuple, a.tolist())) & set(map(tuple, b.tolist()))


def so_8317022(a, b):
    nrows, ncols = a.shape
    dtype = {
        "names": ["f{}".format(i) for i in range(ncols)],
        "formats": ncols * [a.dtype],
    }
    C = np.intersect1d(a.view(dtype), b.view(dtype))
    return C.view(a.dtype).reshape(-1, ncols)


def test_fn(f, a, b):
    number, time_taken = timeit.Timer(lambda: f(a, b)).autorange()
    return number / time_taken


def test(size, max_coord):
    a = np.random.default_rng().integers(0, max_coord, size=(size, 2))
    b = np.random.default_rng().integers(0, max_coord, size=(size, 2))
    return {fn.__name__: test_fn(fn, a, b) for fn in (pandas_merge, set_intersection, so_8317022)}


series = []
datas = defaultdict(list)

for size in (100, 1000, 10000, 100000):
    for max_coord in (50, 500, 5000):
        print(size, max_coord)
        series.append((size, max_coord))
        for fn, result in test(size, max_coord).items():
            datas[fn].append(result)

print("size", "sparseness", "func", "ops/sec")
for fn, values in datas.items():
    for (size, max_coord), value in zip(series, values):
        print(size, max_coord, fn, int(value))

我机器上的结果是

^{tb1}$

不确定这是否是最快的方法,但如果将其转换为熊猫索引,则可以使用其交集方法。由于它在后台使用低级c代码,交叉步骤可能非常快,但将其转换为熊猫索引可能需要一些时间

import numpy as np
import pandas as pd

a = np.array([[2, 5], [6, 3], [4, 2], [1, 4]])
b = np.array([[2, 7], [4, 2], [1, 5], [6, 3]])

df_a = pd.DataFrame(a).set_index([0, 1])
df_b = pd.DataFrame(b).set_index([0, 1])
intersection = df_a.index.intersection(df_b.index)

结果是这样的

print(intersection.values)
[(6, 3) (4, 2)]

编辑2:

出于好奇,我对两种方法进行了比较。现在有了更大的索引列表。我将我的第一个索引方法与稍微改进的方法进行了比较,该方法不需要首先创建数据帧,但会立即创建索引,然后再与提出的数据帧合并方法进行比较

这是密码

from random import randint, seed
import time
import numpy as np
import pandas as pd

seed(0)

n_tuple = 100000
i_min = 0
i_max = 10
a = [[randint(i_min, i_max), randint(i_min, i_max)] for _ in range(n_tuple)]
b = [[randint(i_min, i_max), randint(i_min, i_max)] for _ in range(n_tuple)]
np_a = np.array(a)
np_b = np.array(b)


def method0(a_array, b_array):
    index_a = pd.DataFrame(a_array).set_index([0, 1]).index
    index_b = pd.DataFrame(b_array).set_index([0, 1]).index
    return index_a.intersection(index_b).to_numpy()


def method1(a_array, b_array):
    index_a = pd.MultiIndex.from_arrays(a_array.T)
    index_b = pd.MultiIndex.from_arrays(b_array.T)
    return index_a.intersection(index_b).to_numpy()


def method2(a_array, b_array):
    df_a = pd.DataFrame(a_array)
    df_b = pd.DataFrame(b_array)
    return df_a.merge(df_b).to_numpy()


def method3(a_array, b_array):
    set_a = {(_[0], _[1]) for _ in a_array}
    set_b = {(_[0], _[1]) for _ in b_array}
    return set_a.intersection(set_b)


for cnt, intersect in enumerate([method0, method1, method2, method3]):
    t0 = time.time()
    if cnt < 3:
        intersection = intersect(np_a, np_b)
    else:
        intersection = intersect(a, b)
    print(f"method{cnt}: {time.time() - t0}")

输出如下所示:

method0: 0.1439347267150879
method1: 0.14012742042541504
method2: 4.740894317626953
method3: 0.05933070182800293

结论:数据帧合并方法(方法2)比在索引上使用交集慢50倍左右。基于多索引(method1)的版本只比method0(我的第一个建议)稍微快一点

EDIT2:@AKX的评论建议:如果您不使用numpy,而是使用纯列表和集合,那么您可以再次获得大约3倍的速度提升。但很明显,您不应该使用合并方法

相关问题 更多 >