hama-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tomm...@apache.org
Subject svn commit: r1395161 - /hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Date Sat, 06 Oct 2012 19:51:17 GMT
Author: tommaso
Date: Sat Oct  6 19:51:17 2012
New Revision: 1395161

URL: http://svn.apache.org/viewvc?rev=1395161&view=rev
Log:
[HAMA-651] - added calculateCostForItem method since also cost function is dependent from
the specific algorithm used

Modified:
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java?rev=1395161&r1=1395160&r2=1395161&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java Sat
Oct  6 19:51:17 2012
@@ -18,7 +18,6 @@
 package org.apache.hama.ml.regression;
 
 import org.apache.hadoop.io.DoubleWritable;
-import org.apache.hadoop.io.NullWritable;
 import org.apache.hama.bsp.BSP;
 import org.apache.hama.bsp.BSPPeer;
 import org.apache.hama.bsp.sync.SyncException;
@@ -33,7 +32,7 @@ import java.io.IOException;
 
 /**
  * A gradient descent (see <code>http://en.wikipedia.org/wiki/Gradient_descent</code>)
BSP based abstract implementation.
- * Each extending class should implement the #hypothesis(DoubleVector theta, DoubleVector
x) method for a specific
+ * Each extending class should implement the #applyHypothesis(DoubleVector theta, DoubleVector
x) method for a specific
  */
 public abstract class GradientDescentBSP extends BSP<VectorWritable, DoubleWritable, VectorWritable,
DoubleWritable, VectorWritable> {
 
@@ -69,7 +68,7 @@ public abstract class GradientDescentBSP
         // calculate cost for given input
         double y = kvp.getValue().get();
         DoubleVector x = kvp.getKey().getVector();
-        double costForX = y * Math.log(hypothesis(theta, x)) + (1 - y) * Math.log(1 - hypothesis(theta,
x));
+        double costForX = calculateCostForItem(y, x, theta);
 
         // adds to local cost
         localCost += costForX;
@@ -84,7 +83,7 @@ public abstract class GradientDescentBSP
       }
       peer.sync();
 
-      // second superstep : cost calculation
+      // second superstep : aggregate cost calculation
 
       VectorWritable costResult;
       while ((costResult = peer.getCurrentMessage()) != null) {
@@ -92,7 +91,8 @@ public abstract class GradientDescentBSP
         numRead += costResult.getVector().get(1);
       }
 
-      totalCost = totalCost * (-1 / numRead);
+      totalCost /= numRead;
+
       if (log.isInfoEnabled()) {
         log.info("cost is " + totalCost);
       }
@@ -103,11 +103,11 @@ public abstract class GradientDescentBSP
 
       double[] thetaDelta = new double[theta.getLength()];
 
-      // second superstep : calculate partial derivatives in parallel
+      // third superstep : calculate partial derivatives' deltas in parallel
       while ((kvp = peer.readNext()) != null) {
         DoubleVector x = kvp.getKey().getVector();
         double y = kvp.getValue().get();
-        double difference = hypothesis(theta, x) - y;
+        double difference = applyHypothesis(theta, x) - y;
         for (int j = 0; j < theta.getLength(); j++) {
           thetaDelta[j] += difference * x.get(j);
         }
@@ -120,6 +120,7 @@ public abstract class GradientDescentBSP
 
       peer.sync();
 
+      // fourth superstep : aggregate partial derivatives
       VectorWritable thetaDeltaSlice;
       while ((thetaDeltaSlice = peer.getCurrentMessage()) != null) {
         double[] newTheta = new double[theta.getLength()];
@@ -154,13 +155,22 @@ public abstract class GradientDescentBSP
   }
 
   /**
-   * Applies the hypothesis given a set of parameters theta to a given input x
+   * Calculates the cost function for a given item (input x, output y)
+   * @param y the learned output for x
+   * @param x the input vector
+   * @param theta the parameters vector theta
+   * @return the calculated cost for input x and output y
+  */
+  protected abstract double calculateCostForItem(double y, DoubleVector x, DoubleVector theta);
+
+  /**
+   * Applies the applyHypothesis given a set of parameters theta to a given input x
    *
    * @param theta the parameters vector
    * @param x     the input
    * @return a <code>double</code> number
    */
-  public abstract double hypothesis(DoubleVector theta, DoubleVector x);
+  public abstract double applyHypothesis(DoubleVector theta, DoubleVector x);
 
 
   public void getTheta(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable,
VectorWritable> peer) throws IOException, SyncException, InterruptedException {



Mime
View raw message