我正在用一种算法求解一个微分方程,但我在Python中的实现非常慢,我想知道为什么。该算法的工作原理如下
u0
)u1
,该例程在u0上进行了一些修改(移动函数并将其参数与一些基本函数合并)u2
等等,过程与u1
相同,但取决于u1
而不是u0
(等等)这应该给我一个函数,它为一些x
计算每个时间点1,2,…,N的解的值,为一些N。您可以在程序中找到注释中的单个步骤的解释。如果有什么不清楚的地方,我很乐意解释。(问题如下)
import numpy as np
import math
from scipy import interpolate, integrate, stats
# Some values that determine the Differential Equation
y = 0.024
o = 0.115
p = 0.35
#Defining the initial value, i.e. the function we start with
def u0(x):
return stats.beta.pdf(x, 2.7, 3.05)
#Determining the spatial and time grid on which I'd like the function to show me results
L = 1
N = 1000
dx = L / N
x = np.arange(0, L, dx)
t = np.linspace(0, 1, 1000)
dt = 0.001
#Simulating the Brownian motion on which the Differential equation depends
cov = np.zeros((len(t), len(t)))
B_i = np.zeros((len(t), len(t)))
for i in range(len(t)):
for j in range(i + 1):
B_i[i, j] = np.sqrt(dt)
mean = np.zeros_like(t)
B = np.dot(B_i, np.random.multivariate_normal(mean, np.identity(len(t))))
B = np.concatenate(([0], B), axis=-1)
#This is a shift operator that is used for the solution, it is part of the analytical solution and shifts the inserted parameter by another parameter a
def shift(func, a):
def shift(x):
return func(x - a)
return shift
#This is a scale operator that is used for the solution, it is part of the analytical solution and multiplies the function by some parameter
def scale(func, a=1):
return a*func
#This truncates a function (I only need results in the area of 0<x<1)
def trunc(func):
def trunc(x):
if x <= 0:
return 0
elif x >= 1:
return 0
else:
return func(x)
return trunc
#This is the quadrature routine
def quad(func, a, b):
return integrate.quad(func, a, b)
#Interpolation on the grid
def myinterpolate(func):
x = np.linspace(0, 1, 1000)
y = func(x[:])
return interpolate.interp1d(x, y, kind = 'cubic', fill_value="extrapolate")
#Modification function such that the Differential equation can be solved for the next time step
def gauss(func, t):
def gauss(x):
def pregau(z):
arg = x + t ** (1 / 2) * z
return func(arg) * math.exp(-(z**2)/2)
integration = integrate.quad(pregau, -np.inf, np.inf)[0]
return (2 * math.pi) ** (-1/2) * integration
return gauss
#Function that is used in each time step to calculate the solution for the new time step
def f_temp(ui, B):
def f_temp(x):
arg = o * np.sqrt(1 - p ** (2)) * dt
increment = o * p * B + y * dt
return shift(gauss(ui, arg), increment)(x)
return np.vectorize(f_temp)
#This is the function that gives me (for inital value u0) the solution at every time step for a specific x that I feed it
def vundl(x, u=u0):
v = [myinterpolate(u)]
vx = [myinterpolate(u)(x)[()]]
#In my case 1000, but it takes a lot of time to get to even go to i=100
for i in range(1, 1000):
vi = trunc(myinterpolate(f_temp(v[i-1], B[i] - B[i-1])))
vxi = vi(x)
v.append(vi)
vx.append(vxi)
return list(map(lambda x: x[()], vx))
print(vundl(0.01, u0))
正如您所看到的,我已经将f_temp
从循环中排除,这使它变得更快。之前我只能计算到I=10,因为它花费的时间太长,之后我可以计算到I=100。有没有办法让这个程序更快?我在编程方面不是很有经验,尤其是在Python方面,我花了很多工作来完成这项工作。我很感激你的建议
目前没有回答
相关问题 更多 >
编程相关推荐