spark-user mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sethah <shen...@us.ibm.com>
Subject Re: Spark MLlib Decision Tree Node Accuracy
Date Wed, 09 Sep 2015 22:05:51 GMT
If you are able to traverse the tree, then you can extract the id of the leaf
node for each feature vector. This is like a modified predict method where
it returns the leaf node assigned to the data point instead of the
prediction for that leaf node. The following example code should work: 

import org.apache.spark.mllib.tree.model.Node
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.linalg.Vector

// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))

// Train a DecisionTree model.
//  Empty categoricalFeaturesInfo indicates all features are continuous.
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "gini"
val maxDepth = 5
val maxBins = 32

val model = DecisionTree.trainClassifier(trainingData, numClasses,
categoricalFeaturesInfo,
  impurity, maxDepth, maxBins)

def predictImpl(node: Node, features: Vector): Node = {
  if (node.isLeaf) {
    node
  } else {
    if (node.split.get.featureType == Continuous) {
      if (features(node.split.get.feature) <= node.split.get.threshold) {
        predictImpl(node.leftNode.get, features)
      } else {
        predictImpl(node.rightNode.get, features)
      }
    } else {
      if
(node.split.get.categories.contains(features(node.split.get.feature))) {
        predictImpl(node.leftNode.get, features)
      } else {
        predictImpl(node.rightNode.get, features)
      }
    }
  }
}

val nodeIDAndPredsAndLabels = data.map { lp => 
  val node = predictImpl(model.topNode, lp.features)
  (node.id, (node.predict.predict, lp.label))
}

>From here, you should be able to perform analysis of the accuracy of each
leaf node.

Note that in the new Spark ML library a predictNodeIndex is implemented
(which is being converted to a predictImpl method) similar to the
implementation above. Hopefully that code helps.



--
View this message in context: http://apache-spark-user-list.1001560.n3.nabble.com/Spark-MLlib-Decision-Tree-Node-Accuracy-tp24561p24629.html
Sent from the Apache Spark User List mailing list archive at Nabble.com.

---------------------------------------------------------------------
To unsubscribe, e-mail: user-unsubscribe@spark.apache.org
For additional commands, e-mail: user-help@spark.apache.org


Mime
View raw message