hama-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tomm...@apache.org
Subject svn commit: r1395024 - /hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Date Sat, 06 Oct 2012 12:26:05 GMT
Author: tommaso
Date: Sat Oct  6 12:26:04 2012
New Revision: 1395024

URL: http://svn.apache.org/viewvc?rev=1395024&view=rev
Log:
[HAMA-651] - adding collecting of cost and theta as output

Modified:
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java?rev=1395024&r1=1395023&r2=1395024&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java Sat
Oct  6 12:26:04 2012
@@ -35,7 +35,7 @@ import java.io.IOException;
  * A gradient descent (see <code>http://en.wikipedia.org/wiki/Gradient_descent</code>)
BSP based abstract implementation.
  * Each extending class should implement the #hypothesis(DoubleVector theta, DoubleVector
x) method for a specific
  */
-public abstract class GradientDescentBSP extends BSP<VectorWritable, DoubleWritable, NullWritable,
NullWritable, VectorWritable> {
+public abstract class GradientDescentBSP extends BSP<VectorWritable, DoubleWritable, VectorWritable,
DoubleWritable, VectorWritable> {
 
   private static final Logger log = LoggerFactory.getLogger(GradientDescentBSP.class);
   static final String INITIAL_THETA_VALUES = "initial.theta.values";
@@ -45,12 +45,12 @@ public abstract class GradientDescentBSP
   private DoubleVector theta;
 
   @Override
-  public void setup(BSPPeer<VectorWritable, DoubleWritable, NullWritable, NullWritable,
VectorWritable> peer) throws IOException, SyncException, InterruptedException {
+  public void setup(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable,
VectorWritable> peer) throws IOException, SyncException, InterruptedException {
     master = peer.getPeerIndex() == peer.getNumPeers() / 2;
   }
 
   @Override
-  public void bsp(BSPPeer<VectorWritable, DoubleWritable, NullWritable, NullWritable,
VectorWritable> peer) throws IOException, SyncException, InterruptedException {
+  public void bsp(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable,
VectorWritable> peer) throws IOException, SyncException, InterruptedException {
 
     while (true) {
 
@@ -137,6 +137,10 @@ public abstract class GradientDescentBSP
         if (log.isInfoEnabled()) {
           log.info("new theta for cost " + totalCost + " is " + theta.toArray().toString());
         }
+        // master writes down the output
+        if (master) {
+          peer.write(new VectorWritable(theta), new DoubleWritable(totalCost));
+        }
       }
       peer.sync();
 
@@ -159,7 +163,7 @@ public abstract class GradientDescentBSP
   public abstract double hypothesis(DoubleVector theta, DoubleVector x);
 
 
-  public void getTheta(BSPPeer<VectorWritable, DoubleWritable, NullWritable, NullWritable,
VectorWritable> peer) throws IOException, SyncException, InterruptedException {
+  public void getTheta(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable,
VectorWritable> peer) throws IOException, SyncException, InterruptedException {
     if (master && theta == null) {
       int size = getXSize(peer);
       theta = new DenseDoubleVector(size, peer.getConfiguration().getInt(INITIAL_THETA_VALUES,
10));
@@ -174,7 +178,7 @@ public abstract class GradientDescentBSP
     }
   }
 
-  private int getXSize(BSPPeer<VectorWritable, DoubleWritable, NullWritable, NullWritable,
VectorWritable> peer) throws IOException {
+  private int getXSize(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable,
VectorWritable> peer) throws IOException {
     VectorWritable key = null;
     peer.readNext(key, null);
     peer.reopenInput(); // reset input to start



Mime
View raw message