[ https://issues.apache.org/jira/browse/SPARK-19449?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
]
Sean Owen resolved SPARK-19449.
-------------------------------
Resolution: Not A Problem
> Inconsistent results between ml package RandomForestClassificationModel and mllib package
RandomForestModel
> -----------------------------------------------------------------------------------------------------------
>
> Key: SPARK-19449
> URL: https://issues.apache.org/jira/browse/SPARK-19449
> Project: Spark
> Issue Type: Bug
> Components: ML, MLlib
> Affects Versions: 2.1.0
> Reporter: Aseem Bansal
>
> I worked on some code to convert ml package RandomForestClassificationModel to mllib
package RandomForestModel. It was needed because we need to make predictions on the order
of ms. I found that the results are inconsistent although the underlying DecisionTreeModel
are exactly the same. So the behavior between the 2 implementations is inconsistent which
should not be the case.
> The below code can be used to reproduce the issue. Can run this as a simple Java app
as long as you have spark dependencies set up properly.
> {noformat}
> import org.apache.spark.ml.Transformer;
> import org.apache.spark.ml.classification.*;
> import org.apache.spark.ml.linalg.*;
> import org.apache.spark.ml.regression.RandomForestRegressionModel;
> import org.apache.spark.mllib.linalg.DenseVector;
> import org.apache.spark.mllib.linalg.Vector;
> import org.apache.spark.mllib.tree.configuration.Algo;
> import org.apache.spark.mllib.tree.model.DecisionTreeModel;
> import org.apache.spark.mllib.tree.model.RandomForestModel;
> import org.apache.spark.sql.Dataset;
> import org.apache.spark.sql.Row;
> import org.apache.spark.sql.RowFactory;
> import org.apache.spark.sql.SparkSession;
> import org.apache.spark.sql.types.DataTypes;
> import org.apache.spark.sql.types.Metadata;
> import org.apache.spark.sql.types.StructField;
> import org.apache.spark.sql.types.StructType;
> import scala.Enumeration;
> import java.util.ArrayList;
> import java.util.List;
> import java.util.Random;
> abstract class Predictor {
> abstract double predict(Vector vector);
> }
> public class MainConvertModels {
> public static final int seed = 42;
> public static void main(String[] args) {
> int numRows = 1000;
> int numFeatures = 3;
> int numClasses = 2;
> double trainFraction = 0.8;
> double testFraction = 0.2;
> SparkSession spark = SparkSession.builder()
> .appName("conversion app")
> .master("local")
> .getOrCreate();
> Dataset<Row> data = getDummyData(spark, numRows, numFeatures, numClasses);
> Dataset<Row>[] splits = data.randomSplit(new double[]{trainFraction, testFraction},
seed);
> Dataset<Row> trainingData = splits[0];
> Dataset<Row> testData = splits[1];
> testData.cache();
> List<Double> labels = getLabels(testData);
> List<DenseVector> features = getFeatures(testData);
> DecisionTreeClassifier classifier1 = new DecisionTreeClassifier();
> DecisionTreeClassificationModel model1 = classifier1.fit(trainingData);
> final DecisionTreeModel convertedModel1 = convertDecisionTreeModel(model1, Algo.Classification());
> RandomForestClassifier classifier = new RandomForestClassifier();
> RandomForestClassificationModel model2 = classifier.fit(trainingData);
> final RandomForestModel convertedModel2 = convertRandomForestModel(model2);
> System.out.println(
> "****** DecisionTreeClassifier\n" +
> "** Original **" + getInfo(model1, testData) + "\n" +
> "** New **" + getInfo(new Predictor() {
> double predict(Vector vector) {return convertedModel1.predict(vector);}
> }, labels, features) + "\n" +
> "\n" +
> "****** RandomForestClassifier\n" +
> "** Original **" + getInfo(model2, testData) + "\n" +
> "** New **" + getInfo(new Predictor() {double predict(Vector
vector) {return convertedModel2.predict(vector);}}, labels, features) + "\n" +
> "\n" +
> "");
> }
> static Dataset<Row> getDummyData(SparkSession spark, int numberRows, int numberFeatures,
int labelUpperBound) {
> StructType schema = new StructType(new StructField[]{
> new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
> new StructField("features", new VectorUDT(), false, Metadata.empty())
> });
> double[][] vectors = prepareData(numberRows, numberFeatures);
> Random random = new Random(seed);
> List<Row> dataTest = new ArrayList<>();
> for (double[] vector : vectors) {
> double label = (double) random.nextInt(2);
> dataTest.add(RowFactory.create(label, Vectors.dense(vector)));
> }
> return spark.createDataFrame(dataTest, schema);
> }
> static double[][] prepareData(int numRows, int numFeatures) {
> Random random = new Random(seed);
> double[][] result = new double[numRows][numFeatures];
> for (int row = 0; row < numRows; row++) {
> for (int feature = 0; feature < numFeatures; feature++) {
> result[row][feature] = random.nextDouble();
> }
> }
> return result;
> }
> static String getInfo(Predictor predictor,
> List<Double> labels,
> List<DenseVector> features) {
> Long startTime = System.currentTimeMillis();
> List<Double> predictions = new ArrayList<>();
> for (DenseVector feature : features) {
> predictions.add(predictor.predict(feature));
> }
> return getInfo(startTime, labels, predictions);
> }
> static List<Double> getLabels(Dataset<Row> testData) {
> List<Double> labels = new ArrayList<>();
> List<DenseVector> vectors = new ArrayList<>();
> for (Row row : testData.collectAsList()) {
> vectors.add(new DenseVector(((org.apache.spark.ml.linalg.Vector) row.get(1)).toArray()));
> labels.add(row.getDouble(0));
> }
> return labels;
> }
> static List<DenseVector> getFeatures(Dataset<Row> testData) {
> List<DenseVector> features = new ArrayList<>();
> for (Row row : testData.collectAsList()) {
> features.add(new DenseVector(((org.apache.spark.ml.linalg.Vector) row.get(1)).toArray()));
> }
> return features;
> }
> static String getInfo(Transformer model, Dataset<Row> testData) {
> Dataset<Row> predictions = model.transform(testData);
> predictions.cache();
> Dataset<Row> correctPredictions = predictions.filter("label == prediction");
> correctPredictions.cache();
> Dataset<Row> incorrectPredictions = predictions.filter("label != prediction");
> incorrectPredictions.cache();
> Long truePositives = correctPredictions.filter("prediction == 1.0").count();
> Long trueNegatives = correctPredictions.filter("prediction == 0.0").count();
> Long falsePositives = incorrectPredictions.filter("prediction == 1.0").count();
> Long falseNegatives = incorrectPredictions.filter("prediction == 0.0").count();
> return getInfo(null, truePositives, trueNegatives, falsePositives, falseNegatives);
> }
> static String getInfo(Long startTime, List<Double> labels, List<Double>
predictions) {
> Long endTime = System.currentTimeMillis();
> if (labels.size() != predictions.size()) {
> throw new RuntimeException("labels size is " + labels.size() +
> " but predictions size is " + predictions.size());
> }
> Long truePositives = 0L;
> Long trueNegatives = 0L;
> Long falsePositives = 0L;
> Long falseNegatives = 0L;
> for (int i = 0; i < labels.size(); i++) {
> double label = labels.get(i);
> double prediction = predictions.get(i);
> if (label == prediction) {
> if (prediction == 1.0) {
> truePositives += 1;
> } else {
> trueNegatives += 1;
> }
> } else {
> if (prediction == 1.0) {
> falsePositives += 1;
> } else {
> falseNegatives += 1;
> }
> }
> }
> return getInfo(endTime - startTime, truePositives, trueNegatives, falsePositives,
falseNegatives);
> }
> static double ratio(Long numerator, Long denominator) {
> if (numerator == 0 || denominator == 0) {
> return 0;
> }
> return ((double) numerator) / denominator;
> }
> static String getInfo(Long timeTakenMilliseconds, Long truePositives, Long trueNegatives,
Long falsePositives,
> Long falseNegatives) {
> Long testDataCount = truePositives + trueNegatives + falsePositives + falseNegatives;
> double accuracy = ratio(truePositives + trueNegatives, testDataCount);
> double precision = ratio(truePositives, truePositives + falsePositives);
> double recall = ratio(truePositives, truePositives + falseNegatives);
> String last = "";
> if (timeTakenMilliseconds != null) {
> last = ", Average time taken (ms) " + ratio(timeTakenMilliseconds, testDataCount);
> }
> return (
> "true positives " + truePositives
> + ", true negatives " + trueNegatives
> + ", false positives " + falsePositives
> + ", false negatives " + falseNegatives
> + ", total " + testDataCount
> + "\n\t accuracy " + accuracy
> + ", precision " + precision
> + ", recall " + recall
> + last
> );
> }
> static DecisionTreeModel convertDecisionTreeModel(org.apache.spark.ml.tree.DecisionTreeModel
model,
> Enumeration.Value algo) {
> return new DecisionTreeModel(model.rootNode().toOld(1), algo);
> }
> static RandomForestModel convertRandomForestModel(org.apache.spark.ml.tree.TreeEnsembleModel
model) {
> Enumeration.Value algo;
> if (model instanceof RandomForestRegressionModel) {
> algo = Algo.Regression();
> } else {
> algo = Algo.Classification();
> }
> Object[] decisionTreeModels = model.trees();
> DecisionTreeModel[] convertedDecisionTreeModels = new DecisionTreeModel[decisionTreeModels.length];
> for (int i = 0; i < decisionTreeModels.length; i++) {
> org.apache.spark.ml.tree.DecisionTreeModel originalModel = (org.apache.spark.ml.tree.DecisionTreeModel)
decisionTreeModels[i];
> DecisionTreeModel convertedModel = convertDecisionTreeModel(originalModel,
algo);
> convertedDecisionTreeModels[i] = convertedModel;
> }
> RandomForestModel result = new RandomForestModel(algo, convertedDecisionTreeModels);
> return result;
> }
> }
> {noformat}
> The output looks like the below. In the below the Original refers to ml package version
and New refers to mllib package version.
> - I converted the mllib version Decision tree to ml version Decision tree. Gave both
versions same input and I received the exact same output.
> - Then converted the mllib version Random Forest to ml version Random Forest giving both
the same underlying Decision trees (using the previoeus conversion method). Gave both versions
same input but I received different output.
> {noformat}
> ****** DecisionTreeClassifier
> ** Original **true positives 8128, true negatives 1923, false positives 7942, false negatives
1897, total 19890
> accuracy 0.5053293112116641, precision 0.5057871810827629, recall 0.8107730673316709
> ** New **true positives 8128, true negatives 1923, false positives 7942, false negatives
1897, total 19890
> accuracy 0.5053293112116641, precision 0.5057871810827629, recall 0.8107730673316709,
Average time taken (ms) 0.001558572146807441
> ****** RandomForestClassifier
> ** Original **true positives 3940, true negatives 5915, false positives 3950, false negatives
6085, total 19890
> accuracy 0.49547511312217196, precision 0.49936628643852976, recall 0.39301745635910224
> ** New **true positives 2461, true negatives 7350, false positives 2515, false negatives
7564, total 19890
> accuracy 0.4932629462041227, precision 0.4945739549839228, recall 0.2454862842892768,
Average time taken (ms) 0.01085972850678733
> {noformat}
--
This message was sent by Atlassian JIRA
(v6.3.15#6346)
---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@spark.apache.org
For additional commands, e-mail: issues-help@spark.apache.org
|