从SciKit FunctionTransformer继承

2024-09-30 02:27:07 发布

您现在位置:Python中文网/ 问答频道 /正文

我想使用FunctionTransformer,同时提供一个简单的API并隐藏额外的细节。具体来说,我希望能够提供一个Custom_Trans类,如下所示。因此,用户应该能够使用当前失败的trans2,而不是trans1,因为它工作正常:

from sklearn import preprocessing 
from sklearn.pipeline import Pipeline
from sklearn import model_selection
from sklearn.linear_model import LinearRegression
from sklearn.datasets import make_regression
import numpy as np

X, y = make_regression(n_samples=100, n_features=1, noise=0.1)

def func(X, a, b):
    return X[:,a:b]

class Custom_Trans(preprocessing.FunctionTransformer):
    def __init__(self, ind0, ind1):
        super().__init__(
            func=func,
            kw_args={
                "a": ind0,
                "b": ind1
            }
        )

trans1 = preprocessing.FunctionTransformer( 
    func=func,
    kw_args={
        "a": 0,
        "b": 50
    }
)

trans2 = Custom_Trans(0,50)

pipe1 = Pipeline(
    steps=[
           ('custom', trans1),
           ('linear', LinearRegression())
         ]
)

pipe2 = Pipeline(
    steps=[
           ('custom', trans2),
           ('linear', LinearRegression())
          ]
)

print(model_selection.cross_val_score(
    pipe1, X, y, cv=3,)
)

print(model_selection.cross_val_score(
    pipe2, X, y, cv=3,)
)

这就是我得到的:

[0.99999331 0.99999671 0.99999772]
...sklearn/base.py:209: FutureWarning: From version 0.24, get_params will raise an
AttributeError if a parameter cannot be retrieved as an instance attribute. 
Previously it would return None.
warnings.warn('From version 0.24, get_params will raise an '
...
[0.99999331 0.99999671 0.99999772]

我知道这和估计器克隆有关,但我不知道如何修复它。例如this post

there should be no logic, not even input validation, in an estimators init. The logic should be put where the parameters are used, which is typically in fit

但是在这种情况下,我需要将参数传递给超类。无法将逻辑放在fit()中。 我能做什么


Tags: fromimportantransmodelpipelinecustomsklearn
1条回答
网友
1楼 · 发布于 2024-09-30 02:27:07

您可以通过从BaseEstimator继承来获取“get_params”

class FunctionTransformer(BaseEstimator, TransformerMixin)

How to pass parameters to the customize modeltransformer class

inherit from function_transformer

custom transformers

你在基地里有这个:

def get_params(self, deep=True):
        """
        Get parameters for this estimator.
        Parameters
             
        deep : bool, default=True
            If True, will return the parameters for this estimator and
            contained subobjects that are estimators.
        Returns

更改您的代码:

trans1 = dict(
    functiontransformer__kw_args=[
        {'ind0': None},
        {'ind0': [1]}
    ]
)

class Custom_Trans(preprocessing.FunctionTransformer): 
    def __init__(self, ind0, ind1, deep=True): 
        super().__init__( func=func, kw_args={ "a": ind0, "b": ind1 } ) 
        self.ind0 = ind0
        self.ind1 = ind1
        self.deep = True 

相关问题 更多 >

    热门问题