spark-user mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From Adamantios Corais <adamantios.cor...@gmail.com>
Subject Re: Grid Search using Spark MLLib Pipelines
Date Fri, 12 Aug 2016 18:24:13 GMT
Great.

I like your second solution. But how can I make sure that cvModel holds 
the best model overall (as opposed to the last one that was tired out 
but the grid search)?

In addition, do you have an idea how to collect the average error of 
each grid search (here 1x1x1)?



On 12/08/2016 08:59 μμ, Bryan Cutler wrote:
> You will need to cast bestModel to include the MLWritable trait.  The 
> class Model does not mix it in by default.  For instance:
>
> cvModel.bestModel.asInstanceOf[MLWritable].save("/my/path")
>
> Alternatively, you could save the CV model directly, which takes care 
> of this
>
> cvModel.save("/my/path")
>
> On Fri, Aug 12, 2016 at 9:17 AM, Adamantios Corais 
> <adamantios.corais@gmail.com <mailto:adamantios.corais@gmail.com>> wrote:
>
>     Hi,
>
>     Assuming that I have run the following pipeline and have got the
>     best logistic regression model. How can I then save that model for
>     later use? The following command throws an error:
>
>     cvModel.bestModel.save("/my/path")
>
>     Also, is it possible to get the error (a collection of) for each
>     combination of parameters?
>
>     I am using spark 1.6.2
>
>     import org.apache.spark.ml.Pipeline
>     import org.apache.spark.ml
>     <http://org.apache.spark.ml>.classification.LogisticRegression
>     import org.apache.spark.ml
>     <http://org.apache.spark.ml>.evaluation.BinaryClassificationEvaluator
>     import org.apache.spark.ml.tuning.{ParamGridBuilder , CrossValidator}
>
>     val lr = new LogisticRegression()
>
>     val pipeline = new Pipeline().
>         setStages(Array(lr))
>
>     val paramGrid = new ParamGridBuilder().
>         addGrid(lr.elasticNetParam , Array(0.1)).
>         addGrid(lr.maxIter , Array(10)).
>         addGrid(lr.regParam , Array(0.1)).
>         build()
>
>     val cv = new CrossValidator().
>         setEstimator(pipeline).
>         setEvaluator(new BinaryClassificationEvaluator).
>         setEstimatorParamMaps(paramGrid).
>         setNumFolds(2)
>
>     val cvModel = cv.
>         fit(training)
>
>
>     ---------------------------------------------------------------------
>     To unsubscribe e-mail: user-unsubscribe@spark.apache.org
>     <mailto:user-unsubscribe@spark.apache.org>
>
>


Mime
View raw message