如何在Python中正确使用njit/jit?

2024-10-03 11:17:09 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图为值函数迭代编写一个程序,我想利用numba库中的nopython模式。下面的代码并没有真正起到任何作用(我想从头开始理解我在哪里犯了错误)。它应该只返回我在函数中创建的矩阵。函数中的输入是优化所需的,稍后我将进行优化。然而,我面临的错误(请参阅下文)

我试着使用@njit和@jit,其中包括我为每个输入使用的特定类型。然而,两者都不起作用

import numpy as np
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d
import time
from datetime import timedelta
from numba import jit, njit, int32, float64
from gridlookup import gridlookup


beta = 0.99322 
sigma = 1.5
enum = 2
egrid = np.array([0.1, 1.0])
pie = np.array([[0.5, 0.5], 
                [0.075, 0.925]])
blow = -2.0; bhigh = 4.0; bnum = 10
bgrid = np.logspace(np.log(blow + -1.0*blow + 1.0)/np.log(10.0), np.log(bhigh + -1.0*blow + 1.0)/np.log(10.0), bnum)
bgrid = bgrid + np.ones(np.shape(bgrid))*(blow - 1.0)
mubgnum = 1000
mubgrid = np.linspace(blow, bhigh, mubgnum)

v0 = np.array(np.zeros((enum,bnum)))

@njit
def vfini(bnum, enum, bgrid, egrid, v0):
    ## calculate the initial value function.
    for i in range(bnum):
        bval = bgrid[i]
        for m in range(enum):
            eval0 = egrid[m]
            yval = 0.025*bval + eval0
            v0[m,i] = (yval**(1-sigma))/(1.0-sigma)

    return v0

v0 = vfini(bnum, enum, bgrid, egrid, v0)

@jit([(int32, float64[:], float64[:,:], float64, float64, int32,
        float64[:], int32, float64[:], float64, float64[:,:],float64)],nopython=True)
def huggettqegm(enum, egrid, pie, beta, sigma, bnum, 
                bgrid, mubgnum, mubgrid, precision, v0, q):

    v_I = np.array(np.zeros((enum,bnum)))
    g_I = np.array(np.zeros((enum,bnum)))
    tv_I = np.array(np.zeros((enum,bnum))) 
    tg_I = np.array(np.zeros((enum,bnum)))

    return v_I, g_I

q = qlow

v, g = huggettqegm(enum, egrid, pie, beta, sigma, bnum, bgrid, mubgnum, mubgrid, precision, v0, q)

下面是我运行上面的简单代码时的错误消息:

TypingError: Invalid use of Function() with argument(s) of type(s): (array(float64, 2d, C)) * parameterized In definition 0: TypingError: array(float64, 2d, C) not allowed in a homogeneous sequence raised from C:\Anaconda3\lib\site-packages\numba\typing\npydecl.py:460 In definition 1: TypingError: array(float64, 2d, C) not allowed in a homogeneous sequence raised from C:\Anaconda3\lib\site-packages\numba\typing\npydecl.py:460 This error is usually caused by passing an argument of a type that is unsupported by the named function. [1] During: resolving callee type: Function() [2] During: typing of call at C:/Users/Jung Hwan Kim/Dropbox/StudentDebtCrisis/Program/Python/July012019/practice.py (124)

我想我可以使用numpy来创建数组并在njit模式下使用它们。或者我误解了njit的基本用法。。我非常感谢你的帮助。非常感谢


Tags: fromimportnpzerosenumarraysigmafloat64