Spark 1.3.1中的平均随机森林预测

2024-10-03 19:30:56 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图在Spark 1.3.1中计算随机森林预测的平均值,因为所有树的预测概率只有在将来的版本中才可用。你知道吗

到目前为止,我所能做的就是使用以下函数:

def calculaProbs(dados, modelRF):
    trees = modelRF._java_model.trees()
    nTrees = modelRF.numTrees()
    nPontos = dados.count()
    predictions = np.zeros(nPontos)
    for i in range(nTrees):
        dtm = DecisionTreeModel(trees[i])
        predictions += np.array(dtm.predict(dados.map(lambda x: x.features)).collect())
    predictions = predictions/nTrees
    return predictions

正如预期的那样,这段代码运行得太慢,因为我正在从每棵树收集预测并将它们添加到驱动程序中。 我放不下$dtm.predit公司()$在这个版本的Spark的Map操作中。下面是文档中的注释:“注意:在Python中,predict当前不能在RDD转换或操作中使用。直接在RDD上调用predict。”

有什么改进性能的方法吗?如何从2个RDD中添加值而不将它们的值收集到向量中?你知道吗


Tags: 版本np森林概率treespredictspark平均值