clean out all my changes
[IRC.git] / Robust / src / Benchmarks / SingleTM / Bayes / Bayes.java
1 /* =============================================================================
2  *
3  * bayes.java
4  *
5  * =============================================================================
6  *
7  * Copyright (C) Stanford University, 2006.  All Rights Reserved.
8  * Author: Chi Cao Minh
9  * Ported to Java June 2009 Alokika Dash
10  * University of California, Irvine
11  *
12  * =============================================================================
13  *
14  * For the license of bayes/sort.h and bayes/sort.c, please see the header
15  * of the files.
16  * 
17  * ------------------------------------------------------------------------
18  * 
19  * Unless otherwise noted, the following license applies to STAMP files:
20  * 
21  * Copyright (c) 2007, Stanford University
22  * All rights reserved.
23  * 
24  * Redistribution and use in source and binary forms, with or without
25  * modification, are permitted provided that the following conditions are
26  * met:
27  * 
28  *     * Redistributions of source code must retain the above copyright
29  *       notice, this list of conditions and the following disclaimer.
30  * 
31  *     * Redistributions in binary form must reproduce the above copyright
32  *       notice, this list of conditions and the following disclaimer in
33  *       the documentation and/or other materials provided with the
34  *       distribution.
35  * 
36  *     * Neither the name of Stanford University nor the names of its
37  *       contributors may be used to endorse or promote products derived
38  *       from this software without specific prior written permission.
39  * 
40  * THIS SOFTWARE IS PROVIDED BY STANFORD UNIVERSITY ``AS IS'' AND ANY
41  * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
42  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
43  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL STANFORD UNIVERSITY BE LIABLE
44  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
45  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
46  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
47  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
48  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
49  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
50  * THE POSSIBILITY OF SUCH DAMAGE.
51  *
52  * =============================================================================
53  */
54
55 #define PARAM_DEFAULT_QUALITY   1.0f
56 #define PARAM_EDGE      101
57 #define PARAM_INSERT    105
58 #define PARAM_NUMBER    110
59 #define PARAM_PERCENT   112
60 #define PARAM_RECORD    114
61 #define PARAM_SEED      115
62 #define PARAM_THREAD    116
63 #define PARAM_VAR       118
64
65 #define PARAM_DEFAULT_EDGE     -1
66 #define PARAM_DEFAULT_INSERT   1
67 #define PARAM_DEFAULT_NUMBER   4
68 #define PARAM_DEFAULT_PERCENT  10
69 #define PARAM_DEFAULT_RECORD   4096
70 #define PARAM_DEFAULT_SEED     1
71 #define PARAM_DEFAULT_THREAD   1
72 #define PARAM_DEFAULT_VAR      32
73
74 public class Bayes extends Thread {
75   public int[] global_params; /* 256 = ascii limit */
76   public int global_maxNumEdgeLearned;
77   public int global_insertPenalty;
78   public float global_operationQualityFactor;
79
80   /* Number of threads */
81   int numThread;
82
83   /* thread id */
84   int myId;
85
86   /* Global learn pointer */
87   Learner learnerPtr;
88
89   public Bayes() {
90     global_params = new int[256];
91     global_maxNumEdgeLearned = PARAM_DEFAULT_EDGE;
92     global_insertPenalty = PARAM_DEFAULT_INSERT;
93     global_operationQualityFactor = PARAM_DEFAULT_QUALITY;
94   }
95
96   public Bayes(int numThread, int myId, Learner learnerPtr) {
97     this.numThread = numThread;
98     this.myId = myId;
99     this.learnerPtr = learnerPtr;
100   }
101
102
103   /* =============================================================================
104    * displayUsage
105    * =============================================================================
106    */
107   public void
108     displayUsage ()
109     {
110       System.out.println("Usage: ./Bayes.bin [options]");
111       System.out.println("    e Max [e]dges learned per variable  ");
112       System.out.println("    i Edge [i]nsert penalty             ");
113       System.out.println("    n Max [n]umber of parents           ");
114       System.out.println("    p [p]ercent chance of parent        ");
115       System.out.println("    q Operation [q]uality factor        ");
116       System.out.println("    r Number of [r]ecords               ");
117       System.out.println("    s Random [s]eed                     ");
118       System.out.println("    t Number of [t]hreads               ");
119       System.out.println("    v Number of [v]ariables             ");
120       System.exit(1);
121     }
122
123
124   /* =============================================================================
125    * setDefaultParams
126    * =============================================================================
127    */
128   public void
129     setDefaultParams ()
130     {
131       global_params[PARAM_EDGE]    = PARAM_DEFAULT_EDGE;
132       global_params[PARAM_INSERT]  = PARAM_DEFAULT_INSERT;
133       global_params[PARAM_NUMBER]  = PARAM_DEFAULT_NUMBER;
134       global_params[PARAM_PERCENT] = PARAM_DEFAULT_PERCENT;
135       global_params[PARAM_RECORD]  = PARAM_DEFAULT_RECORD;
136       global_params[PARAM_SEED]    = PARAM_DEFAULT_SEED;
137       global_params[PARAM_THREAD]  = PARAM_DEFAULT_THREAD;
138       global_params[PARAM_VAR]     = PARAM_DEFAULT_VAR;
139     }
140
141
142   /* =============================================================================
143    * parseArgs
144    * =============================================================================
145    */
146   public static void
147     parseArgs (String[] args, Bayes b)
148     {
149       int i = 0;
150       String arg;
151       b.setDefaultParams();
152       while(i < args.length && args[i].startsWith("-")) {
153         arg = args[i++];
154         //check options
155         if(arg.equals("-e")) {
156           if(i < args.length) {
157             b.global_params[PARAM_EDGE] = new Integer(args[i++]).intValue();
158           }
159         } else if(arg.equals("-i")) {
160           if (i < args.length) {
161             b.global_params[PARAM_INSERT] = new Integer(args[i++]).intValue();
162           }
163         } else if (arg.equals("-n")) {
164           if (i < args.length) {
165             b.global_params[PARAM_NUMBER] = new Integer(args[i++]).intValue();
166           }
167         } else if (arg.equals("-p")) {
168           if (i < args.length) {
169             b.global_params[PARAM_PERCENT] = new Integer(args[i++]).intValue();
170           }
171         } else if (arg.equals("-r")) {
172           if (i < args.length) {
173             b.global_params[PARAM_RECORD] = new Integer(args[i++]).intValue();
174           }
175         } else if (arg.equals("-s")) {
176           if (i < args.length) {
177             b.global_params[PARAM_SEED] = new Integer(args[i++]).intValue();
178           }
179         } else if (arg.equals("-t")) {
180           if (i < args.length) {
181             b.global_params[PARAM_THREAD] = new Integer(args[i++]).intValue();
182           }
183         } else if (arg.equals("-v")) {
184           if (i < args.length) {
185             b.global_params[PARAM_VAR] = new Integer(args[i++]).intValue();
186           }
187         } else if(arg.equals("-h")) {
188           b.displayUsage();
189         }
190       }
191
192       if (b.global_params[PARAM_THREAD] == 0) {
193         b.displayUsage();
194       }
195     }
196
197
198   /* =============================================================================
199    * score
200    * =============================================================================
201    */
202   public float score (Net netPtr, Adtree adtreePtr) {
203     /*
204      * Create dummy data structures to conform to learner_score assumptions
205      */
206     
207     Data dataPtr = new Data(1, 1, null);
208     
209     Learner learnerPtr = new Learner(dataPtr, adtreePtr, 1, global_insertPenalty, global_maxNumEdgeLearned, global_operationQualityFactor);
210     
211     Net tmpNetPtr = learnerPtr.netPtr;
212     learnerPtr.netPtr = netPtr;
213     
214     float score = learnerPtr.learner_score();
215     learnerPtr.netPtr = tmpNetPtr;
216     learnerPtr.learner_free();
217     dataPtr.data_free();
218
219
220
221     return score;
222   }
223
224
225   /**
226    * parallel execution
227    **/
228   public void run() {
229     Barrier.enterBarrier();
230     Learner.createTaskList(myId, numThread, learnerPtr);
231     Barrier.enterBarrier();
232
233     Barrier.enterBarrier();
234     Learner.learnStructure(myId, numThread, learnerPtr);
235     Barrier.enterBarrier();
236   }
237     
238
239   /* =============================================================================
240    * main
241    * =============================================================================
242    */
243
244   public static void main(String[] args) {
245     /*
246      * Initialization
247      */
248     Bayes b = new Bayes();
249     Bayes.parseArgs(args, b);
250     int numThread     = b.global_params[PARAM_THREAD];
251     int numVar        = b.global_params[PARAM_VAR];
252     int numRecord     = b.global_params[PARAM_RECORD];
253     int randomSeed    = b.global_params[PARAM_SEED];
254     int maxNumParent  = b.global_params[PARAM_NUMBER];
255     int percentParent = b.global_params[PARAM_PERCENT];
256     b.global_insertPenalty = b.global_params[PARAM_INSERT];
257     b.global_maxNumEdgeLearned = b.global_params[PARAM_EDGE];
258
259     /* Initiate Barriers */
260     Barrier.setBarrier(numThread);
261
262     Bayes[] binit = new Bayes[numThread];
263
264     System.out.println("Number of threads          " + numThread);
265     System.out.println("Random seed                " + randomSeed);
266     System.out.println("Number of vars             " + numVar);
267     System.out.println("Number of records          " + numRecord);
268     System.out.println("Max num parents            " + maxNumParent);
269     System.out.println("%% chance of parent        " + percentParent);
270     System.out.println("Insert penalty             " + b.global_insertPenalty);
271     System.out.println("Max num edge learned / var " + b.global_maxNumEdgeLearned);
272     System.out.println("Operation quality factor   " + b.global_operationQualityFactor);
273
274     /*
275      * Generate data
276      */
277
278     System.out.print("Generating data... ");
279
280     Random randomPtr = new Random();
281     randomPtr.random_alloc();
282     randomPtr.random_seed(randomSeed);
283
284     Data dataPtr = new Data(numVar, numRecord, randomPtr); 
285
286     Net netPtr = dataPtr.data_generate(-1, maxNumParent, percentParent);
287     System.out.println("done.");
288
289     /*
290      * Generate adtree
291      */
292
293     Adtree adtreePtr = new Adtree();
294
295     System.out.print("Generating adtree... ");
296
297     adtreePtr.adtree_make(dataPtr);
298     dataPtr.data_free();
299
300     System.out.println("done.");
301
302     /*
303      * Score original network
304      */
305
306     float actualScore = b.score(netPtr, adtreePtr);
307     netPtr.net_free();
308
309     /*
310      * Learn structure of Bayesian network
311      */
312
313     Learner learnerPtr = new Learner(dataPtr, adtreePtr, numThread, b.global_insertPenalty, b.global_maxNumEdgeLearned, b.global_operationQualityFactor);
314
315     System.out.print("Learning structure...");
316
317     /* Create and Start Threads */
318     for(int i = 1; i<numThread; i++) {
319       binit[i] = new Bayes(i, numThread, learnerPtr);
320     }
321
322     for(int i = 1; i<numThread; i++) {
323       binit[i].start();
324     }
325
326
327     /** 
328       * Parallel work by all threads
329       **/
330     long start=System.currentTimeMillis();
331
332     Barrier.enterBarrier();
333     Learner.createTaskList(0, numThread, learnerPtr);
334     Barrier.enterBarrier();
335
336     Barrier.enterBarrier();
337     Learner.learnStructure(0, numThread, learnerPtr);
338     Barrier.enterBarrier();
339     long stop=System.currentTimeMillis();
340
341     long diff=stop-start;
342     System.out.println("TIME="+diff);
343
344     System.out.println("done.");
345
346     /*
347      * Check solution
348      */
349
350     boolean status = learnerPtr.netPtr.net_isCycle();
351     if(status) {
352       System.out.println("Assert failed: system has an incorrect result");
353       System.exit(0);
354     }
355
356 #ifndef SIMULATOR
357     float learnScore = learnerPtr.learner_score();
358     System.out.println("Learn score= " + (double)learnScore);
359 #endif
360     System.out.println("Actual score= " + (double)actualScore);
361
362     /*
363      * Clean up
364      */
365   }
366 }
367 /* =============================================================================
368  *
369  * End of bayes.java
370  *
371  * =============================================================================
372  */