加速简单的多维计数器

2024-10-01 17:37:49 发布

您现在位置:Python中文网/ 问答频道 /正文

这是慢代码:

def doCounts(maskA1, maskA2, maskA3, counts, maskB):
    counts[0, maskB & maskA1] += 1
    counts[1, maskB & maskA2] += 1
    counts[2, maskB & maskA3] += 1

有没有一种方法可以一次完成/让它更快?你知道吗


Tags: 方法代码defcountsdocountsmaskbmaska1maska2
1条回答
网友
1楼 · 发布于 2024-10-01 17:37:49

矢量化可能很困难或不可能。这里的提示是,第二维度的高级索引,例如maskB & maskA1,可以为每一行提供任意的True值。因此,您不能为索引隔离m x n数组。你知道吗

使用^{}的天真for循环似乎可以通过以下因素提高性能:

# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0

import numpy as np
from numba import njit

@njit
def doCounts(maskA1, maskA2, maskA3, counts, maskB):
    mask1, mask2, mask3 = maskB & maskA1, maskB & maskA2, maskB & maskA3
    for i in range(counts.shape[0]):
        m1, m2, m3 = mask1[i], mask2[i], mask3[i]
        for j in range(counts.shape[1]):
            if m1:
                counts[0, j] += 1
            if m2:
                counts[1, j] += 1
            if m3:
                counts[2, j] += 1
    return counts

def doCounts_original(maskA1, maskA2, maskA3, counts, maskB):
    counts[0, maskB & maskA1] += 1
    counts[1, maskB & maskA2] += 1
    counts[2, maskB & maskA3] += 1
    return counts

n = 100
np.random.seed(0)
m1, m2, m3, mB = (np.random.randint(0, 2, n**3).astype(bool) for _ in range(4))
counts = np.random.randint(0, 100, (3, n**3))

assert np.array_equal(doCounts(m1, m2, m3, counts, mB),
                      doCounts_original(m1, m2, m3, counts, mB))

%timeit doCounts(m1, m2, m3, counts, mB)           # 5.36 ms
%timeit doCounts_original(m1, m2, m3, counts, mB)  # 40.2 ms

相关问题 更多 >

    热门问题