mahout-user mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From damodar shetyo <akshay.she...@gmail.com>
Subject Continued : simple OnlineLogisticRegression classication example using mahout
Date Thu, 28 Jun 2012 13:59:19 GMT
This post is continuation to another mailing thread thats going on,Sorry
for creating a new thread but  i was not getting mails from group before .

Following code was implemented By Ted Dunning .Now i have few questions:

1)The point (x,y) has 2 dimensions.But why are we using 3 instead of 2
while creating DenseVector?
  Vector v = new DenseVector(3);   / / why 3 , why not 2?

2) In getVector method why we set       v.set(2, 1); ??

3)Whats the use of setting lambda?

4)What happens if i increase or decrease learning rate?

I have read the book "Mahout in action " and am not able to understand
whats the use of this 2 parameters?


import com.google.common.collect.Lists;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class ClassifierExample {

    public static class Point {
        public int x;
        public int y;

        public Point(int x, int y) {
            this.x = x;
            this.y = y;
        }

        @Override
        public boolean equals(Object arg0) {
            Point p = (Point) arg0;
            return ((this.x == p.x) && (this.y == p.y));
        }

        @Override
        public String toString() {
            // TODO Auto-generated method stub
            return this.x + " , " + this.y;
        }
    }

    public static void main(String[] args) {

        Map<Point, Integer> points = new HashMap<Point,
                Integer>();

        points.put(new Point(0, 0), 0);
        points.put(new Point(1, 1), 0);
        points.put(new Point(1, 0), 0);
        points.put(new Point(0, 1), 0);
        points.put(new Point(2, 2), 0);


        points.put(new Point(8, 8), 1);
        points.put(new Point(8, 9), 1);
        points.put(new Point(9, 8), 1);
        points.put(new Point(9, 9), 1);


        OnlineLogisticRegression learningAlgo = new
OnlineLogisticRegression(2, 3, new L1());
        // this is a really big value which will make the model very
cautious
        // for lambda = 0.1, the first example below should be about .83
certain
        // for lambda = 0.01, the first example below should be about 0.98
certain

        learningAlgo.lambda(0.1);
        learningAlgo.learningRate(4);

        System.out.println("training model  \n");
        final List<Point> keys = Lists.newArrayList(points.keySet());
        // 200 times through the training data is probably over-kill.
 Itdoesn't matter
        // for tiny data.  The key here is total number of points seen, not
number of passes.

        for (int i = 0; i < 200; i++) {
            // randomize training data on each iteration
            Collections.shuffle(keys);
            for (Point point : keys) {
                Vector v = getVector(point);
                learningAlgo.train(points.get(point), v);
            }
        }
        learningAlgo.close();


        //now classify real data
        Vector v = new RandomAccessSparseVector(3);
        v.set(0, 0.5);
        v.set(1, 0.5);
        v.set(2, 1);

        Vector r = learningAlgo.classifyFull(v);
        System.out.println(r);

        System.out.println("ans = ");
        System.out.printf("no of categories = %d\n",
learningAlgo.numCategories());
        System.out.printf("no of features = %d\n",
learningAlgo.numFeatures());
        System.out.printf("Probability of cluster 0 = %.3f\n", r.get(0));
        System.out.printf("Probability of cluster 1 = %.3f\n", r.get(1));

        v.set(0, 4.5);
        v.set(1, 6.5);
        v.set(2, 1);

        r = learningAlgo.classifyFull(v);

        System.out.println("ans = ");
        System.out.printf("no of categories = %d\n",
learningAlgo.numCategories());
        System.out.printf("no of features =
%d\n",learningAlgo.numFeatures());
        System.out.printf("Probability of cluster 0 = %.3f\n", r.get(0));
        System.out.printf("Probability of cluster 1 = %.3f\n", r.get(1));

        // show how the score varies along a line from 0,0 to 1,1
        System.out.printf("\nx\tscore\n");
        for (int i = 0; i < 100; i++) {
            final double x = 0.0 + i / 10.0;
            v.set(0, x);
            v.set(1, x);
            v.set(2, 1);

            r = learningAlgo.classifyFull(v);

            System.out.printf("%.2f\t%.3f\n", x, r.get(1));
        }

    }

    public static Vector getVector(Point point) {
        Vector v = new DenseVector(3);
        v.set(0, point.x);
        v.set(1, point.y);
        v.set(2, 1);

        return v;
    }
}


-- 
Regards,
Damodar Shetyo

Mime
  • Unnamed multipart/alternative (inline, None, 0 bytes)
View raw message