Spark之获取GBT二分类函数的概率值
在Spark中,GBT(Gradient Boost Trees,提升树)函数用于实现机器学习中的提升树算法,目前仅支持二分类算法。笔者在实际工作中需要获得其预测的概率值,无奈该函数没有相应的方法。 经过笔者几天的奋斗,终于找到了解决之道。下面将分享在Spark中如何获取GBT二分类函数的概率值的思路。 首先,查看GBT函数的Scala源代码,其中的predict函数如下: 其中的prediction值是我们计算概率值所需要的,prediction的值为_treePredictions(向量)与_treeWeights(向量)的点积,numTrees为GBTClassifier所使用的树的数量。_treePredictions为每棵决策树的预测值组成的向量,_treeWeights为每颗树的权重组成的向量。 那么该如何计算_treePredictions与_treeWeights呢? _treeWeights的计算可直接调用GBTClassifier中的treeWeights方法,输出的数据类型为list。以下为例子:(该例子的数据集和GBT测试代码见附录) 接下来讲述_treeP...