flink-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From "ASF GitHub Bot (JIRA)" <j...@apache.org>
Subject [jira] [Commented] (FLINK-1723) Add cross validation for model evaluation
Date Wed, 08 Jul 2015 13:46:05 GMT

    [ https://issues.apache.org/jira/browse/FLINK-1723?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=14618606#comment-14618606
] 

ASF GitHub Bot commented on FLINK-1723:
---------------------------------------

Github user thvasilo commented on a diff in the pull request:

    https://github.com/apache/flink/pull/891#discussion_r34148455
  
    --- Diff: flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/evaluation/CrossValidationITSuite.scala
---
    @@ -0,0 +1,123 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *     http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package org.apache.flink.ml.evaluation
    +
    +import org.apache.flink.api.scala._
    +import org.apache.flink.ml.common.ParameterMap
    +import org.apache.flink.ml.preprocessing.StandardScaler
    +import org.apache.flink.ml.regression.RegressionData._
    +import org.apache.flink.ml.regression.{MultipleLinearRegression, RegressionData}
    +import org.apache.flink.test.util.FlinkTestBase
    +
    +import org.scalatest.{FlatSpec, Matchers}
    +
    +class CrossValidationITSuite extends FlatSpec with Matchers with FlinkTestBase {
    +  behavior of "the cross-validation suite"
    +
    +  it should "be able to split the input into K folds" in {
    +    // Original code from the Apache Spark project
    +    val env = ExecutionEnvironment.getExecutionEnvironment
    +
    +    val data = env.fromCollection(1 to 100)
    +    val collectedData = data.collect().sorted
    +
    +    val twoFolds = KFold(2).folds(data, 42L)
    +    twoFolds(0)._1.collect().sorted shouldEqual twoFolds(1)._2.collect().sorted
    +    twoFolds(0)._2.collect().sorted shouldEqual twoFolds(1)._1.collect().sorted
    +
    +    for (folds <- 2 to 10) {
    +      for (seed <- 1 to 5) {
    +        val foldedDataSets = KFold(folds).folds(data, seed)
    +        foldedDataSets.length shouldEqual  folds
    +
    +        foldedDataSets.foreach { case (training, testing) =>
    +          val result = testing.union(training).collect().sorted
    +          val testingSize = testing.collect().size.toDouble
    +          testingSize should be > 0.0
    +
    +          // Within 4 standard deviations of the mean
    +          val p = 1 / folds.toDouble
    +          val range = 4 * math.sqrt(100 * p * (1 - p))
    +          val expected = 100 * p
    +          val lowerBound = expected - range
    +          val upperBound = expected + range
    +          //Ensure size of test data is within expected bounds
    +          testingSize should be > lowerBound
    +          testingSize should be < upperBound
    +          training.collect().size should be > 0
    +
    +          // The combined set should contain all data
    +          result shouldEqual collectedData
    +        }
    +        // K fold cross validation should only have each element in the validation set
exactly once
    +        foldedDataSets.map(_._2).reduce((x, y) => x.union(y)).collect().sorted shouldEqual
    +          data.collect().sorted
    +      }
    +    }
    +  }
    +
    +  def fixture = new {
    +    val env = ExecutionEnvironment.getExecutionEnvironment
    +
    +    import RegressionData._
    +
    +
    +    val inputDS = env.fromCollection(data)
    +
    +    val mlr = MultipleLinearRegression()
    +      .setStepsize(10.0)
    +      .setIterations(100)
    +
    +    println()
    --- End diff --
    
    It prints a line between the consecutive test runs, I just have it there so I can more
easily see what is happening. The tests don't really do anything yet, just print results.


> Add cross validation for model evaluation
> -----------------------------------------
>
>                 Key: FLINK-1723
>                 URL: https://issues.apache.org/jira/browse/FLINK-1723
>             Project: Flink
>          Issue Type: New Feature
>          Components: Machine Learning Library
>            Reporter: Till Rohrmann
>            Assignee: Theodore Vasiloudis
>              Labels: ML
>
> Cross validation [1] is a standard tool to estimate the test error for a model. As such
it is a crucial tool for every machine learning library.
> The cross validation should work with arbitrary Estimators and error metrics. A first
cross validation strategy it should support is the k-fold cross validation.
> Resources:
> [1] [http://en.wikipedia.org/wiki/Cross-validation]



--
This message was sent by Atlassian JIRA
(v6.3.4#6332)

Mime
View raw message