我试图为值函数迭代编写一个程序,我想利用numba库中的nopython模式。下面的代码并没有真正起到任何作用(我想从头开始理解我在哪里犯了错误)。它应该只返回我在函数中创建的矩阵。函数中的输入是优化所需的,稍后我将进行优化。然而,我面临的错误(请参阅下文)
我试着使用@njit和@jit,其中包括我为每个输入使用的特定类型。然而,两者都不起作用
import numpy as np
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d
import time
from datetime import timedelta
from numba import jit, njit, int32, float64
from gridlookup import gridlookup
beta = 0.99322
sigma = 1.5
enum = 2
egrid = np.array([0.1, 1.0])
pie = np.array([[0.5, 0.5],
[0.075, 0.925]])
blow = -2.0; bhigh = 4.0; bnum = 10
bgrid = np.logspace(np.log(blow + -1.0*blow + 1.0)/np.log(10.0), np.log(bhigh + -1.0*blow + 1.0)/np.log(10.0), bnum)
bgrid = bgrid + np.ones(np.shape(bgrid))*(blow - 1.0)
mubgnum = 1000
mubgrid = np.linspace(blow, bhigh, mubgnum)
v0 = np.array(np.zeros((enum,bnum)))
@njit
def vfini(bnum, enum, bgrid, egrid, v0):
## calculate the initial value function.
for i in range(bnum):
bval = bgrid[i]
for m in range(enum):
eval0 = egrid[m]
yval = 0.025*bval + eval0
v0[m,i] = (yval**(1-sigma))/(1.0-sigma)
return v0
v0 = vfini(bnum, enum, bgrid, egrid, v0)
@jit([(int32, float64[:], float64[:,:], float64, float64, int32,
float64[:], int32, float64[:], float64, float64[:,:],float64)],nopython=True)
def huggettqegm(enum, egrid, pie, beta, sigma, bnum,
bgrid, mubgnum, mubgrid, precision, v0, q):
v_I = np.array(np.zeros((enum,bnum)))
g_I = np.array(np.zeros((enum,bnum)))
tv_I = np.array(np.zeros((enum,bnum)))
tg_I = np.array(np.zeros((enum,bnum)))
return v_I, g_I
q = qlow
v, g = huggettqegm(enum, egrid, pie, beta, sigma, bnum, bgrid, mubgnum, mubgrid, precision, v0, q)
下面是我运行上面的简单代码时的错误消息:
TypingError: Invalid use of Function() with argument(s) of type(s): (array(float64, 2d, C)) * parameterized In definition 0: TypingError: array(float64, 2d, C) not allowed in a homogeneous sequence raised from C:\Anaconda3\lib\site-packages\numba\typing\npydecl.py:460 In definition 1: TypingError: array(float64, 2d, C) not allowed in a homogeneous sequence raised from C:\Anaconda3\lib\site-packages\numba\typing\npydecl.py:460 This error is usually caused by passing an argument of a type that is unsupported by the named function. [1] During: resolving callee type: Function() [2] During: typing of call at C:/Users/Jung Hwan Kim/Dropbox/StudentDebtCrisis/Program/Python/July012019/practice.py (124)
我想我可以使用numpy来创建数组并在njit模式下使用它们。或者我误解了njit的基本用法。。我非常感谢你的帮助。非常感谢
目前没有回答
相关问题 更多 >
编程相关推荐