Python中移动矩阵行的最快方法

2024-06-24 12:50:10 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个4x4矩阵,如下所示:

1  2  3  4
5  6  7  8
9  10 11 12
13 14 15 16

我想根据行索引的数量向左移动每一行(左循环移动)。即第0行保持原样,第1行向左移动1,第2行向左移动2,等等

所以我们得到这个:

1  2  3  4
6  7  8  5
11 12 9  10
16 13 14 15

我在Python中想出的最快的方法是:

import numpy as np
def ShiftRows(x):
    x[1:] = [np.append(x[i][i:], x[i][:i]) for i in range(1, 4)]
    return x

我需要像这样在数千个4x4矩阵上运行这个函数,所以速度很重要(在Python中尽可能)。我不关心使用其他模块,比如numpy,我只关心速度

任何帮助都将不胜感激

谢谢大家!


Tags: 方法inimportnumpyfor数量defas
3条回答

首先改进,摆脱列表理解

我假设您的输入将始终是4x4 Ndaray。如果没有,您需要适当地修改函数(即添加np.asarray,检查维度等)。删除列表理解已经提供了很好的加速效果:

import numpy as np

a = np.arange(16).reshape(4, 4)

def ShiftRows(x):
    x[1:] = [np.append(x[i][i:], x[i][:i]) for i in range(1, 4)]
    return x

def shift(x):
    for i in range(1, 4):
        x[i] = np.append(x[i, i:], x[i, :i])
    return x

%timeit ShiftRows(a)
# 38.6 µs ± 1.08 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit shift(a)
# 31.9 µs ± 583 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

请记住,这两种变体都会在适当的位置修改阵列。如果这不是您想要的,请在两个函数的开头添加x = x.copy()

从我的测试来看numpy.roll比这两个版本都慢得多

第二个改进,使用^{}

现在,当我们使用^{}时,真正的加速是:

import numba

@numba.njit
def shift_numba(x):
    for i in range(1, 4):
        x[i] = np.append(x[i, i:], x[i, :i])
    return x    

%timeit shift_numba(a)
# 2.5 µs ± 115 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

这比现在快了15倍。使用parallel模式不会提高性能,可能是因为阵列的大小很小


测试:展开循环

应Patrick Artner的要求,我展开了循环(很可能是4x4):

@numba.njit
def shift_numba_unrolled(x):
    x[1] = np.append(x[1, 1:], x[1, :1])
    x[2] = np.append(x[2, 2:], x[2, :2])
    x[3] = np.append(x[3, 3:], x[3, :3])
    return x

%timeit shift_numba_unrolled(a)
# 2.49 µs ± 85 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

展开似乎会产生相同的结果


编辑:修复了一个大问题,现在加速比要小得多

如果您不介意硬编码数组大小,在我的测试中,硬编码重排模式的速度大约是硬编码的6倍:

def rot(a):
    return a.take((0,1,2,3,5,6,7,4,10,11,8,9,15,12,13,14)).reshape(4, 4)

这项工作:

import numpy as np


def stepped_roll(arr):
    return np.array([np.roll(row, -n) for n, row in enumerate(arr)])


print(stepped_roll(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])))

我倾向于使用np.roll,因为numpy例程往往比Python更快^不幸的是,{}在这里不起作用,因为您需要每行的索引

然而,在您的例子中,操作非常简单,数据集非常小,像@jancristofterasa的答案中建议的shift()函数会更快

相关问题 更多 >