<p>我按照洛克林先生的建议实现了numba。结果在我的机器上快了4倍。你知道吗</p>
<h2>修改的Numba版本</h2>
<pre><code>import numba as nb
@nb.jit
def nb_compare(sample, competition_exp_dot, preferences):
sample_exp_dot = np.exp(preferences @ sample)
all_competitors = np.append(sample_exp_dot.reshape(-1, 1), competition_exp_dot, 1)
all_results = (all_competitors.T/all_competitors.sum(axis=1)).T
return np_mean(all_results, 0) # see source for np_mean in notes below
</code></pre>
<h2>可比Numpy版本</h2>
<pre><code>import numpy as np
def np_compare(sample, competition_exp_dot, preferences):
sample_exp_dot = np.exp(preferences @ sample)
all_competitors = np.append(sample_exp_dot.reshape(-1, 1), competition_exp_dot, 1)
all_results = (all_competitors.T/all_competitors.sum(axis=1)).T
return all_results.mean(axis=0)
</code></pre>
<h2>时间比较</h2>
<p>设置:</p>
<pre><code>preferences = np.random.random((1000,100)).astype(np.float32)
competition = np.array([np.random.randint(0,2,100), np.random.randint(0,2,100)]).astype(np.float32)
competition_exp_dot = np.exp(preferences @ competition.T)
sample = np.random.randint(0,2,100)
</code></pre>
<pre><code>%timeit np_compare(sample, competition_exp_dot, preferences)
"210 µs ± 13.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)"
%timeit -n 10000 nb_compare(population[0], competition_exp_dot, preferences)
"52.4 µs ± 4.48 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)"
</code></pre>
<h2>注意事项</h2>
<p>Numba不支持可选参数,比如axis fornp.平均值并返回一个打字机错误。在我的numba代码中,我使用了<code>np_mean</code>的callbelow版本。你知道吗</p>
<p>记入<a href="https://github.com/numba/numba/issues/1269#issuecomment-472574352" rel="nofollow noreferrer">joelrich</a></p>
<pre><code>import numba as nb, numpy as np
# fix to use np.mean along axis=0 (numba doesn't support optional arguments for np.mean)
# credit to: joelrich https://github.com/numba/numba/issues/1269#issuecomment-472574352
@nb.njit
def np_apply_along_axis(func1d, axis, arr):
assert arr.ndim == 2
assert axis in [0, 1]
if axis == 0:
result = np.empty(arr.shape[1])
for i in range(len(result)):
result[i] = func1d(arr[:, i])
else:
result = np.empty(arr.shape[0])
for i in range(len(result)):
result[i] = func1d(arr[i, :])
return result
@nb.njit
def np_mean(array, axis):
return np_apply_along_axis(np.mean, axis, array)
</code></pre>