如何嵌套LabelKFold?

2024-09-30 10:42:34 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个有大约300个点和32个不同标签的数据集,我想通过使用网格搜索和LabelKFold验证绘制其学习曲线来评估LinearSVR模型。在

我的代码是这样的:

import numpy as np
from sklearn import preprocessing
from sklearn.svm import LinearSVR
from sklearn.pipeline import Pipeline
from sklearn.cross_validation import LabelKFold
from sklearn.grid_search import GridSearchCV
from sklearn.learning_curve import learning_curve
    ...
#get data (x, y, labels)
    ...
C_space = np.logspace(-3, 3, 10)
epsilon_space = np.logspace(-3, 3, 10)  

svr_estimator = Pipeline([
    ("scale", preprocessing.StandardScaler()),
    ("svr", LinearSVR),
])

search_params = dict(
    svr__C = C_space,
    svr__epsilon = epsilon_space
)

kfold = LabelKFold(labels, 5)

svr_search = GridSearchCV(svr_estimator, param_grid = search_params, cv = ???)

train_space = np.linspace(.5, 1, 10)
train_sizes, train_scores, valid_scores = learning_curve(svr_search, x, y, train_sizes = train_space, cv = ???, n_jobs = 4)
    ...
#plot learning curve

我的问题是如何为网格搜索和学习曲线设置cv属性,以便它将我的原始集分解为不共享任何用于计算学习曲线的标签的训练集和测试集。然后从这些训练集中,进一步将它们分成训练集和测试集,而不共享网格搜索的标签?在

本质上,如何运行嵌套LabelKFold?在


我是为这个问题创建悬赏的用户,使用sklearn提供的数据编写了以下可复制的示例。在

^{pr2}$

Tags: fromimport网格searchnptrainspace标签
1条回答
网友
1楼 · 发布于 2024-09-30 10:42:34

从您的问题中,您正在查找数据上的LabelKFold分数,同时在外部LabelKFold的每次迭代中网格搜索管道的参数,同时再次使用LabelKFold。虽然我无法实现开箱即用,但只需要一个循环:

outer_cv = LabelKFold(labels=strata, n_folds=3)
strata = np.array(strata)
scores = []
for outer_train, outer_test in outer_cv:
    print "Outer set. Train:", set(strata[outer_train]), "\tTest:", set(strata[outer_test])
    inner_cv = LabelKFold(labels=strata[outer_train], n_folds=3)
    print "\tInner:"
    for inner_train, inner_test in inner_cv:
        print "\t\tTrain:", set(strata[outer_train][inner_train]), "\tTest:", set(strata[outer_train][inner_test])
    clf = GridSearchCV(estimator=toy_rf, param_grid=tuned_par, scoring=roc_auc_scorer, cv= inner_cv, n_jobs=1)
    clf.fit(X[outer_train],Z[outer_train])
    scores.append(clf.score(X[outer_test], Z[outer_test]))

运行代码时,第一次迭代生成:

^{pr2}$

因此,很容易验证它是否按预期执行。您的交叉验证分数在列表scores中,您可以轻松地处理它们。我使用了您在上一段代码中定义的变量,例如strata。在

相关问题 更多 >

    热门问题