如何通过插补填充泰坦尼克号年龄列中的NaN值

2024-10-01 07:45:39 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在进行泰坦尼克号卡格尔号竞赛,目前我正在尝试估算缺失的Age

其思想是计算训练集上每个[Pclass, Sex]组的平均Age,然后使用此信息替换训练集和测试集上的NaN

这就是我到目前为止所做的:

meanAgeTrain = train.groupby(['Pclass', 'Sex'])['Age'].transform('mean')
    
for df in [train, test]:
    df['Age'] = df['Age'].fillna(meanAgeTrain)

问题是,这仍然会在测试集中留下一些NaN值,同时在训练集中消除所有NaN。我想这与指数有关

我需要的是:

  1. 计算训练集中每个P类/性别组的平均值
  2. 将训练集中的所有NaN值映射到正确的平均值
  3. 将测试集中的所有NaN值映射到正确的平均值(按Pclass/Sex查找,而不是基于索引)

如何使用熊猫正确地做到这一点

编辑:

谢谢你的建议。@Reza的那个很有效,但我不是100%理解。所以我正试图想出我自己的解决办法

这是可行的,但我对熊猫还不熟悉,我想知道是否有更简单的方法来实现它

trainMeans = self.train.groupby(['Pclass', 'Sex'])['Age'].mean().reset_index()

def f(x):
    if x["Age"] == x["Age"]:  # not NaN
        return x["Age"]
    return trainMeans.loc[(trainMeans["Pclass"] == x["Pclass"]) & (trainMeans["Sex"] == x["Sex"])]["Age"].values[0]

 self.train['Age'] = self.train.apply(f, axis=1)
 self.test['Age'] = self.test.apply(f, axis=1)

尤其是函数中的if在我看来不是一个最佳实践。我需要一种方法将函数仅应用于NaN ages

编辑2

事实证明,重置索引会使事情变得更加复杂和缓慢,因为在分组后,索引已经是我想要用作映射键的东西了。这会更快更容易:

trainMeans = self.train.groupby(['Pclass', 'Sex'])['Age'].mean()

def f(x):
    if not np.isnan(x["Age"]):  # not NaN
        return x["Age"]
    return trainMeans[x["Pclass"], x["Sex"]]

self.train['Age'] = self.train.apply(f, axis=1)
self.test['Age'] = self.test.apply(f, axis=1)

这可以进一步简化吗


Tags: testselfdfagereturntrainnanmean
2条回答
  • 您将看到两种填充方法,groupby fillna with meanrandom forest Recessiver,彼此相差不到一年的1/100
    • 有关统计比较,请参见答案的底部

用平均值填充nan值

import pandas as pd
import seaborn as sns

# load dataset
df = sns.load_dataset('titanic')

# map sex to a numeric type
df.sex = df.sex.map({'male': 1, 'female': 0})

# Populate Age_Fill
df['Age_Fill'] = df['age'].groupby([df['pclass'], df['sex']]).apply(lambda x: x.fillna(x.mean()))

# series with filled ages
groupby_result = df.Age_Fill[df.age.isnull()]

# display(df[df.age.isnull()].head())
 survived  pclass     sex  age  sibsp  parch     fare embarked   class    who  adult_male deck  embark_town alive  alone  Age_Fill
        0       3    male  NaN      0      0   8.4583        Q   Third    man        True  NaN   Queenstown    no   True  26.50759
        1       2    male  NaN      0      0  13.0000        S  Second    man        True  NaN  Southampton   yes   True  30.74071
        1       3  female  NaN      0      0   7.2250        C   Third  woman       False  NaN    Cherbourg   yes   True  21.75000
        0       3    male  NaN      0      0   7.2250        C   Third    man        True  NaN    Cherbourg    no   True  26.50759
        1       3  female  NaN      0      0   7.8792        Q   Third  woman       False  NaN   Queenstown   yes   True  21.75000

从RandomForestRegressionor中填充nan值

  • ^{}
  • Kaggle: Titanic
    • 年龄似乎是一个很有前途的特征。因此,简单地用中值/均值/模式填充空值是没有意义的
    • 根据这里的结果,我不认为这有多大区别
from sklearn.ensemble import RandomForestRegressor
import pandas as pd
import seaborn as sns

# load dataset
df = sns.load_dataset('titanic')

# map sex to a numeric type
df.sex = df.sex.map({'male': 1, 'female': 0})

# split data
train = df.loc[(df.age.notnull())]  # known age values
test = df.loc[(df.age.isnull())]  # all nan age values

# select age column
y = train.values[:, 3]

# select pclass and sex
X = train.values[:, [1, 2]]

# create RandomForestRegressor model
rfr = RandomForestRegressor(n_estimators=2000, n_jobs=-1)

# Fit a model
rfr.fit(X, y)

# Use the fitted model to predict the missing values
predictedAges = rfr.predict(test.values[:, [1, 2]])

# create predicted age column
df['pred_age'] = df.age

# fill column
df.loc[(df.pred_age.isnull()), 'pred_age'] = predictedAges 

# display(df[df.age.isnull()].head())
 survived  pclass  sex  age  sibsp  parch     fare embarked   class    who  adult_male deck  embark_town alive  alone  pred_age
        0       3    1  NaN      0      0   8.4583        Q   Third    man        True  NaN   Queenstown    no   True  26.49935
        1       2    1  NaN      0      0  13.0000        S  Second    man        True  NaN  Southampton   yes   True  30.73126
        1       3    0  NaN      0      0   7.2250        C   Third  woman       False  NaN    Cherbourg   yes   True  21.76513
        0       3    1  NaN      0      0   7.2250        C   Third    man        True  NaN    Cherbourg    no   True  26.49935
        1       3    0  NaN      0      0   7.8792        Q   Third  woman       False  NaN   Queenstown   yes   True  21.76513

groupby与rfr的比较

print(predictedAges - groupby_result).describe())

count    177.00000
mean       0.00362
std        0.01877
min       -0.04167
25%        0.01121
50%        0.01121
75%        0.01131
max        0.02969
Name: Age_Fill, dtype: float64

# comparison dataframe
comp = pd.DataFrame({'rfr': predictedAges.tolist(), 'gb': groupby_result.tolist()})
comp['diff'] = comp.rfr - comp.gb

# display(comp)
      rfr        gb     diff
 26.51880  26.50759  0.01121
 30.69903  30.74071 -0.04167
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 34.63090  34.61176  0.01913
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 34.63090  34.61176  0.01913
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 30.69903  30.74071 -0.04167
 41.24592  41.28139 -0.03547
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 21.76131  21.75000  0.01131
 21.76131  21.75000  0.01131
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 34.63090  34.61176  0.01913
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 41.24592  41.28139 -0.03547
 21.76131  21.75000  0.01131
 30.69903  30.74071 -0.04167
 41.24592  41.28139 -0.03547
 41.24592  41.28139 -0.03547
 41.24592  41.28139 -0.03547
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 28.75266  28.72297  0.02969
 26.51880  26.50759  0.01121
 34.63090  34.61176  0.01913
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 34.63090  34.61176  0.01913
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 21.76131  21.75000  0.01131
 34.63090  34.61176  0.01913
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 30.69903  30.74071 -0.04167
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 34.63090  34.61176  0.01913
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 30.69903  30.74071 -0.04167
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 30.69903  30.74071 -0.04167
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 30.69903  30.74071 -0.04167
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 28.75266  28.72297  0.02969
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 34.63090  34.61176  0.01913
 30.69903  30.74071 -0.04167
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 30.69903  30.74071 -0.04167
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 41.24592  41.28139 -0.03547
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 41.24592  41.28139 -0.03547
 26.51880  26.50759  0.01121
 34.63090  34.61176  0.01913
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131
 26.51880  26.50759  0.01121
 26.51880  26.50759  0.01121
 21.76131  21.75000  0.01131

计算随机训练集上的平均值

  • 本例计算随机训练集的平均值,然后填充训练集和测试集中的nan
  • 当两个数据帧都有匹配的索引且填充列相同时,使用^{}从另一个数据帧填充数据帧列中缺少的值。
    • Pclass/Sex,而不是基于索引pclasssex被设置为索引,这就是.fillna的工作方式
  • 在本例中,train是数据的67%,而test是数据的33%。
    • test_sizetrain_size可以根据需要设置,如^{}
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split

# load dataset
df = sns.load_dataset('titanic')

# map sex to a numeric type
df.sex = df.sex.map({'male': 1, 'female': 0})

# randomly split the dataframe into a train and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

# select columns for X and y
X = df[['pclass', 'sex']]
y = df['age']

# create a dataframe of train (X, y) and test (X, y)
train = pd.concat([X_train, y_train], axis=1).reset_index(drop=True)
test = pd.concat([X_test, y_test], axis=1).reset_index(drop=True)

# calculate means for train
train_means = train.groupby(['pclass', 'sex']).agg({'age': 'mean'})

# display train_means, a multi-index dataframe
                 age
pclass sex          
1      0    34.66667
       1    41.38710
2      0    27.90217
       1    30.50000
3      0    21.56338
       1    26.87163

# fill nan values in train
train = train.set_index(['pclass', 'sex']).age.fillna(train_means.age).reset_index()

# fill nan values in test
test = test.set_index(['pclass', 'sex']).age.fillna(train_means.age).reset_index()

您可以首先为Age创建映射:

cols = ['Pclass', 'Sex']
age_class_sex = train.groupby(cols)['Age'].mean().reset_index()

然后将其与测试和训练单独合并,以便解决索引问题

train['Age'] = train['Age'].fillna(train[cols].reset_index().merge(age_class_sex, how='left', on=cols).set_index('index')['Age'])
test['Age'] = test['Age'].fillna(test[cols].reset_index().merge(age_class_sex, how='left', on=cols).set_index('index')['Age'])

相关问题 更多 >