<p><strong>Numpy实现</strong>可以利用<a href="https://docs.scipy.org/doc/numpy/reference/generated/numpy.argwhere.html" rel="nofollow noreferrer">numpy.argwhere</a>检索值索引,使用<a href="https://docs.scipy.org/doc/numpy/reference/generated/numpy.ix_.html" rel="nofollow noreferrer">numpy.ix_</a>创建索引网格,最后应用<a href="https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.ravel.html" rel="nofollow noreferrer">numpy.narray.ravel</a>方法将数组展开:</p>
<pre class="lang-py prettyprint-override"><code>import numpy as np
n = 5
grid = np.arange(n**2).reshape(n,n)[::-1]
def celdas_vecinas_np(grid, v, n):
x, y = np.argwhere(grid == v)[0]
idx = np.arange(x-1, x+2) %n
idy = np.arange(y-1, y+2) %n
return grid[np.ix_(idx, idy)].ravel()
celdas_vecinas_np(grid, 24, n)
array([ 3, 4, 0, 23, 24, 20, 18, 19, 15])
</code></pre>
<p>另一方面,对于<strong>Numba实现</strong>我们不能使用<code>numpy.argwhere</code>,但我们可以使用<a href="https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html" rel="nofollow noreferrer">numpy.where</a>来获取索引。一旦我们这样做,就只需要在正确的范围内循环,即:</p>
<pre class="lang-py prettyprint-override"><code>from numba import njit
@njit
def celdas_vecinas_numba(grid, v, n):
x, y = np.where(grid == v)
x, y = x[0], y[0]
result = []
for ix in range(x-1, x+2):
for iy in range(y-1, y+2):
result.append(grid[ix%n, iy%n])
return result
celdas_vecinas_numba(grid, 24, n)
[3, 4, 0, 23, 24, 20, 18, 19, 15]
</code></pre>
<p><strong>性能比较</strong>使用如此小的网格,numba在我的本地机器上的运行速度已经快了约20倍:</p>
<pre><code>%timeit celdas_vecinas_np(grid, 24, 5)
38 µs ± 1.62 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit celdas_vecinas_numba(grid, 24, n)
1.81 µs ± 93.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
</code></pre>