sklearn shuffle train_test_split不洗牌标签和输入以匹配

2024-06-28 15:38:11 发布

您现在位置:Python中文网/ 问答频道 /正文

我相信这段代码导致我的X和Y数据不一致,因为它们的索引号不同。它们是否应该相同,以便模型知道哪个输入与哪个标签相关

x_train, x_valid, y_train, y_valid = train_test_split(Normalise_Data(data), labels, test_size=0.2,shuffle=True)

这是此函数的输入和标签的终端输出。 索引是否应该不对应

x_train
Out[94]: 
         0         1         2     3     ...      4605      4606      4607      4608
114  0.999399  0.000000  0.000000   0.0  ...  0.000025  0.000048  0.000016  0.000038
44   0.995420  0.000000  0.000000   0.0  ...  0.000066  0.000103  0.000058  0.000040
160  0.999492  0.000000  0.000000   0.0  ...  0.000021  0.000024  0.000044  0.000028
293  0.999893  0.000000  0.000250   0.0  ...  0.000002  0.000007  0.000014  0.000003
129  0.999458  0.000885  0.000976   0.0  ...  0.000005  0.000034  0.000044  0.000048
..        ...       ...       ...   ...  ...       ...       ...       ...       ...
176  0.999750  0.000041  0.000000   0.0  ...  0.000032  0.000039  0.000034  0.000029
241  0.999832  0.000000  0.000000   0.0  ...  0.000005  0.000005  0.000017  0.000003
283  0.999927  0.000170  0.000094   0.0  ...  0.000007  0.000009  0.000010  0.000012
405  0.998595  0.000000  0.000000   0.0  ...  0.000051  0.000087  0.000019  0.000031
267  0.999899  0.000000  0.000254   0.0  ...  0.000011  0.000016  0.000015  0.000020 




y_train
Out[95]: 
567     0
44      0
884     0
1902    0
676     0
       ..
1003    0
1475    0
1826    0
302     1
1718    0
Name: prediction, Length: 427, dtype: int64


Tags: 数据代码模型testdatasizelabelstrain
1条回答
网友
1楼 · 发布于 2024-06-28 15:38:11

train_test_split将允许您使用pd.DataFramepd.Serieses,但它不使用索引来决定与什么相关的内容-它只是偏离了内容呈现的顺序:

In [5]: X = pd.DataFrame(np.random.random((5,5)), index=list('ABCDE'))

In [6]: y = pd.Series(np.random.random(5), index=list('FGHIJ'))

In [7]: train_test_split(X, y)
Out[7]:
[          0         1         2         3         4
 A  0.353250  0.859230  0.055278  0.871435  0.827556
 B  0.906734  0.244356  0.082618  0.614280  0.200890
 E  0.285790  0.483524  0.206643  0.881300  0.085348,

           0         1         2         3         4
 D  0.437108  0.883394  0.468495  0.329983  0.685234
 C  0.387929  0.889313  0.728260  0.049744  0.819579,

 F    0.720916
 G    0.072408
 J    0.674973
 dtype: float64,

 I    0.452183
 H    0.202770
 dtype: float64]

只需将输入更改为Normalize_Data(data).sort_index()labels.sort_index(),就可以很容易地解决这个问题

相关问题 更多 >