使用numba快速计算余弦相似度发现了这个gist。在
import numba
@numba.jit(target='cpu', nopython=True)
def fast_cosine(u, v):
m = u.shape[0]
udotv = 0
u_norm = 0
v_norm = 0
for i in range(m):
if (np.isnan(u[i])) or (np.isnan(v[i])):
continue
udotv += u[i] * v[i]
u_norm += u[i] * u[i]
v_norm += v[i] * v[i]
u_norm = np.sqrt(u_norm)
v_norm = np.sqrt(v_norm)
if (u_norm == 0) or (v_norm == 0):
ratio = 1.0
else:
ratio = udotv / (u_norm * v_norm)
return ratio
结果看起来很有希望(500ns与200us相比,我的机器中没有jit decorator)。在
我想用numba将向量u
和候选矩阵M
之间的计算并行化,即每行的余弦。在
示例:
^{pr2}$一种方法是用第二个输入重写一个矩阵。但是如果我尝试迭代矩阵的行,就会得到一个NotImplementedError
。试着用切片。在
我想用vectorize
但我不能让它工作。在
解决方案稍微重写一下:
另一种方法:用numba生成一个通用的UFunc
相关问题 更多 >
编程相关推荐