<h2>首先改进,摆脱列表理解</h2>
<p>我假设您的输入将始终是4x4 Ndaray。如果没有,您需要适当地修改函数(即添加<code>np.asarray</code>,检查维度等)。删除列表理解已经提供了很好的加速效果:</p>
<pre><code>import numpy as np
a = np.arange(16).reshape(4, 4)
def ShiftRows(x):
x[1:] = [np.append(x[i][i:], x[i][:i]) for i in range(1, 4)]
return x
def shift(x):
for i in range(1, 4):
x[i] = np.append(x[i, i:], x[i, :i])
return x
%timeit ShiftRows(a)
# 38.6 µs ± 1.08 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit shift(a)
# 31.9 µs ± 583 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
</code></pre>
<p>请记住,这两种变体都会在适当的位置修改阵列。如果这不是您想要的,请在两个函数的开头添加<code>x = x.copy()</code></p>
<p>从我的测试来看<code>numpy.roll</code>比这两个版本都慢得多</p>
<h2>第二个改进,使用<a href="https://numba.pydata.org/" rel="nofollow noreferrer">^{<cd4>}</a></h2>
<p>现在,当我们使用<a href="https://numba.pydata.org/" rel="nofollow noreferrer">^{<cd4>}</a>时,真正的加速是:</p>
<pre><code>import numba
@numba.njit
def shift_numba(x):
for i in range(1, 4):
x[i] = np.append(x[i, i:], x[i, :i])
return x
%timeit shift_numba(a)
# 2.5 µs ± 115 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
</code></pre>
<p>这比现在快了15倍。使用<code>parallel</code>模式不会提高性能,可能是因为阵列的大小很小</p>
<hr/>
<h2>测试:展开循环</h2>
<p>应Patrick Artner的要求,我展开了循环(很可能是4x4):</p>
<pre><code>@numba.njit
def shift_numba_unrolled(x):
x[1] = np.append(x[1, 1:], x[1, :1])
x[2] = np.append(x[2, 2:], x[2, :2])
x[3] = np.append(x[3, 3:], x[3, :3])
return x
%timeit shift_numba_unrolled(a)
# 2.49 µs ± 85 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
</code></pre>
<p>展开似乎会产生相同的结果</p>
<hr/>
<p>编辑:修复了一个大问题,现在加速比要小得多</p>