2024-09-23 06:30:03 发布
网友
我试图创建一个logisticRegressionWithGD,但是它得到的错误是
org.apache.spark.SparkException: Input validation failed.
如果我给它二进制输入(0,1而不是0,1,2),它确实成功了。在
输入示例:
代码: model = LogisticRegressionWithSGD.train(parsed_data)
model = LogisticRegressionWithSGD.train(parsed_data)
spark中的Logistic回归模型应该只用于二元分类吗?在
虽然从文档中看不清楚(您必须深入研究source code才能实现它),但是LogisticRegressionWithSGD只适用于二进制数据;对于多项式回归,您应该使用LogisticRegressionWithLBFGS:
LogisticRegressionWithSGD
LogisticRegressionWithLBFGS
from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel, LogisticRegressionWithSGD from pyspark.mllib.regression import LabeledPoint parsed_data = [LabeledPoint(0.0, [4.6,3.6,1.0,0.2]), LabeledPoint(0.0, [5.7,4.4,1.5,0.4]), LabeledPoint(1.0, [6.7,3.1,4.4,1.4]), LabeledPoint(0.0, [4.8,3.4,1.6,0.2]), LabeledPoint(2.0, [4.4,3.2,1.3,0.2])] model = LogisticRegressionWithSGD.train(sc.parallelize(parsed_data)) # gives error: # org.apache.spark.SparkException: Input validation failed. model = LogisticRegressionWithLBFGS.train(sc.parallelize(parsed_data), numClasses=3) # works OK
虽然从文档中看不清楚(您必须深入研究source code才能实现它),但是
LogisticRegressionWithSGD
只适用于二进制数据;对于多项式回归,您应该使用LogisticRegressionWithLBFGS
:相关问题 更多 >
编程相关推荐