spark-user mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From AaronLee <...@wish.com.INVALID>
Subject Re: Spark ml how to extract split points from trained decision tree mode
Date Fri, 12 Jun 2020 05:16:07 GMT
@srowen. You are totally right, the model was not trained correctly. But it
is weird as the dataset I used actually has 50m rows. It has binary label
with 20% positive, and 1 feature in feature vector. Do not understand why it
does not trained correctly 


```
scala> df2.count
res56: Long = 48174858

scala> df2.show
+--------------------+-----+
|            features|label|
+--------------------+-----+
|              [14.0]|  1.0|
|               [2.0]|  0.0|
|               [2.0]|  0.0|
|               [1.0]|  1.0|
|[0.9700000286102295]|  1.0|
|[1.9600000381469727]|  0.0|
|[0.9900000095367432]|  0.0|
|[11.739999771118164]|  1.0|
|               [1.0]|  0.0|
|[0.9800000190734863]|  0.0|
|               [5.0]|  0.0|
| [5.940000057220459]|  1.0|
|              [11.0]|  0.0|
|               [4.0]|  0.0|
|               [1.0]|  1.0|
|[1.9700000286102295]|  0.0|
| [6.989999771118164]|  0.0|
|[0.9700000286102295]|  0.0|
|[0.9700000286102295]|  0.0|
|[0.9900000095367432]|  0.0|
+--------------------+-----+
only showing top 20 rows


scala> df2.printSchema
root
 |-- features: vector (nullable = true)
 |-- label: double (nullable = true)

scala> val dt = new
DecisionTreeClassifier().setLabelCol("label").setFeaturesCol("features").setMaxBins(10)
dt: org.apache.spark.ml.classification.DecisionTreeClassifier =
dtc_2b6b6e170840

scala>  val dtm = dt.fit(df2)
*dtm: org.apache.spark.ml.classification.DecisionTreeClassificationModel =
DecisionTreeClassificationModel (uid=dtc_2b6b6e170840) of depth 0 with 1
nodes
*

scala> val df3 = dtm.transform(df2)
df3: org.apache.spark.sql.DataFrame = [features: vector, label: double ... 3
more fields]

scala>  df3.show(100,false)
+--------------------+-----+----------------------+----------------------------------------+----------+
|features            |label|rawPrediction         |probability                           

|prediction|
+--------------------+-----+----------------------+----------------------------------------+----------+
|[14.0]              |1.0 
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[2.0]               |0.0 
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[2.0]               |0.0 
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[1.0]               |1.0 
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[0.9700000286102295]|1.0 
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[1.9600000381469727]|0.0 
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
|[0.9900000095367432]|0.0 
|[3.872715E7,9447708.0]|[0.8038871645454565,0.19611283545454353]|0.0       |
....
```




--
Sent from: http://apache-spark-user-list.1001560.n3.nabble.com/

---------------------------------------------------------------------
To unsubscribe e-mail: user-unsubscribe@spark.apache.org


Mime
View raw message