[ https://issues.apache.org/jira/browse/FLINK1807?page=com.atlassian.jira.plugin.system.issuetabpanels:commenttabpanel&focusedCommentId=14526610#comment14526610
]
ASF GitHub Bot commented on FLINK1807:

Github user tillrohrmann commented on a diff in the pull request:
https://github.com/apache/flink/pull/613#discussion_r29583038
 Diff: flinkstaging/flinkml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala

@@ 0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.optimization
+
+import org.apache.flink.ml.common.{WeightVector, LabeledVector}
+import org.apache.flink.ml.math.{Vector => FlinkVector, BLAS}
+
+
+abstract class LossFunction extends Serializable{
+
+
+ /** Calculates the loss for a given prediction/truth pair
+ *
+ * @param prediction The predicted value
+ * @param truth The true value
+ */
+ protected def loss(prediction: Double, truth: Double): Double
+
+ /** Calculates the derivative of the loss function with respect to the prediction
+ *
+ * @param prediction The predicted value
+ * @param truth The true value
+ */
+ protected def lossDerivative(prediction: Double, truth: Double): Double
+
+ /** Compute the gradient and the loss for the given data.
+ * The provided cumGradient is updated in place.
+ *
+ * @param example The features and the label associated with the example
+ * @param weights The current weight vector
+ * @param cumGradient The vector to which the gradient will be added to, in place.
+ * @return A tuple containing the computed loss as its first element and a the loss
derivative as
+ * its second element.
+ */
+ def lossAndGradient(
+ example: LabeledVector,
+ weights: WeightVector,
+ cumGradient: FlinkVector,
+ regType: RegularizationType,
+ regParameter: Double): (Double, Double) = {
+ val features = example.vector
+ val label = example.label
+ // TODO(tvas): We could also provide for the case where we don't want an intercept
value
+ // i.e. data already centered
+ val prediction = BLAS.dot(features, weights.weights) + weights.intercept
+ val lossValue: Double = loss(prediction, label)
+ // The loss derivative is used to update the intercept
+ val lossDeriv= lossDerivative(prediction, label)
+ BLAS.axpy(lossDeriv , features, cumGradient)
+ val adjustedLoss = {
+ regType match {
+ case x : DiffRegularizationType => {
+ x.regularizedLossAndGradient(lossValue, weights.weights, cumGradient, regParameter)
 End diff 
We don't have to calculate the regularization gradient for every example, since it only
depends on the old weight vector. Thus, it would be more efficient to calculate the gradient
once in the weight update step.
> Stochastic gradient descent optimizer for ML library
> 
>
> Key: FLINK1807
> URL: https://issues.apache.org/jira/browse/FLINK1807
> Project: Flink
> Issue Type: Improvement
> Components: Machine Learning Library
> Reporter: Till Rohrmann
> Assignee: Theodore Vasiloudis
> Labels: ML
>
> Stochastic gradient descent (SGD) is a widely used optimization technique in different
ML algorithms. Thus, it would be helpful to provide a generalized SGD implementation which
can be instantiated with the respective gradient computation. Such a building block would
make the development of future algorithms easier.

This message was sent by Atlassian JIRA
(v6.3.4#6332)
