擅长:python、mysql、java
<p>我最近在深度学习方面也遇到了同样的问题(斯坦福cs231n,作业1),但是当我使用</p>
<pre><code> np.sqrt((np.square(a[:,np.newaxis]-b).sum(axis=2)))
</code></pre>
<p>有个错误</p>
<pre><code>MemoryError
</code></pre>
<p>这意味着我的内存不足(事实上,这中间产生了一个500*5000*1024的数组,太大了!)</p>
<p>为了防止这种错误,我们可以使用一个公式来简化:</p>
<p><img src="https://latex.codecogs.com/gif.latex?(a-b)%5E2&space;=&space;a%5E2&space;-&space;2ab&plus;b%5E2" title="(a-b)^2 = a^2 - 2ab+b^2"/></p>
<p>代码:</p>
<pre><code>import numpy as np
aSumSquare = np.sum(np.square(a),axis=1);
bSumSquare = np.sum(np.square(b),axis=1);
mul = np.dot(a,b.T);
dists = np.sqrt(aSumSquare[:,np.newaxis]+bSumSquare-2*mul)
</code></pre>