4c66afff91566c7adf9ccef148d9221a44c7f930
[IRC.git] / Robust / src / Benchmarks / SSJava / EyeTracking / ClassifierTree.java
1 /*
2  * Copyright 2009 (c) Florian Frankenberger (darkblue.de)
3  * 
4  * This file is part of LEA.
5  * 
6  * LEA is free software: you can redistribute it and/or modify it under the
7  * terms of the GNU Lesser General Public License as published by the Free
8  * Software Foundation, either version 3 of the License, or (at your option) any
9  * later version.
10  * 
11  * LEA is distributed in the hope that it will be useful, but WITHOUT ANY
12  * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13  * A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
14  * details.
15  * 
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with LEA. If not, see <http://www.gnu.org/licenses/>.
18  */
19
20 import java.awt.RenderingHints;
21 import java.awt.color.ColorSpace;
22 import java.awt.geom.Rectangle2D;
23 import java.awt.image.BufferedImage;
24 import java.awt.image.BufferedImageOp;
25 import java.awt.image.ColorConvertOp;
26 import java.io.IOException;
27 import java.io.InputStream;
28 import java.io.InputStreamReader;
29 import java.io.OutputStream;
30 import java.io.OutputStreamWriter;
31 import java.io.PrintWriter;
32 import java.io.Reader;
33 import java.util.ArrayList;
34 import java.util.Collections;
35 import java.util.List;
36
37 /**
38  * 
39  * @author Florian
40  */
41 public class ClassifierTree {
42
43   private List<Classifier> classifiers;
44   private static XStream xStream = new XStream(new DomDriver());
45
46   static {
47     xStream.alias("ClassifierTree", ClassifierTree.class);
48     xStream.alias("Classifier", Classifier.class);
49     xStream.alias("ScanArea", ScanArea.class);
50   }
51
52   public ClassifierTree(List<Classifier> classifier) {
53     this.classifiers = new ArrayList<Classifier>(classifier);
54     Collections.sort(this.classifiers);
55   }
56
57   @Override
58   public String toString() {
59     StringBuilder sb = new StringBuilder();
60     sb.append("ClassifierTree {\n");
61     for (Classifier classifier : this.classifiers) {
62       sb.append(classifier.toString());
63       sb.append('\n');
64     }
65     sb.append("}\n");
66     return sb.toString();
67   }
68
69   public static BufferedImage resizeImageFittingInto(BufferedImage image, int dimension) {
70
71     int newHeight = 0;
72     int newWidth = 0;
73     float factor = 0;
74     if (image.getWidth() > image.getHeight()) {
75       factor = dimension / (float) image.getWidth();
76       newWidth = dimension;
77       newHeight = (int) (factor * image.getHeight());
78     } else {
79       factor = dimension / (float) image.getHeight();
80       newHeight = dimension;
81       newWidth = (int) (factor * image.getWidth());
82     }
83
84     if (factor > 1) {
85       BufferedImageOp op = new ColorConvertOp(ColorSpace.getInstance(ColorSpace.CS_GRAY), null);
86       BufferedImage tmpImage = op.filter(image, null);
87
88       return tmpImage;
89     }
90
91     BufferedImage resizedImage = new BufferedImage(newWidth, newHeight, BufferedImage.TYPE_INT_RGB);
92
93     Graphics2D g2D = resizedImage.createGraphics();
94     g2D.setRenderingHint(RenderingHints.KEY_INTERPOLATION,
95         RenderingHints.VALUE_INTERPOLATION_NEAREST_NEIGHBOR);
96
97     g2D.drawImage(image, 0, 0, newWidth - 1, newHeight - 1, 0, 0, image.getWidth() - 1,
98         image.getHeight() - 1, null);
99
100     BufferedImageOp op = new ColorConvertOp(ColorSpace.getInstance(ColorSpace.CS_GRAY), null);
101     BufferedImage tmpImage = op.filter(resizedImage, null);
102
103     return tmpImage;
104   }
105
106   /**
107    * Image should have 100x100px and should be in b/w
108    * 
109    * @param image
110    */
111   public void learn(BufferedImage image, boolean isFace) {
112     IntegralImageData imageData = new IntegralImageData(image);
113     for (Classifier classifier : this.classifiers) {
114       classifier.learn(imageData, isFace);
115     }
116   }
117
118   public int getLearnedFacesYes() {
119     return this.classifiers.get(0).getLearnedFacesYes();
120   }
121
122   public int getLearnedFacesNo() {
123     return this.classifiers.get(0).getLearnedFacesNo();
124   }
125
126   /**
127    * Locates a face by linear iteration through all probable face positions
128    * 
129    * @deprecated use locateFaceRadial instead for improved performance
130    * @param image
131    * @return an rectangle representing the actual face position on success or
132    *         null if no face could be detected
133    */
134   public Rectangle2D locateFace(BufferedImage image) {
135     long timeStart = System.currentTimeMillis();
136
137     int resizeTo = 600;
138
139     BufferedImage smallImage = resizeImageFittingInto(image, resizeTo);
140     IntegralImageData imageData = new IntegralImageData(smallImage);
141
142     float factor = image.getWidth() / (float) smallImage.getWidth();
143
144     int maxIterations = 0;
145
146     // first we calculate the maximum scale factor for our 200x200 image
147     float maxScaleFactor = Math.min(imageData.getWidth() / 100f, imageData.getHeight() / 100f);
148
149     // we simply won't recognize faces that are smaller than 40x40 px
150     float minScaleFactor = 0.5f;
151
152     // border for faceYes-possibility must be greater that that
153     float maxBorder = 0.999f;
154
155     for (float scale = maxScaleFactor; scale > minScaleFactor; scale -= 0.25) {
156       int actualDimension = (int) (scale * 100);
157       int borderX = imageData.getWidth() - actualDimension;
158       int borderY = imageData.getHeight() - actualDimension;
159       for (int x = 0; x <= borderX; ++x) {
160         yLines: for (int y = 0; y <= borderY; ++y) {
161
162           for (int iterations = 0; iterations < this.classifiers.size(); ++iterations) {
163             Classifier classifier = this.classifiers.get(iterations);
164
165             float borderline =
166                 0.8f + (iterations / this.classifiers.size() - 1) * (maxBorder - 0.8f);
167             if (iterations > maxIterations)
168               maxIterations = iterations;
169             if (!classifier.classifyFace(imageData, scale, x, y, borderline)) {
170               continue yLines;
171             }
172           }
173
174           // if we reach here we have a face recognized because our image went
175           // through all
176           // classifiers
177
178           Rectangle2D faceRect =
179               new Rectangle2D.Float(x * factor, y * factor, actualDimension * factor,
180                   actualDimension * factor);
181
182           System.out.println("Time: " + (System.currentTimeMillis() - timeStart) + "ms");
183           return faceRect;
184
185         }
186       }
187     }
188
189     return null;
190   }
191
192   /**
193    * Locates a face by searching radial starting at the last known position. If
194    * lastCoordinates are null we simply start in the center of the image.
195    * <p>
196    * TODO: This method could quite possible be tweaked so that face recognition
197    * would be much faster
198    * 
199    * @param image
200    *          the image to process
201    * @param lastCoordinates
202    *          the last known coordinates or null if unknown
203    * @return an rectangle representing the actual face position on success or
204    *         null if no face could be detected
205    */
206   public Rectangle2D locateFaceRadial(BufferedImage image, Rectangle2D lastCoordinates) {
207
208     int resizeTo = 600;
209
210     BufferedImage smallImage = resizeImageFittingInto(image, resizeTo);
211     float originalImageFactor = image.getWidth() / (float) smallImage.getWidth();
212     IntegralImageData imageData = new IntegralImageData(smallImage);
213
214     if (lastCoordinates == null) {
215       // if we don't have a last coordinate we just begin in the center
216       int smallImageMaxDimension = Math.min(smallImage.getWidth(), smallImage.getHeight());
217       lastCoordinates =
218           new Rectangle2D.Float((smallImage.getWidth() - smallImageMaxDimension) / 2.0f,
219               (smallImage.getHeight() - smallImageMaxDimension) / 2.0f, smallImageMaxDimension,
220               smallImageMaxDimension);
221     } else {
222       // first we have to scale the last coodinates back relative to the resized
223       // image
224       lastCoordinates =
225           new Rectangle2D.Float((float) (lastCoordinates.getX() * (1 / originalImageFactor)),
226               (float) (lastCoordinates.getY() * (1 / originalImageFactor)),
227               (float) (lastCoordinates.getWidth() * (1 / originalImageFactor)),
228               (float) (lastCoordinates.getHeight() * (1 / originalImageFactor)));
229     }
230
231     float startFactor = (float) (lastCoordinates.getWidth() / 100.0f);
232
233     // first we calculate the maximum scale factor for our 200x200 image
234     float maxScaleFactor = Math.min(imageData.getWidth() / 100f, imageData.getHeight() / 100f);
235     // maxScaleFactor = 1.0f;
236
237     // we simply won't recognize faces that are smaller than 40x40 px
238     float minScaleFactor = 0.5f;
239
240     float maxScaleDifference =
241         Math.max(Math.abs(maxScaleFactor - startFactor), Math.abs(minScaleFactor - startFactor));
242
243     // border for faceYes-possibility must be greater that that
244     float maxBorder = 0.999f;
245
246     int startPosX = (int) lastCoordinates.getX();
247     int startPosY = (int) lastCoordinates.getX();
248
249     for (float factorDiff = 0.0f; Math.abs(factorDiff) <= maxScaleDifference; factorDiff =
250         (factorDiff + sgn(factorDiff) * 0.1f) * -1 // we alternate between
251                                                    // negative and positiv
252                                                    // factors
253     ) {
254
255       float factor = startFactor + factorDiff;
256       if (factor > maxScaleFactor || factor < minScaleFactor)
257         continue;
258
259       // now we calculate the actualDimmension
260       int actualDimmension = (int) (100 * factor);
261       int maxX = imageData.getWidth() - actualDimmension;
262       int maxY = imageData.getHeight() - actualDimmension;
263
264       int maxDiffX = Math.max(Math.abs(startPosX - maxX), startPosX);
265       int maxDiffY = Math.max(Math.abs(startPosY - maxY), startPosY);
266
267       for (float xDiff = 0.1f; Math.abs(xDiff) <= maxDiffX; xDiff =
268           (xDiff + sgn(xDiff) * 0.5f) * -1) {
269         int xPos = Math.round(startPosX + xDiff);
270         if (xPos < 0 || xPos > maxX)
271           continue;
272
273         yLines: for (float yDiff = 0.1f; Math.abs(yDiff) <= maxDiffY; yDiff =
274             (yDiff + sgn(yDiff) * 0.5f) * -1) {
275           int yPos = Math.round(startPosY + yDiff);
276           if (yPos < 0 || yPos > maxY)
277             continue;
278
279           // by now we should have a valid coordinate to process which we should
280           // do now
281           for (int iterations = 0; iterations < this.classifiers.size(); ++iterations) {
282             Classifier classifier = this.classifiers.get(iterations);
283
284             float borderline =
285                 0.8f + (iterations / (this.classifiers.size() - 1)) * (maxBorder - 0.8f);
286
287             if (!classifier.classifyFace(imageData, factor, xPos, yPos, borderline)) {
288               continue yLines;
289             }
290           }
291
292           // if we reach here we have a face recognized because our image went
293           // through all
294           // classifiers
295
296           Rectangle2D faceRect =
297               new Rectangle2D.Float(xPos * originalImageFactor, yPos * originalImageFactor,
298                   actualDimmension * originalImageFactor, actualDimmension * originalImageFactor);
299
300           return faceRect;
301
302         }
303
304       }
305
306     }
307
308     // System.out.println("Time: "+(System.currentTimeMillis()-timeStart)+"ms");
309     return null;
310
311   }
312
313   public List<Classifier> getClassifiers() {
314     return new ArrayList<Classifier>(this.classifiers);
315   }
316
317   public static void saveToXml(OutputStream out, ClassifierTree tree) throws IOException {
318     PrintWriter writer = new PrintWriter(new OutputStreamWriter(out, "UTF-8"));
319     writer.write(xStream.toXML(tree));
320     writer.close();
321   }
322
323   public static ClassifierTree loadFromXml(InputStream in) throws IOException {
324     Reader reader = new InputStreamReader(in, "UTF-8");
325     StringBuilder sb = new StringBuilder();
326
327     char[] buffer = new char[1024];
328     int read = 0;
329     do {
330       read = reader.read(buffer);
331       if (read > 0) {
332         sb.append(buffer, 0, read);
333       }
334     } while (read > -1);
335     reader.close();
336
337     return (ClassifierTree) xStream.fromXML(sb.toString());
338   }
339
340   private static int sgn(float value) {
341     return (value < 0 ? -1 : (value > 0 ? +1 : 1));
342   }
343
344 }