如何在numba中加速dict查找任务?

2024-09-29 23:17:25 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个python函数,想用numba来加速它。最耗时的部分是在set/dict中搜索元组。有人能告诉我们如何解决这个问题吗?在

idNeg = np.array([[1,2,3], [4,5,6], ..., [7,8,9]])
validSet = {(1,2,3):True, (5,4,3):True, ..., (2,5,3):True}
@jit
def CalcNeg(idNeg, validSet):
    l = len(idNeg)
    for j in xrange(l):
        corplc = np.random.choice([0, 2])
        idNeg[j, corplc] = random.randrange(0, VE.shape[0])
        while validSet.has_key((idNeg[j, 0], idNeg[j, 1], idNeg[j, 2])):
            idNeg[j, corplc] = random.randrange(0, VE.shape[0])
    return idNeg

我试过这样做,但是速度没有改变,与没有@jit的代码相比。在


Tags: 函数truenpverandomdictjitshape
1条回答
网友
1楼 · 发布于 2024-09-29 23:17:25

我不是100%清楚所需输入(内容或典型形状)的确切性质,但是在Numba中获得良好性能提升的关键是能够在nopython模式下(与python对象模式相反)jit函数。最初的函数使用了数据结构,特别是dict,目前不支持。在

同样,我不知道具体的用例或者下面的修改是否有效,但是我使用了validSetdict并将它的keys转换为一个实集对象,其中key,value对中的值是True。在

例如:

import numpy as np
import numba as nb
import random

# Original function
def CalcNeg(idNeg, validSet, N):
    l = len(idNeg)
    for j in xrange(l):
        corplc = np.random.choice([0, 2])
        idNeg[j, corplc] = random.randrange(0, N)
        while validSet.has_key((idNeg[j, 0], idNeg[j, 1], idNeg[j, 2])):
            idNeg[j, corplc] = random.randrange(0, N)
    return idNeg

# Modified version, compiled in nopython mode (njit)
@nb.njit
def CalcNeg2(idNeg, validSet, N):
    l = len(idNeg)
    c = np.array([0,2])
    for j in xrange(l):
        corplc = np.random.choice(c)
        idNeg[j, corplc] = random.randrange(0, N)
        #while validSet.has_key((idNeg[j, 0], idNeg[j, 1], idNeg[j, 2])):
        while (idNeg[j, 0], idNeg[j, 1], idNeg[j, 2]) in validSet:
            idNeg[j, corplc] = random.randrange(0, N)
    return idNeg

# Some test data
N = 40
M = 2000
idNeg = np.random.random_integers(0, N, size=(M,3))
tmp = np.random.random_integers(0, N, size=(M,3))
validSet = {tuple(tmp[k,:]): True for k in xrange(tmp.shape[0])}

# convert validSet to real python set for keys with value == True
_validSet = {k for k,v in validSet.iteritems() if v is True}

现在,使用%timeit魔法,从ipython笔记本电脑中获得一些计时:

^{pr2}$

在我的机器上是18倍的加速。我用的是Numba 0.25。请注意,切换到原始python函数中的一个集合会产生一个小的差异,但更像是25%的加速。在

如果测试数据不切实际,或者将dict转换为集合不合适,请告诉我。如果没有更多的细节,很难判断如何解决这个问题。在

相关问题 更多 >

    热门问题