基于numpy的有效区域加权和

2024-05-07 03:40:10 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个定义区域的索引矩阵,如下所示:

0 0 0 0 1 1 1
0 0 0 1 1 1 1
0 0 1 1 1 1 2
0 1 1 1 1 1 2
2 2 2 2 2 2 2
3 3 3 3 3 3 3

我有另一个同样大小的重量矩阵。我想对每个区域进行加权求和。这是我的第一次尝试:

^{pr2}$

显然,Python中的本机循环不是很快。 我第二次尝试使用掩蔽:

x, y = [], []
row_matrix = np.fromfunction(lambda i, j: i, weights.shape)
col_matrix = np.fromfunction(lambda i, j: j, weights.shape)

for ind in range(num_regions):
    mask = (indices == ind)
    xSum = sum(weights[mask] * row_matrix[mask])
    ySum = sum(weights[mask] * col_matrix[mask])
    dSum = sum(weights[mask])

    x.append(xSum / dSum)
    y.append(ySum / dSum)

问题是,我能做得更快吗?没有循环,纯粹在矩阵上?在

对于测试,您可以生成随机的大矩阵,如下所示:

indices = np.random.randint(0, 100, (1000, 1000))
weights = np.random.rand(1000, 1000)

在这个数据集上,第一个取1.8s,后一个取0.9s


Tags: lambda区域np矩阵maskcolmatrixrow
1条回答
网友
1楼 · 发布于 2024-05-07 03:40:10

使用^{}

import numpy as np

indices = np.random.randint(0, 100, (1000, 1000))
weights = np.random.rand(1000, 1000)

def orig(indices, weights):
    x, y = [], []
    row_matrix = np.fromfunction(lambda i, j: i, weights.shape)
    col_matrix = np.fromfunction(lambda i, j: j, weights.shape)
    num_regions = indices.max()+1
    for ind in range(num_regions):
        mask = (indices == ind)
        xSum = sum(weights[mask] * row_matrix[mask])
        ySum = sum(weights[mask] * col_matrix[mask])
        dSum = sum(weights[mask])

        x.append(xSum / dSum)
        y.append(ySum / dSum)
    return x, y

def alt(indices, weights):
    indices = indices.ravel()
    h, w = weights.shape
    row_matrix, col_matrix = np.ogrid[:h, :w]
    dSum = np.bincount(indices, weights=weights.ravel())
    xSum = np.bincount(indices, weights=(weights*row_matrix).ravel())
    ySum = np.bincount(indices, weights=(weights*col_matrix).ravel())
    return xSum/dSum, ySum/dSum

expected_x, expected_y = orig(indices, weights)
result_x, result_y = alt(indices, weights)

# check that the result is the same
assert np.allclose(expected_x, result_x)
assert np.allclose(expected_y, result_y)

这是一个基准:

^{pr2}$

相关问题 更多 >