sklearn中StratifiedKFold与StratifiedShuffleSplit的区别

2024-09-30 20:37:00 发布

您现在位置:Python中文网/ 问答频道 /正文

从标题上看,我想知道

StratifiedKFold参数为shuffle=True

StratifiedKFold(n_splits=10, shuffle=True, random_state=0)

以及

StratifiedShuffleSplit

StratifiedShuffleSplit(n_splits=10, test_size=’default’, train_size=None, random_state=0)

使用StratifiedShuffleSplit有什么好处


Tags: testnonetruedefault标题参数sizetrain
3条回答

图示: output examples of KFold, StratifiedKFold, StratifiedShuffleSplit (如何在此窗口中显示此图片?)

上面的图形表示基于Ken Syme的代码:

from sklearn.model_selection import KFold, StratifiedKFold, StratifiedShuffleSplit
SEED = 43
SPLIT = 3

X_train = [0,1,2,3,4,5,6,7,8]
y_train = [0,0,0,0,0,0,1,1,1]   # note 6,7,8 are labelled class '1'

print("KFold, shuffle=False (default)")
kf = KFold(n_splits=SPLIT, random_state=SEED)
for train_index, test_index in kf.split(X_train, y_train):
    print("TRAIN:", train_index, "TEST:", test_index)

print("KFold, shuffle=True")
kf = KFold(n_splits=SPLIT, shuffle=True, random_state=SEED)
for train_index, test_index in kf.split(X_train, y_train):
    print("TRAIN:", train_index, "TEST:", test_index)

print("\nStratifiedKFold, shuffle=False (default)")
skf = StratifiedKFold(n_splits=SPLIT, random_state=SEED)
for train_index, test_index in skf.split(X_train, y_train):
    print("TRAIN:", train_index, "TEST:", test_index)

print("StratifiedKFold, shuffle=True")
skf = StratifiedKFold(n_splits=SPLIT, shuffle=True, random_state=SEED)
for train_index, test_index in skf.split(X_train, y_train):
    print("TRAIN:", train_index, "TEST:", test_index)

print("\nStratifiedShuffleSplit")
sss = StratifiedShuffleSplit(n_splits=SPLIT, random_state=SEED, test_size=3)
for train_index, test_index in sss.split(X_train, y_train):
    print("TRAIN:", train_index, "TEST:", test_index)

print("\nStratifiedShuffleSplit (can customise test_size)")
sss = StratifiedShuffleSplit(n_splits=SPLIT, random_state=SEED, test_size=2)
for train_index, test_index in sss.split(X_train, y_train):
    print("TRAIN:", train_index, "TEST:", test_index)

在Kfold中,即使使用shuffle,每个测试集也不应该重叠。使用KFolds和shuffle,数据在开始时被shuffle一次,然后被分割成所需的分割数。测试数据总是其中的一部分,列车数据是其余部分。

在ShuffleSplit中,数据每次都会被洗牌,然后被分割。这意味着测试集可能在拆分之间重叠。

有关区别的示例,请参见此块。注意ShuffleSplit测试集中元素的重叠。

splits = 5

tx = range(10)
ty = [0] * 5 + [1] * 5

from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold
from sklearn import datasets

kfold = StratifiedKFold(n_splits=splits, shuffle=True, random_state=42)
shufflesplit = StratifiedShuffleSplit(n_splits=splits, random_state=42, test_size=2)

print("KFold")
for train_index, test_index in kfold.split(tx, ty):
    print("TRAIN:", train_index, "TEST:", test_index)

print("Shuffle Split")
for train_index, test_index in shufflesplit.split(tx, ty):
    print("TRAIN:", train_index, "TEST:", test_index)

输出:

KFold
TRAIN: [0 2 3 4 5 6 7 9] TEST: [1 8]
TRAIN: [0 1 2 3 5 7 8 9] TEST: [4 6]
TRAIN: [0 1 3 4 5 6 8 9] TEST: [2 7]
TRAIN: [1 2 3 4 6 7 8 9] TEST: [0 5]
TRAIN: [0 1 2 4 5 6 7 8] TEST: [3 9]
Shuffle Split
TRAIN: [8 4 1 0 6 5 7 2] TEST: [3 9]
TRAIN: [7 0 3 9 4 5 1 6] TEST: [8 2]
TRAIN: [1 2 5 6 4 8 9 0] TEST: [3 7]
TRAIN: [4 6 7 8 3 5 1 2] TEST: [9 0]
TRAIN: [7 2 6 5 4 3 0 9] TEST: [1 8]

至于何时使用它们,我倾向于使用kfold进行交叉验证,我使用ShuffleSplit进行火车/测试集的拆分,拆分为2。但我相信这两种情况都有其他的用例。

@Ken Syme已经有了一个很好的答案。我只是想补充一点。

  • StratifiedKFoldKFold的变体。首先,StratifiedKFold洗牌数据,然后将数据分成n_splits部分并完成。 现在,它将使用每个部分作为测试集。注意它只会在拆分前对数据进行一次洗牌。

使用shuffle = True,数据将被random_state洗牌。否则, 数据被np.random(作为默认值)洗牌。 例如,使用n_splits = 4,数据有3个类(label)用于y(因变量)。4个测试集覆盖所有数据,没有任何重叠。

enter image description here

  • 另一方面,StratifiedShuffleSplitShuffleSplit的变体。 首先,StratifiedShuffleSplit洗牌您的数据,然后它还将数据分成n_splits部分。但是,还没有完成。在这一步之后,StratifiedShuffleSplit选择一个部分作为测试集。 然后它重复相同的进程n_splits - 1其他时间,以获得n_splits - 1其他测试集。看看下面的图片,数据是一样的,但是这次,4个测试集并没有覆盖所有的数据,即测试集之间有重叠。

enter image description here

因此,这里的区别是StratifiedKFold只进行一次洗牌和拆分,因此测试集不会重叠,而StratifiedShuffleSplit在每次拆分前都进行洗牌,并且它会进行n_splits次拆分,测试集可以重叠

  • 注意:这两种方法使用“分层折叠”(这就是“分层”出现在两个名称中的原因)。这意味着每个部分都保留与原始数据相同百分比的每个类(标签)的样本。你可以在cross_validation documents上阅读更多内容

相关问题 更多 >