"Spark随机森林交叉验证"

2024-09-23 22:29:28 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图在Spark的一个随机林中运行交叉验证。在

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

data = nds.sc.parallelize([
 LabeledPoint(0.0, [0,402,6,0]),
 LabeledPoint(0.0, [3,500,3,0]),
 LabeledPoint(1.0, [1,590,1,1]),
 LabeledPoint(1.0, [3,328,5,0]),
 LabeledPoint(1.0, [4,351,4,0]),
 LabeledPoint(0.0, [2,372,2,0]),
 LabeledPoint(0.0, [4,302,5,0]),
 LabeledPoint(1.0, [1,387,2,0]),
 LabeledPoint(1.0, [1,419,3,0]),
 LabeledPoint(0.0, [1,370,5,0]),
 LabeledPoint(0.0, [1,410,4,0]),
 LabeledPoint(0.0, [2,509,7,1]),
 LabeledPoint(0.0, [1,307,5,0]),
 LabeledPoint(0.0, [0,424,4,1]),
 LabeledPoint(0.0, [1,509,2,1]),
 LabeledPoint(1.0, [3,361,4,0]),
 ])


train=data.toDF(['label','features'])

numfolds =2

rf = RandomForestClassifier(labelCol="label", featuresCol="features")
evaluator = MulticlassClassificationEvaluator()  


paramGrid = ParamGridBuilder().addGrid(rf.maxDepth,      
[4,8,10]).addGrid(rf.impurity, ['entropy','gini']).addGrid(rf.featureSubsetStrategy, [6,8,10]).build()

pipeline = Pipeline(stages=[rf])

crossval = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=paramGrid,
    evaluator=evaluator,
    numFolds= numfolds)

model = crossval.fit(train)

我得到以下错误

^{pr2}$

似乎paramGrid没有将我的输入作为列表读取。是否有其他格式或解决方法。任何帮助都将不胜感激。在


Tags: fromimportdatapipelineevaluatormlpysparkrf
2条回答

rf.featureSubsetStrategy传递了不正确的值。它应该是一个描述策略的字符串,它支持以下值:auto、all、onethird、sqrt、log2。请参见:RandomForestClassifier.featureSubsetStrategy.doc。在

也不要使用data.toDF(['label','features'])。它不能保持正确的秩序。使用:

data.toDF()

或者如果要修改名称:

^{pr2}$

最后,标签列必须被索引,或者您必须提供所需的元数据。见How can I declare a Column as a categorical feature in a DataFrame for use in ml

这是我最初的代码

lr_parameter_grid_ = ParamGridBuilder().addGrid(lr.maxIter, [50, 200, 500])\
    .addGrid(lr.regParam, [0, 0.3, 1])\
    .addGrid(lr.elasticNetParam, [0, 0.3, 1]).build()

我也犯了同样的错误。然后我只保留了1个参数(maxIter),它就起作用了。后来我添加了regParam&;它又起作用了。最后我添加了elasticNetParam,它仍然有效。我不知道为什么第一次不工作,如果你放了多个参数,但当你从一个开始,然后继续添加时,它就起作用了。在

不是一个永久的伟大的解决办法,但对我有用

相关问题 更多 >