我想使用FunctionTransformer
,同时提供一个简单的API并隐藏额外的细节。具体来说,我希望能够提供一个Custom_Trans
类,如下所示。因此,用户应该能够使用当前失败的trans2
,而不是trans1
,因为它工作正常:
from sklearn import preprocessing
from sklearn.pipeline import Pipeline
from sklearn import model_selection
from sklearn.linear_model import LinearRegression
from sklearn.datasets import make_regression
import numpy as np
X, y = make_regression(n_samples=100, n_features=1, noise=0.1)
def func(X, a, b):
return X[:,a:b]
class Custom_Trans(preprocessing.FunctionTransformer):
def __init__(self, ind0, ind1):
super().__init__(
func=func,
kw_args={
"a": ind0,
"b": ind1
}
)
trans1 = preprocessing.FunctionTransformer(
func=func,
kw_args={
"a": 0,
"b": 50
}
)
trans2 = Custom_Trans(0,50)
pipe1 = Pipeline(
steps=[
('custom', trans1),
('linear', LinearRegression())
]
)
pipe2 = Pipeline(
steps=[
('custom', trans2),
('linear', LinearRegression())
]
)
print(model_selection.cross_val_score(
pipe1, X, y, cv=3,)
)
print(model_selection.cross_val_score(
pipe2, X, y, cv=3,)
)
这就是我得到的:
[0.99999331 0.99999671 0.99999772]
...sklearn/base.py:209: FutureWarning: From version 0.24, get_params will raise an
AttributeError if a parameter cannot be retrieved as an instance attribute.
Previously it would return None.
warnings.warn('From version 0.24, get_params will raise an '
...
[0.99999331 0.99999671 0.99999772]
我知道这和估计器克隆有关,但我不知道如何修复它。例如this post说
there should be no logic, not even input validation, in an estimators init. The logic should be put where the parameters are used, which is typically in fit
但是在这种情况下,我需要将参数传递给超类。无法将逻辑放在fit()
中。
我能做什么
您可以通过从BaseEstimator继承来获取“get_params”
How to pass parameters to the customize modeltransformer class
inherit from function_transformer
custom transformers
你在基地里有这个:
更改您的代码:
相关问题 更多 >
编程相关推荐