擅长:python、mysql、java
<p>问题的出现是因为您正在将python列表传递给numpy函数。如果将numpy数组作为参数传递,numpy函数的速度要快得多。在</p>
<pre><code>#Create numpy numbers
nptest = np.random.uniform(size=(10000, 10))
#Create a native python list
listtest = list(nptest)
#Compare performance
%timeit np.min(nptest, axis=0)
%timeit np.min(listtest, axis=0)
</code></pre>
<p>输出</p>
^{pr2}$
<p>编辑:增加了如何在网格上计算成本函数的示例。在</p>
<p>下面计算网格上的二次成本函数,然后沿第一个轴取最小值。尤其是,<code>np.meshgrid</code>是你的朋友。在</p>
<pre><code>def cost_function(x, y):
return x ** 2 + y ** 2
x = linspace(-1, 1)
y = linspace(-1, 1)
def eval_python(x, y):
matrix = [cost_function(_x, _y) for _x in x for _y in y]
return np.min(matrix, axis=0)
def eval_numpy(x, y):
xx, yy = np.meshgrid(x, y)
matrix = cost_function(xx, yy)
return np.min(matrix, axis=0)
%timeit eval_python(x, y)
%timeit eval_numpy(x, y)
</code></pre>
<p>输出
100个回路,最好每回路3:13.9ms
10000个环路,最好每环路3:136µs</p>
<p>最后,如果您不能将问题转换成这个表单,您可以预先分配内存,然后填充每个元素。在</p>
<pre><code>matrix = np.empty((num_x, num_y))
for i in range(num_x):
for j in range(num_y):
matrix[i, j] = cost_function(i, j)
</code></pre>