89663d3fce8f929dc170c757db06a55d4c79d41e
[IRC.git] / Robust / src / Benchmarks / SingleTM / Bayes / Bayes.java
1 /* =============================================================================
2  *
3  * bayes.java
4  *
5  * =============================================================================
6  *
7  * Copyright (C) Stanford University, 2006.  All Rights Reserved.
8  * Author: Chi Cao Minh
9  * Ported to Java June 2009 Alokika Dash
10  * University of California, Irvine
11  *
12  * =============================================================================
13  *
14  * For the license of bayes/sort.h and bayes/sort.c, please see the header
15  * of the files.
16  * 
17  * ------------------------------------------------------------------------
18  * 
19  * Unless otherwise noted, the following license applies to STAMP files:
20  * 
21  * Copyright (c) 2007, Stanford University
22  * All rights reserved.
23  * 
24  * Redistribution and use in source and binary forms, with or without
25  * modification, are permitted provided that the following conditions are
26  * met:
27  * 
28  *     * Redistributions of source code must retain the above copyright
29  *       notice, this list of conditions and the following disclaimer.
30  * 
31  *     * Redistributions in binary form must reproduce the above copyright
32  *       notice, this list of conditions and the following disclaimer in
33  *       the documentation and/or other materials provided with the
34  *       distribution.
35  * 
36  *     * Neither the name of Stanford University nor the names of its
37  *       contributors may be used to endorse or promote products derived
38  *       from this software without specific prior written permission.
39  * 
40  * THIS SOFTWARE IS PROVIDED BY STANFORD UNIVERSITY ``AS IS'' AND ANY
41  * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
42  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
43  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL STANFORD UNIVERSITY BE LIABLE
44  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
45  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
46  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
47  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
48  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
49  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
50  * THE POSSIBILITY OF SUCH DAMAGE.
51  *
52  * =============================================================================
53  */
54
55 #define PARAM_DEFAULT_QUALITY   1.0f
56 #define PARAM_EDGE      101
57 #define PARAM_INSERT    105
58 #define PARAM_NUMBER    110
59 #define PARAM_PERCENT   112
60 #define PARAM_RECORD    114
61 #define PARAM_SEED      115
62 #define PARAM_THREAD    116
63 #define PARAM_VAR       118
64
65 #define PARAM_DEFAULT_EDGE     -1
66 #define PARAM_DEFAULT_INSERT   1
67 #define PARAM_DEFAULT_NUMBER   4
68 #define PARAM_DEFAULT_PERCENT  10
69 #define PARAM_DEFAULT_RECORD   4096
70 #define PARAM_DEFAULT_SEED     1
71 #define PARAM_DEFAULT_THREAD   1
72 #define PARAM_DEFAULT_VAR      32
73
74 public class Bayes extends Thread {
75   public int[] global_params; /* 256 = ascii limit */
76   public int global_maxNumEdgeLearned;
77   public int global_insertPenalty;
78   public float global_operationQualityFactor;
79
80   /* Number of threads */
81   int numThread;
82
83   /* thread id */
84   int myId;
85
86   /* Global learn pointer */
87   Learner learnerPtr;
88
89   public Bayes() {
90     global_params = new int[256];
91     global_maxNumEdgeLearned = PARAM_DEFAULT_EDGE;
92     global_insertPenalty = PARAM_DEFAULT_INSERT;
93     global_operationQualityFactor = PARAM_DEFAULT_QUALITY;
94   }
95
96   public Bayes(int numThread, int myId, Learner learnerPtr) {
97     this.numThread = numThread;
98     this.myId = myId;
99     this.learnerPtr = learnerPtr;
100   }
101
102
103   /* =============================================================================
104    * displayUsage
105    * =============================================================================
106    */
107   public void
108     displayUsage ()
109     {
110       System.out.println("Usage: ./Bayes.bin [options]");
111       System.out.println("    e Max [e]dges learned per variable  ");
112       System.out.println("    i Edge [i]nsert penalty             ");
113       System.out.println("    n Max [n]umber of parents           ");
114       System.out.println("    p [p]ercent chance of parent        ");
115       System.out.println("    q Operation [q]uality factor        ");
116       System.out.println("    r Number of [r]ecords               ");
117       System.out.println("    s Random [s]eed                     ");
118       System.out.println("    t Number of [t]hreads               ");
119       System.out.println("    v Number of [v]ariables             ");
120       System.exit(1);
121     }
122
123
124   /* =============================================================================
125    * setDefaultParams
126    * =============================================================================
127    */
128   public void
129     setDefaultParams ()
130     {
131       global_params[PARAM_EDGE]    = PARAM_DEFAULT_EDGE;
132       global_params[PARAM_INSERT]  = PARAM_DEFAULT_INSERT;
133       global_params[PARAM_NUMBER]  = PARAM_DEFAULT_NUMBER;
134       global_params[PARAM_PERCENT] = PARAM_DEFAULT_PERCENT;
135       global_params[PARAM_RECORD]  = PARAM_DEFAULT_RECORD;
136       global_params[PARAM_SEED]    = PARAM_DEFAULT_SEED;
137       global_params[PARAM_THREAD]  = PARAM_DEFAULT_THREAD;
138       global_params[PARAM_VAR]     = PARAM_DEFAULT_VAR;
139     }
140
141
142   /* =============================================================================
143    * parseArgs
144    * =============================================================================
145    */
146   public static void
147     parseArgs (String[] args, Bayes b)
148     {
149       int i = 0;
150       String arg;
151       b.setDefaultParams();
152       while(i < args.length && args[i].startsWith("-")) {
153         arg = args[i++];
154         //check options
155         if(arg.equals("-e")) {
156           if(i < args.length) {
157             b.global_params[PARAM_EDGE] = new Integer(args[i++]).intValue();
158           }
159         } else if(arg.equals("-i")) {
160           if (i < args.length) {
161             b.global_params[PARAM_INSERT] = new Integer(args[i++]).intValue();
162           }
163         } else if (arg.equals("-n")) {
164           if (i < args.length) {
165             b.global_params[PARAM_NUMBER] = new Integer(args[i++]).intValue();
166           }
167         } else if (arg.equals("-p")) {
168           if (i < args.length) {
169             b.global_params[PARAM_PERCENT] = new Integer(args[i++]).intValue();
170           }
171         } else if (arg.equals("-r")) {
172           if (i < args.length) {
173             b.global_params[PARAM_RECORD] = new Integer(args[i++]).intValue();
174           }
175         } else if (arg.equals("-s")) {
176           if (i < args.length) {
177             b.global_params[PARAM_SEED] = new Integer(args[i++]).intValue();
178           }
179         } else if (arg.equals("-t")) {
180           if (i < args.length) {
181             b.global_params[PARAM_THREAD] = new Integer(args[i++]).intValue();
182           }
183         } else if (arg.equals("-v")) {
184           if (i < args.length) {
185             b.global_params[PARAM_VAR] = new Integer(args[i++]).intValue();
186           }
187         } else if(arg.equals("-h")) {
188           b.displayUsage();
189         }
190       }
191
192       if (b.global_params[PARAM_THREAD] == 0) {
193         b.displayUsage();
194       }
195     }
196
197
198   /* =============================================================================
199    * score
200    * =============================================================================
201    */
202   public float
203     score (Net netPtr, Adtree adtreePtr)
204     {
205       /*
206        * Create dummy data structures to conform to learner_score assumptions
207        */
208
209       Data dataPtr = Data.data_alloc(1, 1, null);
210
211       Learner learnerPtr = Learner.learner_alloc(dataPtr, adtreePtr, 1, global_insertPenalty, global_maxNumEdgeLearned, global_operationQualityFactor);
212
213       Net tmpNetPtr = learnerPtr.netPtr;
214       learnerPtr.netPtr = netPtr;
215
216       float score = learnerPtr.learner_score();
217
218       learnerPtr.netPtr = tmpNetPtr;
219       learnerPtr.learner_free();
220       dataPtr.data_free();
221
222       return score;
223     }
224
225
226   /**
227    * parallel execution
228    **/
229   public void run() {
230     Barrier.enterBarrier();
231     Learner.createTaskList(myId, numThread, learnerPtr);
232     Barrier.enterBarrier();
233
234     Barrier.enterBarrier();
235     Learner.learnStructure(myId, numThread, learnerPtr);
236     Barrier.enterBarrier();
237   }
238     
239
240   /* =============================================================================
241    * main
242    * =============================================================================
243    */
244
245   public static void main(String[] args) {
246     /*
247      * Initialization
248      */
249     Bayes b = new Bayes();
250     Bayes.parseArgs(args, b);
251     int numThread     = b.global_params[PARAM_THREAD];
252     int numVar        = b.global_params[PARAM_VAR];
253     int numRecord     = b.global_params[PARAM_RECORD];
254     int randomSeed    = b.global_params[PARAM_SEED];
255     int maxNumParent  = b.global_params[PARAM_NUMBER];
256     int percentParent = b.global_params[PARAM_PERCENT];
257     b.global_insertPenalty = b.global_params[PARAM_INSERT];
258     b.global_maxNumEdgeLearned = b.global_params[PARAM_EDGE];
259
260     /* Initiate Barriers */
261     Barrier.setBarrier(numThread);
262
263     Bayes[] binit = new Bayes[numThread];
264
265     System.out.println("Number of threads          " + numThread);
266     System.out.println("Random seed                " + randomSeed);
267     System.out.println("Number of vars             " + numVar);
268     System.out.println("Number of records          " + numRecord);
269     System.out.println("Max num parents            " + maxNumParent);
270     System.out.println("%% chance of parent        " + percentParent);
271     System.out.println("Insert penalty             " + b.global_insertPenalty);
272     System.out.println("Max num edge learned / var " + b.global_maxNumEdgeLearned);
273     System.out.println("Operation quality factor   " + b.global_operationQualityFactor);
274
275     /*
276      * Generate data
277      */
278
279     System.out.print("Generating data... ");
280
281     Random randomPtr = new Random();
282     randomPtr.random_alloc();
283     randomPtr.random_seed(randomSeed);
284
285     Data dataPtr = Data.data_alloc(numVar, numRecord, randomPtr); 
286
287     Net netPtr = dataPtr.data_generate(-1, maxNumParent, percentParent);
288     System.out.println("done.");
289
290     /*
291      * Generate adtree
292      */
293
294     Adtree adtreePtr = Adtree.adtree_alloc();
295
296     System.out.print("Generating adtree... ");
297
298     adtreePtr.adtree_make(dataPtr);
299
300     System.out.println("done.");
301
302     /*
303      * Score original network
304      */
305
306     float actualScore = b.score(netPtr, adtreePtr);
307     netPtr.net_free();
308
309     /*
310      * Learn structure of Bayesian network
311      */
312
313     Learner learnerPtr = Learner.learner_alloc(dataPtr, adtreePtr, numThread, b.global_insertPenalty, b.global_maxNumEdgeLearned, b.global_operationQualityFactor);
314
315     dataPtr.data_free(); /* save memory */
316
317     System.out.print("Learning structure...");
318
319     /* Create and Start Threads */
320     for(int i = 1; i<numThread; i++) {
321       binit[i] = new Bayes(i, numThread, learnerPtr);
322     }
323
324     for(int i = 1; i<numThread; i++) {
325       binit[i].start();
326     }
327
328
329     /** 
330       * Parallel work by all threads
331       **/
332
333     Barrier.enterBarrier();
334     Learner.createTaskList(0, numThread, learnerPtr);
335     Barrier.enterBarrier();
336
337     Barrier.enterBarrier();
338     Learner.learnStructure(0, numThread, learnerPtr);
339     Barrier.enterBarrier();
340
341     System.out.println("done.");
342
343     /*
344      * Check solution
345      */
346
347     boolean status = learnerPtr.netPtr.net_isCycle();
348     if(status) {
349       System.out.println("Assert failed: system has an incorrect result");
350       System.exit(0);
351     }
352
353 #ifndef SIMULATOR
354     float learnScore = learnerPtr.learner_score();
355     System.out.println("Learn score= " + (double)learnScore);
356 #endif
357     System.out.println("Actual score= " + (double)actualScore);
358
359     /*
360      * Clean up
361      */
362
363 #ifndef SIMULATOR
364     adtreePtr.adtree_free();
365 #  if 0    
366     learnerPtr.learner_free();
367 #  endif    
368 #endif
369
370   }
371 }
372 /* =============================================================================
373  *
374  * End of bayes.java
375  *
376  * =============================================================================
377  */