1 /* =============================================================================
5 * =============================================================================
7 * Copyright (C) Stanford University, 2006. All Rights Reserved.
9 * Ported to Java June 2009 Alokika Dash
10 * University of California, Irvine
12 * =============================================================================
14 * For the license of bayes/sort.h and bayes/sort.c, please see the header
17 * ------------------------------------------------------------------------
19 * Unless otherwise noted, the following license applies to STAMP files:
21 * Copyright (c) 2007, Stanford University
22 * All rights reserved.
24 * Redistribution and use in source and binary forms, with or without
25 * modification, are permitted provided that the following conditions are
28 * * Redistributions of source code must retain the above copyright
29 * notice, this list of conditions and the following disclaimer.
31 * * Redistributions in binary form must reproduce the above copyright
32 * notice, this list of conditions and the following disclaimer in
33 * the documentation and/or other materials provided with the
36 * * Neither the name of Stanford University nor the names of its
37 * contributors may be used to endorse or promote products derived
38 * from this software without specific prior written permission.
40 * THIS SOFTWARE IS PROVIDED BY STANFORD UNIVERSITY ``AS IS'' AND ANY
41 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
42 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
43 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL STANFORD UNIVERSITY BE LIABLE
44 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
45 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
46 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
47 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
48 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
49 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
50 * THE POSSIBILITY OF SUCH DAMAGE.
52 * =============================================================================
55 #define PARAM_DEFAULT_QUALITY 1.0f
56 #define PARAM_EDGE 101
57 #define PARAM_INSERT 105
58 #define PARAM_NUMBER 110
59 #define PARAM_PERCENT 112
60 #define PARAM_RECORD 114
61 #define PARAM_SEED 115
62 #define PARAM_THREAD 116
65 #define PARAM_DEFAULT_EDGE -1
66 #define PARAM_DEFAULT_INSERT 1
67 #define PARAM_DEFAULT_NUMBER 4
68 #define PARAM_DEFAULT_PERCENT 10
69 #define PARAM_DEFAULT_RECORD 4096
70 #define PARAM_DEFAULT_SEED 1
71 #define PARAM_DEFAULT_THREAD 1
72 #define PARAM_DEFAULT_VAR 32
74 public class Bayes extends Thread {
75 public int[] global_params; /* 256 = ascii limit */
76 public int global_maxNumEdgeLearned;
77 public int global_insertPenalty;
78 public float global_operationQualityFactor;
80 /* Number of threads */
86 /* Global learn pointer */
90 global_params = new int[256];
91 global_maxNumEdgeLearned = PARAM_DEFAULT_EDGE;
92 global_insertPenalty = PARAM_DEFAULT_INSERT;
93 global_operationQualityFactor = PARAM_DEFAULT_QUALITY;
96 public Bayes(int numThread, int myId, Learner learnerPtr) {
97 this.numThread = numThread;
99 this.learnerPtr = learnerPtr;
103 /* =============================================================================
105 * =============================================================================
110 System.out.println("Usage: ./Bayes.bin [options]");
111 System.out.println(" e Max [e]dges learned per variable ");
112 System.out.println(" i Edge [i]nsert penalty ");
113 System.out.println(" n Max [n]umber of parents ");
114 System.out.println(" p [p]ercent chance of parent ");
115 System.out.println(" q Operation [q]uality factor ");
116 System.out.println(" r Number of [r]ecords ");
117 System.out.println(" s Random [s]eed ");
118 System.out.println(" t Number of [t]hreads ");
119 System.out.println(" v Number of [v]ariables ");
124 /* =============================================================================
126 * =============================================================================
131 global_params[PARAM_EDGE] = PARAM_DEFAULT_EDGE;
132 global_params[PARAM_INSERT] = PARAM_DEFAULT_INSERT;
133 global_params[PARAM_NUMBER] = PARAM_DEFAULT_NUMBER;
134 global_params[PARAM_PERCENT] = PARAM_DEFAULT_PERCENT;
135 global_params[PARAM_RECORD] = PARAM_DEFAULT_RECORD;
136 global_params[PARAM_SEED] = PARAM_DEFAULT_SEED;
137 global_params[PARAM_THREAD] = PARAM_DEFAULT_THREAD;
138 global_params[PARAM_VAR] = PARAM_DEFAULT_VAR;
142 /* =============================================================================
144 * =============================================================================
147 parseArgs (String[] args, Bayes b)
151 b.setDefaultParams();
152 while(i < args.length && args[i].startsWith("-")) {
155 if(arg.equals("-e")) {
156 if(i < args.length) {
157 b.global_params[PARAM_EDGE] = new Integer(args[i++]).intValue();
159 } else if(arg.equals("-i")) {
160 if (i < args.length) {
161 b.global_params[PARAM_INSERT] = new Integer(args[i++]).intValue();
163 } else if (arg.equals("-n")) {
164 if (i < args.length) {
165 b.global_params[PARAM_NUMBER] = new Integer(args[i++]).intValue();
167 } else if (arg.equals("-p")) {
168 if (i < args.length) {
169 b.global_params[PARAM_PERCENT] = new Integer(args[i++]).intValue();
171 } else if (arg.equals("-r")) {
172 if (i < args.length) {
173 b.global_params[PARAM_RECORD] = new Integer(args[i++]).intValue();
175 } else if (arg.equals("-s")) {
176 if (i < args.length) {
177 b.global_params[PARAM_SEED] = new Integer(args[i++]).intValue();
179 } else if (arg.equals("-t")) {
180 if (i < args.length) {
181 b.global_params[PARAM_THREAD] = new Integer(args[i++]).intValue();
183 } else if (arg.equals("-v")) {
184 if (i < args.length) {
185 b.global_params[PARAM_VAR] = new Integer(args[i++]).intValue();
187 } else if(arg.equals("-h")) {
192 if (b.global_params[PARAM_THREAD] == 0) {
198 /* =============================================================================
200 * =============================================================================
202 public float score (Net netPtr, Adtree adtreePtr) {
204 * Create dummy data structures to conform to learner_score assumptions
207 Data dataPtr = new Data(1, 1, null);
209 Learner learnerPtr = new Learner(dataPtr, adtreePtr, 1, global_insertPenalty, global_maxNumEdgeLearned, global_operationQualityFactor);
211 Net tmpNetPtr = learnerPtr.netPtr;
212 learnerPtr.netPtr = netPtr;
214 float score = learnerPtr.learner_score();
215 learnerPtr.netPtr = tmpNetPtr;
216 learnerPtr.learner_free();
229 Barrier.enterBarrier();
230 Learner.createTaskList(myId, numThread, learnerPtr);
231 Barrier.enterBarrier();
233 Barrier.enterBarrier();
234 Learner.learnStructure(myId, numThread, learnerPtr);
235 Barrier.enterBarrier();
239 /* =============================================================================
241 * =============================================================================
244 public static void main(String[] args) {
248 Bayes b = new Bayes();
249 Bayes.parseArgs(args, b);
250 int numThread = b.global_params[PARAM_THREAD];
251 int numVar = b.global_params[PARAM_VAR];
252 int numRecord = b.global_params[PARAM_RECORD];
253 int randomSeed = b.global_params[PARAM_SEED];
254 int maxNumParent = b.global_params[PARAM_NUMBER];
255 int percentParent = b.global_params[PARAM_PERCENT];
256 b.global_insertPenalty = b.global_params[PARAM_INSERT];
257 b.global_maxNumEdgeLearned = b.global_params[PARAM_EDGE];
259 /* Initiate Barriers */
260 Barrier.setBarrier(numThread);
262 Bayes[] binit = new Bayes[numThread];
264 System.out.println("Number of threads " + numThread);
265 System.out.println("Random seed " + randomSeed);
266 System.out.println("Number of vars " + numVar);
267 System.out.println("Number of records " + numRecord);
268 System.out.println("Max num parents " + maxNumParent);
269 System.out.println("%% chance of parent " + percentParent);
270 System.out.println("Insert penalty " + b.global_insertPenalty);
271 System.out.println("Max num edge learned / var " + b.global_maxNumEdgeLearned);
272 System.out.println("Operation quality factor " + b.global_operationQualityFactor);
278 System.out.print("Generating data... ");
280 Random randomPtr = new Random();
281 randomPtr.random_alloc();
282 randomPtr.random_seed(randomSeed);
284 Data dataPtr = new Data(numVar, numRecord, randomPtr);
286 Net netPtr = dataPtr.data_generate(-1, maxNumParent, percentParent);
287 System.out.println("done.");
293 Adtree adtreePtr = new Adtree();
295 System.out.print("Generating adtree... ");
297 adtreePtr.adtree_make(dataPtr);
300 System.out.println("done.");
303 * Score original network
306 float actualScore = b.score(netPtr, adtreePtr);
310 * Learn structure of Bayesian network
313 Learner learnerPtr = new Learner(dataPtr, adtreePtr, numThread, b.global_insertPenalty, b.global_maxNumEdgeLearned, b.global_operationQualityFactor);
315 System.out.print("Learning structure...");
317 /* Create and Start Threads */
318 for(int i = 1; i<numThread; i++) {
319 binit[i] = new Bayes(i, numThread, learnerPtr);
322 for(int i = 1; i<numThread; i++) {
328 * Parallel work by all threads
331 Barrier.enterBarrier();
332 Learner.createTaskList(0, numThread, learnerPtr);
333 Barrier.enterBarrier();
335 Barrier.enterBarrier();
336 Learner.learnStructure(0, numThread, learnerPtr);
337 Barrier.enterBarrier();
339 System.out.println("done.");
345 boolean status = learnerPtr.netPtr.net_isCycle();
347 System.out.println("Assert failed: system has an incorrect result");
352 float learnScore = learnerPtr.learner_score();
353 System.out.println("Learn score= " + (double)learnScore);
355 System.out.println("Actual score= " + (double)actualScore);
362 /* =============================================================================
366 * =============================================================================