我在一个球体上有两组点,在下面的代码示例中分别标记为“obj”和“ps”。我想确定所有的‘obj’点,这些点距离‘ps’点一定的角度距离。在
我的想法是用三维单位向量表示每个点,并将它们的点积与cos(最大分离)进行比较。使用numpy广播可以很容易地做到这一点,但在我的应用程序中,我有n_obj~500000和n_ps~50000,因此广播的内存需求太大。下面我用numba粘贴了我当前的拍摄。能否进一步优化?在
from numba import jit
import numpy as np
from sklearn.preprocessing import normalize
def gen_points(n):
"""
generate random 3D unit vectors (not uniform, but irrelevant here)
"""
vec = 2*np.random.rand(n,3)-1.
vec_norm = normalize(vec)
return vec_norm
#@jit(nopython=True)
@jit
def angdist_threshold_numba(vec_obj,vec_ps,cos_maxsep):
"""
finds obj that are closer than maxsep to a ps
"""
nps = len(vec_ps)
nobj = len(vec_obj)
#closeobj_all = []
closeobj_all = np.empty(0)
dotprod = np.empty(nobj)
a = np.arange(nobj)
for ps in range(nps):
np.sum(vec_obj*vec_ps[ps],axis=1,out=dotprod)
#closeobj_all.extend(a[dotprod > cos_maxsep])
closeobj_all = np.append(closeobj_all, a[dotprod > cos_maxsep])
return closeobj_all
vec_obj = gen_points(50000) #in reality ~500,000
vec_ps = gen_points(5000) #in reality ~50,000
cos_maxsep = np.cos(0.003)
closeobj_all = np.unique(angdist_threshold_numba(vec_obj,vec_ps,cos_maxsep))
这是使用代码中给出的测试用例的性能:
^{pr2}$我试着用
@jit(nopython=True)
但是这失败了
NotImplementedError: Failed at nopython (nopython frontend)
(<class 'numba.ir.Expr'>, build_list(items=[]))
编辑:在numba更新到0.26之后,即使在python模式下,创建空列表也会失败。可以通过将其替换为np.空(0),然后使用np.追加(),见上文。这几乎不会改变性能。在
根据https://github.com/numba/numba/issues/858np.空()现在在nopython模式下受支持,但是我仍然不能用@jit运行它(nopython=True):
TypingError: Internal error at <numba.typeinfer.CallConstraint object at 0x7ff3114a9310>
与
list.append
不同的是,永远不要在循环中调用numpy.append
!这是因为即使是附加一个元素,也需要复制整个数组。因为您只对唯一的obj
感兴趣,所以可以使用布尔数组来标记到目前为止找到的匹配项。在至于Numba,最好是写出所有的循环。例如:
另一个好处是,一旦找到与当前
obj
匹配的匹配项,我们就可以打破ps
数组的循环。在通过专门化三维空间的函数,可以获得更快的速度。此外,由于某些原因,将所有数组和相关维度传递到helper函数会导致另一个加速:
^{pr2}$我得到的20000
obj
和2000ps
的时间安排:相关问题 更多 >
编程相关推荐