提高优化问题的速度

2024-05-06 01:58:20 发布

您现在位置:Python中文网/ 问答频道 /正文

我有下面的应用程序,优化以下pb。代码可以工作,但我发现它有点慢。有没有什么改进性能的方法(不用编写c代码)来更好地使用python、numpy和scipy?看来插值函数是最耗时的部分。在

from scipy.optimize import leastsq
from scipy.interpolate import interp1d
import timeit


class Bond(object):
  def __init__(self, years, cpn):
    self.years = years 
    self.coupon = cpn
    self.cashflows = [(0.0, -1.0)]
    self.cashflows.extend([(float(i),self.coupon) for i in range(1,self.years)])
    self.cashflows.append((float(self.years), 1.0 + self.coupon))

  def pv(self, market):
    return sum([cf[1] * market.df(cf[0]) for cf in self.cashflows])

class Market(object):
  def __init__(self, instruments):
    self.instruments = sorted(
        instruments, key=lambda instrument : instrument.cashflows[-1][0])
    self.knots = [0.0]
    self.knots.extend([inst.cashflows[-1][0] for inst in self.instruments])
    self.dfs = [1.0]
    self.dfs.extend([1.0] * len(self.instruments))
    self.interp = interp1d(self.knots, self.dfs)

  def df(self, day):
    return self.interp(day)

  def calibrate(self):
    leastsq(self.__target, self.dfs[1:])

  def __target(self, x):
    self.dfs[1:] = x
    self.interp = interp1d(self.knots, self.dfs)
    return [bond.pv(self) for bond in self.instruments]


def main():
  instruments = [Bond(i, 0.02) for i in xrange(1, numberOfInstruments + 1)]
  market = Market(instruments)
  market.calibrate()
  print('CALIBRATED')

numberOfTimes = 10
numberOfInstruments = 50
print('%.2f' % float(timeit.timeit(main, number=numberOfTimes)/numberOfTimes))

Tags: inimportselffordefscipymarketcoupon
2条回答

您应该尝试将求和和和与插值例程的调用矢量化。例如,像这样:

import numpy as np

class Bond(object):
  def __init__(self, years, cpn):
    self.years = years
    self.coupon = cpn

    self.cashflows = np.zeros((self.years + 1, 2))
    self.cashflows[:,0] = np.arange(self.years + 1)
    self.cashflows[:,1] = self.coupon
    self.cashflows[0,:] = 0, -1
    self.cashflows[-1,:] = self.years, 1.0 + self.coupon

  def pv(self, market):
    return (self.cashflows[:,1] * market.df(self.cashflows[:,0])).sum()

这似乎能提高10倍的加速。您还可以用类似的方法将Market中的knots和{}列表替换为数组。在

重新校准需要时间的原因是leastsq必须再次验证它是否真的处于局部最小值。这需要对目标函数进行数值微分,这需要时间,因为有很多自由变量。优化问题相当容易,所以只需几个步骤就可以收敛,这意味着验证最小值所需的时间几乎与解决问题所需的时间相同。在

@pv的答案很可能是正确的,但是this answer提供了一个简单的方法来确定,并表明您是否可以做进一步的工作。在

相关问题 更多 >