基于anacond的随机森林交叉验证

2024-09-19 07:09:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我用泰坦尼克号的数据集来预测一个乘客是幸存还是没有使用随机森林。这是我的密码:

import numpy as np 
import pandas as pd 
from sklearn.ensemble import RandomForestClassifier
from sklearn import cross_validation
import matplotlib.pyplot as plt
%matplotlib inline

data=pd.read_csv("C:\\Users\\kabala\\Downloads\\Titanic.csv")
data.isnull().any()
data["Age"]=data1["Age"].fillna(data1["Age"].median())
data["PClass"]=data["PClass"].fillna("3rd")
data["PClass"].isnull().any()
data1.isnull().any()
pd.get_dummies(data.Sex)
# choosing the predictive variables 
x=data[["PClass","Age","Sex"]]
# the target variable is y 
y=data["Survived"]
modelrandom=RandomForestClassifier(max_depth=3)
modelrandom=cross_validation.cross_val_score(modelrandom,x,y,cv=5)

但是,我一直在犯这样的错误:

ValueError: could not convert string to float: 'female'

我不明白问题出在哪里,因为我把性特征改成了假人

谢谢:)


Tags: fromimportagedataasanysklearnvalidation
1条回答
网友
1楼 · 发布于 2024-09-19 07:09:36

pd.get_dummies返回一个数据帧,并且不执行适当的操作。因此,您实际上是在发送带有sex列的sting。你知道吗

所以你需要像X = pd.get_dummies(data[['Sex','PClass','Age']], columns=['Sex','PClass'])这样的东西,这样可以解决你的问题。我认为PClass也是一个字符串列,您需要使用伪变量,因为它填充了'3rd'。你知道吗

还有一些地方您调用data.isnull().any()对底层数据帧没有任何作用。我让他们保持原样,但仅供参考,他们可能没有按你的意思做。你知道吗

完整代码为:

import numpy as np 
import pandas as pd 
from sklearn.ensemble import RandomForestClassifier
from sklearn import cross_validation
import matplotlib.pyplot as plt
%matplotlib inline

data=pd.read_csv("C:\\Users\\kabala\\Downloads\\Titanic.csv")
data.isnull().any()   <  -Beware this is not doing anything to the data
data["Age"]=data1["Age"].fillna(data1["Age"].median())
data["PClass"]=data["PClass"].fillna("3rd")
data["PClass"].isnull().any()  <  -Beware this is not doing anything to the data
data1.isnull().any() <  -Beware this is not doing anything to the data

#********Fix for your code*******
X = pd.get_dummies(data[['Sex','PClass','Age']], columns=['Sex','PClass'])

# choosing the predictive variables 
# x=data[["PClass","Age","Sex"]]
# the target variable is y 
y=data["Survived"]
modelrandom=RandomForestClassifier(max_depth=3)
modelrandom=cross_validation.cross_val_score(modelrandom,x,y,cv=5)

相关问题 更多 >