有没有更干净的方法来实现多模型的曲线拟合?

2024-09-24 22:21:37 发布

您现在位置:Python中文网/ 问答频道 /正文

在我的项目中,我有多个预定义的函数族来拟合曲线。让我们看看最简单的:

def polyfit3(x, b0, b1, b2, b3):
    return b0+b1*x+b2*x**2+b3*x**3

def polyfit2(x, b0, b1, b2):
    return b0+b1*x+b2*x**2

def polyfit1(x, b0, b1):
    return b0+b1*x

注意: 我知道在这种特殊情况下,^{}会是一个更好的选择

(更简单)的函数,使配件看起来像这样:

from scipy.optimize import curve_fit
try:
    from lmfit import Model
    _has_lmfit = True
except ImportError:
    _has_lmfit = False

def f(x, y, order=3):
    if _has_lmfit:
        if order == 3:
            fitModel = Model(polyfit3)
            params = fitModel.make_params(b0=0, b1=1, b2=1, b3=1)
            result = fitModel.fit(y, x=x, params=params)
        elif order == 2:
            fitModel = Model(polyfit2)
            params = fitModel.make_params(b0=0, b1=1, b2=1)
            result = fitModel.fit(y, x=x, params=params)
        elif order == 1:
            fitModel = Model(polyfit1)
            params = fitModel.make_params(b0=0, b1=1)
            result = fitModel.fit(y, x=x, params=params)
        else:
            raise ValueError('Order is out of range, please select from [1, 3].')
    else:
        if order == 3:
            popt, pcov = curve_fit(polyfit3, x, y)
            _function = polyfit3
        elif order == 2:
            popt, pcov = curve_fit(polyfit2, x, y)
            _function = polyfit2
        elif order == 1:
            popt, pcov = curve_fit(polyfit1, x, y)
            _function = polyfit1
        else:
            raise ValueError('Order is out of range, please select from [1, 3].')
    # more code there.. mostly working with the optimized parameters, plotting, etc.

我的问题是这很快变得很难看,我一遍又一遍地重复我自己。有没有更好的办法?你知道吗

编辑:

我试过这个:

def poly_fit(x, *args):
    return sum(b*x**i for i, b in enumerate(args))

...

fitModel = Model(poly_fit)
fitModel.make_params(**{f'b{i}': 1 for i in range(order+1)})

但不幸的是,lmfit抛出了一个错误:

ValueError: varargs '*args' is not supported

Tags: frommodelreturndeforderparamsb0b2
2条回答

我通过为polyfit函数创建全局配置来重写代码。这是if更具python风格的版本。你知道吗

polyfits = {
    1: {
        'f': polyfit1,
        'params': ['b0', 'b1'],
        'vals'  : [  0,    1], 
    },
    2: {
        'f': polyfit2,
        'params': ['b0', 'b1', 'b2'],
        'vals'  : [   0,    1,   1,], 
    },
    3: {
        'f': polyfit3,
        'params': ['b0', 'b1', 'b2', 'b3'],
        'vals'  : [   0,    1,    1,    1], 
    },

}

def f(x, y, order=3):
    if order not in polyfits.keys():
        raise ValueError('Order is out of range, please select from {}.'.format(','.join(map(str, polyfits.keys()))))
    _function = polyfits[order]['f']
    if _has_lmfit:
        fitModel = Model(_function)
        params = dict(zip(polyfits[order]['params'], polyfits[order]['vals']))
        params = fitModel.make_params(**params)
        result = fitModel.fit(y, x=x, params=params)
    else:
        popt, pcov = curve_fit(_function, x, y)

我相信您发布了非常简化的代码版本(因为您当前的版本可以比我上面的代码更有效地最小化)。你知道吗

我认为lmfit.models.PolynomialModel()正是你想要的。该模型将多项式次数n作为参数,并使用名为c0c1cn(最多处理n=7)的系数:

from lmfit.models import PolynomialModel

def f(x, y, degree=3):
    fitModel = PolynomialModel(degree=degree)
    params = fitModel.make_params(c0=0, c1=1, c2=1, c3=0, 
                                  c4=0, c5=0, c6=0, c7=0)
    # or if you prefer to do it the hard way:
    params = fitModel.make_params(**{'c%d'%i:0 for i in range(degree+1)})

    return fitModel.fit(y, x=x, params=params)

请注意,这里可以过度指定系数。也就是说,如果degree=3,那么对fitModel.make_params(c0=0, ..., c7=0)的调用实际上不会为c4c5c6c7生成参数。你知道吗

PolynomialModel会产生一个TypeErrorifdegree > 7,所以我没有考虑你的显式测试。你知道吗

我希望这能让您开始学习,但看起来您可能也想包括其他模型函数。在这种情况下,我所做的是制作一本类名词典:

from lmfit.models import LinearModel, PolynomialModel, GaussianModel, ....

KnownModels = {'linear': LinearModel, 'polynomial': PolynomialModel, 
              'gaussian': GaussianModel, ...}

然后用它来构建模型:

modelchoice = 'linear' # probably really came from user selection in a GUI

if modelchoice in KnownModels:
    model = KnownModels[modelchoice]()
else:
    raise ValueError("unknown model '%s'" % modelchoice)

params = model.make_params(....) # <- might know and store what the parameter names are
.....

相关问题 更多 >