spark-user mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sethah <>
Subject Re: Spark MLlib Decision Tree Node Accuracy
Date Wed, 09 Sep 2015 22:05:51 GMT
If you are able to traverse the tree, then you can extract the id of the leaf
node for each feature vector. This is like a modified predict method where
it returns the leaf node assigned to the data point instead of the
prediction for that leaf node. The following example code should work: 

import org.apache.spark.mllib.tree.model.Node
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.linalg.Vector

// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))

// Train a DecisionTree model.
//  Empty categoricalFeaturesInfo indicates all features are continuous.
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "gini"
val maxDepth = 5
val maxBins = 32

val model = DecisionTree.trainClassifier(trainingData, numClasses,
  impurity, maxDepth, maxBins)

def predictImpl(node: Node, features: Vector): Node = {
  if (node.isLeaf) {
  } else {
    if (node.split.get.featureType == Continuous) {
      if (features(node.split.get.feature) <= node.split.get.threshold) {
        predictImpl(node.leftNode.get, features)
      } else {
        predictImpl(node.rightNode.get, features)
    } else {
(node.split.get.categories.contains(features(node.split.get.feature))) {
        predictImpl(node.leftNode.get, features)
      } else {
        predictImpl(node.rightNode.get, features)

val nodeIDAndPredsAndLabels = { lp => 
  val node = predictImpl(model.topNode, lp.features)
  (, (node.predict.predict, lp.label))

>From here, you should be able to perform analysis of the accuracy of each
leaf node.

Note that in the new Spark ML library a predictNodeIndex is implemented
(which is being converted to a predictImpl method) similar to the
implementation above. Hopefully that code helps.

View this message in context:
Sent from the Apache Spark User List mailing list archive at

To unsubscribe, e-mail:
For additional commands, e-mail:

View raw message