checking outstanding changes in my CVS
[IRC.git] / Robust / src / Benchmarks / SingleTM / Bayes / Learner.java
1 /* =============================================================================
2  *
3  * learner.java
4  * -- Learns structure of Bayesian net from data
5  *
6  * =============================================================================
7  *
8  * Copyright (C) Stanford University, 2006.  All Rights Reserved.
9  * Author: Chi Cao Minh
10  * Ported to Java June 2009 Alokika Dash
11  * University of California, Irvine
12  *
13  *
14  * =============================================================================
15  *
16  * The penalized log-likelihood score (Friedman & Yahkani, 1996) is used to
17  * evaluated the "goodness" of a Bayesian net:
18  *
19  *                             M      n_j
20  *                            --- --- ---
21  *  -N_params * ln(R) / 2 + R >   >   >   P((a_j = v), X_j) ln P(a_j = v | X_j)
22  *                            --- --- ---
23  *                            j=1 X_j v=1
24  *
25  * Where:
26  *
27  *     N_params     total number of parents across all variables
28  *     R            number of records
29  *     M            number of variables
30  *     X_j          parents of the jth variable
31  *     n_j          number of attributes of the jth variable
32  *     a_j          attribute
33  *
34  * The second summation of X_j varies across all possible assignments to the
35  * values of the parents X_j.
36  *
37  * In the code:
38  *
39  *    "local log likelihood" is  P((a_j = v), X_j) ln P(a_j = v | X_j)
40  *    "log likelihood" is everything to the right of the '+', i.e., "R ... X_j)"
41  *    "base penalty" is -ln(R) / 2
42  *    "penalty" is N_params * -ln(R) / 2
43  *    "score" is the entire expression
44  *
45  * For more notes, refer to:
46  *
47  * A. Moore and M.-S. Lee. Cached sufficient statistics for efficient machine
48  * learning with large datasets. Journal of Artificial Intelligence Research 8
49  * (1998), pp 67-91.
50  *
51  * =============================================================================
52  *
53  * The search strategy uses a combination of local and global structure search.
54  * Similar to the technique described in:
55  *
56  * D. M. Chickering, D. Heckerman, and C. Meek.  A Bayesian approach to learning
57  * Bayesian networks with local structure. In Proceedings of Thirteenth
58  * Conference on Uncertainty in Artificial Intelligence (1997), pp. 80-89.
59  *
60  * =============================================================================
61  *
62  * For the license of bayes/sort.h and bayes/sort.c, please see the header
63  * of the files.
64  * 
65  * ------------------------------------------------------------------------
66  *
67  * Unless otherwise noted, the following license applies to STAMP files:
68  * 
69  * Copyright (c) 2007, Stanford University
70  * All rights reserved.
71  * 
72  * Redistribution and use in source and binary forms, with or without
73  * modification, are permitted provided that the following conditions are
74  * met:
75
76 *     * Redistributions of source code must retain the above copyright
77 *       notice, this list of conditions and the following disclaimer.
78
79 *     * Redistributions in binary form must reproduce the above copyright
80 *       notice, this list of conditions and the following disclaimer in
81 *       the documentation and/or other materials provided with the
82 *       distribution.
83
84 *     * Neither the name of Stanford University nor the names of its
85 *       contributors may be used to endorse or promote products derived
86 *       from this software without specific prior written permission.
87
88 * THIS SOFTWARE IS PROVIDED BY STANFORD UNIVERSITY ``AS IS'' AND ANY
89 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
90 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
91 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL STANFORD UNIVERSITY BE LIABLE
92 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
93 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
94 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
95 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
96 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
97 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
98 * THE POSSIBILITY OF SUCH DAMAGE.
99 *
100 * =============================================================================
101 */
102
103
104 #define CACHE_LINE_SIZE 64
105 #define QUERY_VALUE_WILDCARD -1
106 #define OPERATION_INSERT      0
107 #define OPERATION_REMOVE      1
108 #define OPERATION_REVERSE     2
109 #define NUM_OPERATION         3
110
111 public class Learner {
112   Adtree adtreePtr;
113   Net netPtr;
114   float[] localBaseLogLikelihoods;
115   float baseLogLikelihood;
116   LearnerTask[] tasks;
117   List taskListPtr;
118   int numTotalParent;
119   int global_insertPenalty;
120   int global_maxNumEdgeLearned;
121   float global_operationQualityFactor;
122
123   public Learner() {
124 #ifdef TEST_LEARNER
125     global_maxNumEdgeLearned = -1;
126     global_insertPenalty = 1;
127     global_operationQualityFactor = 1.0F;
128 #endif
129   }
130
131   /* =============================================================================
132    * learner_alloc
133    * =============================================================================
134    */
135   public Learner(Data dataPtr, 
136                  Adtree adtreePtr, 
137                  int numThread, 
138                  int global_insertPenalty,
139                  int global_maxNumEdgeLearned,
140                  float global_operationQualityFactor) {
141     this.adtreePtr = adtreePtr;
142     this.netPtr = new Net(dataPtr.numVar);
143     this.localBaseLogLikelihoods = new float[dataPtr.numVar];
144     this.baseLogLikelihood = 0.0f;
145     this.tasks = new LearnerTask[dataPtr.numVar];
146     this.taskListPtr = List.list_alloc();
147     this.numTotalParent = 0;
148 #ifndef TEST_LEARNER
149     this.global_insertPenalty = global_insertPenalty;
150     this.global_maxNumEdgeLearned = global_maxNumEdgeLearned;
151     this.global_operationQualityFactor = global_operationQualityFactor;
152 #endif
153   }
154
155   public void learner_free() {
156     adtreePtr=null;
157     netPtr=null;
158     localBaseLogLikelihoods=null;
159     tasks=null;
160     taskListPtr=null;
161   }
162
163
164   /* =============================================================================
165    * computeSpecificLocalLogLikelihood
166    * -- Query vectors should not contain wildcards
167    * =============================================================================
168    */
169   public float
170     computeSpecificLocalLogLikelihood (Adtree adtreePtr,
171         Vector_t queryVectorPtr,
172         Vector_t parentQueryVectorPtr)
173     {
174       int count = adtreePtr.adtree_getCount(queryVectorPtr);
175       if (count == 0) {
176         return 0.0f;
177       }
178
179       double probability = (double)count / (double)adtreePtr.numRecord;
180       int parentCount = adtreePtr.adtree_getCount(parentQueryVectorPtr);
181
182
183       float fval = (float)(probability * (Math.log((double)count/ (double)parentCount)));
184
185       return fval;
186     }
187
188
189   /* =============================================================================
190    * createPartition
191    * =============================================================================
192    */
193   public void
194     createPartition (int min, int max, int id, int n, LocalStartStop lss)
195     {
196       int range = max - min;
197       int chunk = Math.imax(1, ((range + n/2) / n)); // rounded 
198       int start = min + chunk * id;
199       int stop;
200       if (id == (n-1)) {
201         stop = max;
202       } else {
203         stop = Math.imin(max, (start + chunk));
204       }
205
206       lss.i_start = start;
207       lss.i_stop = stop;
208     }
209
210
211   /* =============================================================================
212    * createTaskList
213    * -- baseLogLikelihoods and taskListPtr are updated
214    * =============================================================================
215    */
216   public static void
217     createTaskList (int myId, int numThread, Learner learnerPtr)
218     {
219       boolean status;
220
221       Query[] queries = new Query[2];
222       queries[0] = new Query();
223       queries[1] = new Query();
224
225       Vector_t queryVectorPtr = new Vector_t(2);
226
227       status = queryVectorPtr.vector_pushBack(queries[0]);
228
229       Query parentQuery = new Query();
230       Vector_t parentQueryVectorPtr = new Vector_t(1); 
231
232       int numVar = learnerPtr.adtreePtr.numVar;
233       int numRecord = learnerPtr.adtreePtr.numRecord;
234       float baseLogLikelihood = 0.0f;
235       float penalty = (float)(-0.5f * Math.log((double)numRecord)); // only add 1 edge 
236
237       LocalStartStop lss = new LocalStartStop();
238       learnerPtr.createPartition(0, numVar, myId, numThread, lss);
239
240       /*
241        * Compute base log likelihood for each variable and total base loglikelihood
242        */
243
244       for (int v = lss.i_start; v < lss.i_stop; v++) {
245
246         float localBaseLogLikelihood = 0.0f;
247         queries[0].index = v;
248
249         queries[0].value = 0;
250         localBaseLogLikelihood +=
251           learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
252               queryVectorPtr,
253               parentQueryVectorPtr);
254
255         queries[0].value = 1;
256         localBaseLogLikelihood +=
257           learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
258               queryVectorPtr,
259               parentQueryVectorPtr);
260
261         learnerPtr.localBaseLogLikelihoods[v] = localBaseLogLikelihood;
262         baseLogLikelihood += localBaseLogLikelihood;
263
264       } // for each variable 
265
266       atomic {
267         float globalBaseLogLikelihood =
268           learnerPtr.baseLogLikelihood;
269         learnerPtr.baseLogLikelihood = (baseLogLikelihood + globalBaseLogLikelihood);
270       }
271
272       /*
273        * For each variable, find if the addition of any edge _to_ it is better
274        */
275
276       status = parentQueryVectorPtr.vector_pushBack(parentQuery);
277
278       for (int v = lss.i_start; v < lss.i_stop; v++) {
279
280          //Compute base log likelihood for this variable
281
282         queries[0].index = v;
283         int bestLocalIndex = v;
284         float bestLocalLogLikelihood = learnerPtr.localBaseLogLikelihoods[v];
285
286         status = queryVectorPtr.vector_pushBack(queries[1]);
287
288         for (int vv = 0; vv < numVar; vv++) {
289
290           if (vv == v) {
291             continue;
292           }
293           parentQuery.index = vv;
294           if (v < vv) {
295             queries[0].index = v;
296             queries[1].index = vv;
297           } else {
298             queries[0].index = vv;
299             queries[1].index = v;
300           }
301
302           float newLocalLogLikelihood = 0.0f;
303
304           queries[0].value = 0;
305           queries[1].value = 0;
306           parentQuery.value = 0;
307           newLocalLogLikelihood +=
308             learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
309                 queryVectorPtr,
310                 parentQueryVectorPtr);
311
312           queries[0].value = 0;
313           queries[1].value = 1;
314           parentQuery.value = ((vv < v) ? 0 : 1);
315           newLocalLogLikelihood +=
316             learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
317                 queryVectorPtr,
318                 parentQueryVectorPtr);
319
320           queries[0].value = 1;
321           queries[1].value = 0;
322           parentQuery.value = ((vv < v) ? 1 : 0);
323           newLocalLogLikelihood +=
324             learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
325                 queryVectorPtr,
326                 parentQueryVectorPtr);
327
328           queries[0].value = 1;
329           queries[1].value = 1;
330           parentQuery.value = 1;
331           newLocalLogLikelihood +=
332             learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
333                 queryVectorPtr,
334                 parentQueryVectorPtr);
335
336           if (newLocalLogLikelihood > bestLocalLogLikelihood) {
337             bestLocalIndex = vv;
338             bestLocalLogLikelihood = newLocalLogLikelihood;
339           }
340
341         } // foreach other variable 
342
343         queryVectorPtr.vector_popBack();
344
345         if (bestLocalIndex != v) {
346           float logLikelihood = numRecord * (baseLogLikelihood +
347               + bestLocalLogLikelihood
348               - learnerPtr.localBaseLogLikelihoods[v]);
349           float score = penalty + logLikelihood;
350
351           learnerPtr.tasks[v] = new LearnerTask();
352           LearnerTask taskPtr = learnerPtr.tasks[v];
353           taskPtr.op = OPERATION_INSERT;
354           taskPtr.fromId = bestLocalIndex;
355           taskPtr.toId = v;
356           taskPtr.score = score;
357           atomic {
358             status = learnerPtr.taskListPtr.list_insert(taskPtr);
359           }
360
361         }
362
363       } // for each variable 
364
365
366       queryVectorPtr.clear();
367       parentQueryVectorPtr.clear();
368
369 #ifdef TEST_LEARNER
370       ListNode it = learnerPtr.taskListPtr.head;
371
372      while (it.nextPtr!=null) {
373         it = it.nextPtr;
374         LearnerTask taskPtr = it.dataPtr;
375         System.out.println("[task] op= "+ taskPtr.op +" from= "+taskPtr.fromId+" to= " +taskPtr.toId+
376            " score= " + taskPtr.score);
377       }
378 #endif // TEST_LEARNER 
379
380     }
381
382   /* =============================================================================
383    * TMpopTask
384    * -- Returns null is list is empty
385    * =============================================================================
386    */
387   public LearnerTask TMpopTask (List taskListPtr)
388     {
389       LearnerTask taskPtr = null;
390
391       ListNode it = taskListPtr.head;
392       if (it.nextPtr!=null) {
393         it = it.nextPtr;
394         taskPtr = it.dataPtr;
395         boolean status = taskListPtr.list_remove(taskPtr);
396       }
397
398       return taskPtr;
399     }
400
401
402   /* =============================================================================
403    * populateParentQuery
404    * -- Modifies contents of parentQueryVectorPtr
405    * =============================================================================
406    */
407   public void
408     populateParentQueryVector (Net netPtr,
409         int id,
410         Query[] queries,
411         Vector_t parentQueryVectorPtr)
412     {
413       parentQueryVectorPtr.vector_clear();
414
415       IntList parentIdListPtr = netPtr.net_getParentIdListPtr(id);
416       IntListNode it = parentIdListPtr.head;
417       while (it.nextPtr!=null) {
418         it = it.nextPtr;
419         int parentId = it.dataPtr;
420         boolean status = parentQueryVectorPtr.vector_pushBack(queries[parentId]);
421       }
422     }
423
424
425   /* =============================================================================
426    * TMpopulateParentQuery
427    * -- Modifies contents of parentQueryVectorPtr
428    * =============================================================================
429    */
430   public void
431     TMpopulateParentQueryVector (Net netPtr,
432         int id,
433         Query[] queries,
434         Vector_t parentQueryVectorPtr)
435     {
436       parentQueryVectorPtr.vector_clear();
437
438       IntList parentIdListPtr = netPtr.net_getParentIdListPtr(id);
439       IntListNode it = parentIdListPtr.head;
440
441       while (it.nextPtr!=null) {
442         it = it.nextPtr;
443         int parentId = it.dataPtr;
444         boolean status = parentQueryVectorPtr.vector_pushBack(queries[parentId]);
445       }
446     }
447
448
449   /* =============================================================================
450    * populateQueryVectors
451    * -- Modifies contents of queryVectorPtr and parentQueryVectorPtr
452    * =============================================================================
453    */
454   public void
455     populateQueryVectors (Net netPtr,
456         int id,
457         Query[] queries,
458         Vector_t queryVectorPtr,
459         Vector_t parentQueryVectorPtr)
460     {
461       populateParentQueryVector(netPtr, id, queries, parentQueryVectorPtr);
462
463       boolean status;
464       status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
465       status = queryVectorPtr.vector_pushBack(queries[id]);
466
467
468       queryVectorPtr.vector_sort();
469     }
470
471
472   /* =============================================================================
473    * TMpopulateQueryVectors
474    * -- Modifies contents of queryVectorPtr and parentQueryVectorPtr
475    * =============================================================================
476    */
477   public void
478     TMpopulateQueryVectors (Net netPtr,
479         int id,
480         Query[] queries,
481         Vector_t queryVectorPtr,
482         Vector_t parentQueryVectorPtr)
483     {
484       TMpopulateParentQueryVector(netPtr, id, queries, parentQueryVectorPtr);
485
486       boolean status;
487       status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
488       status = queryVectorPtr.vector_pushBack(queries[id]);
489
490       queryVectorPtr.vector_sort();
491     }
492
493   /* =============================================================================
494    * computeLocalLogLikelihoodHelper
495    * -- Recursive helper routine
496    * =============================================================================
497    */
498   public float
499     computeLocalLogLikelihoodHelper (int i,
500         int numParent,
501         Adtree adtreePtr,
502         Query[] queries,
503         Vector_t queryVectorPtr,
504         Vector_t parentQueryVectorPtr)
505     {
506       if (i >= numParent) {
507         return computeSpecificLocalLogLikelihood(adtreePtr,
508             queryVectorPtr,
509             parentQueryVectorPtr);
510       }
511
512       float localLogLikelihood = 0.0f;
513
514       Query parentQueryPtr = (Query) (parentQueryVectorPtr.vector_at(i));
515       int parentIndex = parentQueryPtr.index;
516
517       queries[parentIndex].value = 0;
518       localLogLikelihood += computeLocalLogLikelihoodHelper((i + 1),
519           numParent,
520           adtreePtr,
521           queries,
522           queryVectorPtr,
523           parentQueryVectorPtr);
524
525       queries[parentIndex].value = 1;
526       localLogLikelihood += computeLocalLogLikelihoodHelper((i + 1),
527           numParent,
528           adtreePtr,
529           queries,
530           queryVectorPtr,
531           parentQueryVectorPtr);
532
533       queries[parentIndex].value = QUERY_VALUE_WILDCARD;
534
535       return localLogLikelihood;
536     }
537
538
539   /* =============================================================================
540    * computeLocalLogLikelihood
541    * -- Populate the query vectors before passing as args
542    * =============================================================================
543    */
544   public float
545     computeLocalLogLikelihood (int id,
546         Adtree adtreePtr,
547         Net netPtr,
548         Query[] queries,
549         Vector_t queryVectorPtr,
550         Vector_t parentQueryVectorPtr)
551     {
552       int numParent = parentQueryVectorPtr.vector_getSize();
553       float localLogLikelihood = 0.0f;
554
555       queries[id].value = 0;
556       localLogLikelihood += computeLocalLogLikelihoodHelper(0,
557           numParent,
558           adtreePtr,
559           queries,
560           queryVectorPtr,
561           parentQueryVectorPtr);
562
563       queries[id].value = 1;
564       localLogLikelihood += computeLocalLogLikelihoodHelper(0,
565           numParent,
566           adtreePtr,
567           queries,
568           queryVectorPtr,
569           parentQueryVectorPtr);
570
571       queries[id].value = QUERY_VALUE_WILDCARD;
572
573       return localLogLikelihood;
574     }
575
576
577   /* =============================================================================
578    * TMfindBestInsertTask
579    * =============================================================================
580    */
581   public LearnerTask
582     TMfindBestInsertTask (FindBestTaskArg argPtr)
583     {
584       int       toId                      = argPtr.toId;
585       Learner   learnerPtr                = argPtr.learnerPtr;
586       Query[]   queries                   = argPtr.queries;
587       Vector_t  queryVectorPtr            = argPtr.queryVectorPtr;
588       Vector_t  parentQueryVectorPtr      = argPtr.parentQueryVectorPtr;
589       int       numTotalParent            = argPtr.numTotalParent;
590       float     basePenalty               = argPtr.basePenalty;
591       float     baseLogLikelihood         = argPtr.baseLogLikelihood;
592       BitMap    invalidBitmapPtr          = argPtr.bitmapPtr;
593       Queue     workQueuePtr              = argPtr.workQueuePtr;
594       Vector_t  baseParentQueryVectorPtr  = argPtr.aQueryVectorPtr;
595       Vector_t  baseQueryVectorPtr        = argPtr.bQueryVectorPtr;
596
597       boolean status;
598       Adtree adtreePtr               = learnerPtr.adtreePtr;
599       Net    netPtr                  = learnerPtr.netPtr;
600
601       TMpopulateParentQueryVector(netPtr, toId, queries, parentQueryVectorPtr);
602
603       /*
604        * Create base query and parentQuery
605        */
606
607       status = Vector_t.vector_copy(baseParentQueryVectorPtr, parentQueryVectorPtr);
608
609       status = Vector_t.vector_copy(baseQueryVectorPtr, baseParentQueryVectorPtr);
610
611       status = baseQueryVectorPtr.vector_pushBack(queries[toId]);
612
613       queryVectorPtr.vector_sort();
614
615       /*
616        * Search all possible valid operations for better local log likelihood
617        */
618
619       int bestFromId = toId; // flag for not found 
620       float oldLocalLogLikelihood = learnerPtr.localBaseLogLikelihoods[toId];
621       float bestLocalLogLikelihood = oldLocalLogLikelihood;
622
623       status = netPtr.net_findDescendants(toId, invalidBitmapPtr, workQueuePtr);
624
625       int fromId = -1;
626
627       IntList parentIdListPtr = netPtr.net_getParentIdListPtr(toId);
628
629       int maxNumEdgeLearned = global_maxNumEdgeLearned;
630
631       if ((maxNumEdgeLearned < 0) ||
632           (parentIdListPtr.list_getSize() <= maxNumEdgeLearned))
633       {
634
635         IntListNode it = parentIdListPtr.head;
636
637         while(it.nextPtr!=null) {
638           it = it.nextPtr;
639           int parentId = it.dataPtr;
640           invalidBitmapPtr.bitmap_set(parentId); // invalid since already have edge 
641         }
642
643         while ((fromId = invalidBitmapPtr.bitmap_findClear((fromId + 1))) >= 0) {
644
645           if (fromId == toId) {
646             continue;
647           }
648
649           status = Vector_t.vector_copy(queryVectorPtr, baseQueryVectorPtr);
650
651           status = queryVectorPtr.vector_pushBack(queries[fromId]);
652
653           queryVectorPtr.vector_sort();
654
655           status = Vector_t.vector_copy(parentQueryVectorPtr, baseParentQueryVectorPtr);
656           status = parentQueryVectorPtr.vector_pushBack(queries[fromId]);
657
658           parentQueryVectorPtr.vector_sort();
659
660           float newLocalLogLikelihood =
661             computeLocalLogLikelihood(toId,
662                 adtreePtr,
663                 netPtr,
664                 queries,
665                 queryVectorPtr,
666                 parentQueryVectorPtr);
667
668           if (newLocalLogLikelihood > bestLocalLogLikelihood) {
669             bestLocalLogLikelihood = newLocalLogLikelihood;
670             bestFromId = fromId;
671           }
672
673         } // foreach valid parent 
674
675       } // if have not exceeded max number of edges to learn 
676
677       /*
678        * Return best task; Note: if none is better, fromId will equal toId
679        */
680
681       LearnerTask bestTask = new LearnerTask();
682       bestTask.op     = OPERATION_INSERT;
683       bestTask.fromId = bestFromId;
684       bestTask.toId   = toId;
685       bestTask.score  = 0.0f;
686
687       if (bestFromId != toId) {
688         int numRecord = adtreePtr.numRecord;
689         int numParent = parentIdListPtr.list_getSize() + 1;
690         float penalty =
691           (numTotalParent + numParent * global_insertPenalty) * basePenalty;
692         float logLikelihood = numRecord * (baseLogLikelihood +
693             + bestLocalLogLikelihood
694             - oldLocalLogLikelihood);
695         float bestScore = penalty + logLikelihood;
696         bestTask.score  = bestScore;
697       }
698
699       return bestTask;
700     }
701
702 #ifdef LEARNER_TRY_REMOVE
703   /* =============================================================================
704    * TMfindBestRemoveTask
705    * =============================================================================
706    */
707   public LearnerTask
708     TMfindBestRemoveTask (FindBestTaskArg argPtr)
709     {
710       int       toId                     = argPtr.toId;
711       Learner   learnerPtr               = argPtr.learnerPtr;
712       Query[]   queries                  = argPtr.queries;
713       Vector_t  queryVectorPtr           = argPtr.queryVectorPtr;
714       Vector_t  parentQueryVectorPtr     = argPtr.parentQueryVectorPtr;
715       int       numTotalParent           = argPtr.numTotalParent;
716       float     basePenalty              = argPtr.basePenalty;
717       float     baseLogLikelihood        = argPtr.baseLogLikelihood;
718       Vector_t  origParentQueryVectorPtr = argPtr.aQueryVectorPtr;
719
720       boolean status;
721       Adtree adtreePtr = learnerPtr.adtreePtr;
722       Net netPtr = learnerPtr.netPtr;
723       float[] localBaseLogLikelihoods = learnerPtr.localBaseLogLikelihoods;
724
725       TMpopulateParentQueryVector(netPtr, toId, queries, origParentQueryVectorPtr);
726       int numParent = origParentQueryVectorPtr.vector_getSize();
727
728       /*
729        * Search all possible valid operations for better local log likelihood
730        */
731
732       int bestFromId = toId; // flag for not found 
733       float oldLocalLogLikelihood = localBaseLogLikelihoods[toId];
734       float bestLocalLogLikelihood = oldLocalLogLikelihood;
735
736       int i;
737       for (i = 0; i < numParent; i++) {
738
739         Query queryPtr = (Query) (origParentQueryVectorPtr.vector_at(i));
740         int fromId = queryPtr.index;
741
742         /*
743          * Create parent query (subset of parents since remove an edge)
744          */
745
746         parentQueryVectorPtr.vector_clear();
747
748         for (int p = 0; p < numParent; p++) {
749           if (p != fromId) {
750             Query tmpqueryPtr = (Query) (origParentQueryVectorPtr.vector_at(p));
751             status = parentQueryVectorPtr.vector_pushBack(queries[tmpqueryPtr.index]);
752           }
753         } // create new parent query 
754
755         /*
756          * Create query
757          */
758
759         status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
760         status = queryVectorPtr.vector_pushBack(queries[toId]);
761         queryVectorPtr.vector_sort();
762
763         /*
764          * See if removing parent is better
765          */
766
767         float newLocalLogLikelihood =
768           computeLocalLogLikelihood(toId,
769               adtreePtr,
770               netPtr,
771               queries,
772               queryVectorPtr,
773               parentQueryVectorPtr);
774
775         if (newLocalLogLikelihood > bestLocalLogLikelihood) {
776           bestLocalLogLikelihood = newLocalLogLikelihood;
777           bestFromId = fromId;
778         }
779
780       } // for each parent 
781
782       /*
783        * Return best task; Note: if none is better, fromId will equal toId
784        */
785
786       LearnerTask bestTask = new LearnerTask();
787       bestTask.op     = OPERATION_REMOVE;
788       bestTask.fromId = bestFromId;
789       bestTask.toId   = toId;
790       bestTask.score  = 0.0f;
791
792       if (bestFromId != toId) {
793         int numRecord = adtreePtr.numRecord;
794         float penalty = (numTotalParent - 1) * basePenalty;
795         float logLikelihood = numRecord * (baseLogLikelihood +
796             + bestLocalLogLikelihood
797             - oldLocalLogLikelihood);
798         float bestScore = penalty + logLikelihood;
799         bestTask.score  = bestScore;
800       }
801
802       return bestTask;
803     }
804 #endif /* LEARNER_TRY_REMOVE */
805
806
807 #ifdef LEARNER_TRY_REVERSE
808   /* =============================================================================
809    * TMfindBestReverseTask
810    * =============================================================================
811    */
812   public LearnerTask
813     TMfindBestReverseTask (FindBestTaskArg argPtr)
814     {
815       int       toId                         = argPtr.toId;
816       Learner   learnerPtr                   = argPtr.learnerPtr;
817       Query[]   queries                      = argPtr.queries;
818       Vector_t  queryVectorPtr               = argPtr.queryVectorPtr;
819       Vector_t  parentQueryVectorPtr         = argPtr.parentQueryVectorPtr;
820       int       numTotalParent               = argPtr.numTotalParent;
821       float     basePenalty                  = argPtr.basePenalty;
822       float     baseLogLikelihood            = argPtr.baseLogLikelihood;
823       BitMap    visitedBitmapPtr             = argPtr.bitmapPtr;
824       Queue     workQueuePtr                 = argPtr.workQueuePtr;
825       Vector_t  toOrigParentQueryVectorPtr   = argPtr.aQueryVectorPtr;
826       Vector_t  fromOrigParentQueryVectorPtr = argPtr.bQueryVectorPtr;
827
828       boolean status;
829       Adtree adtreePtr = learnerPtr.adtreePtr;
830       Net netPtr = learnerPtr.netPtr;
831       float[] localBaseLogLikelihoods = learnerPtr.localBaseLogLikelihoods;
832
833       TMpopulateParentQueryVector(netPtr, toId, queries, toOrigParentQueryVectorPtr);
834       int numParent = toOrigParentQueryVectorPtr.vector_getSize();
835
836       /*
837        * Search all possible valid operations for better local log likelihood
838        */
839
840       int bestFromId = toId; // flag for not found 
841       float oldLocalLogLikelihood = localBaseLogLikelihoods[toId];
842       float bestLocalLogLikelihood = oldLocalLogLikelihood;
843       int fromId = 0;
844
845       for (int i = 0; i < numParent; i++) {
846
847         Query queryPtr = (Query) (toOrigParentQueryVectorPtr.vector_at(i));
848         fromId = queryPtr.index;
849
850         bestLocalLogLikelihood =
851           oldLocalLogLikelihood + localBaseLogLikelihoods[fromId];
852
853         TMpopulateParentQueryVector(netPtr,
854             fromId,
855             queries,
856             fromOrigParentQueryVectorPtr);
857
858         /*
859          * Create parent query (subset of parents since remove an edge)
860          */
861
862         parentQueryVectorPtr.vector_clear();
863
864         for (int p = 0; p < numParent; p++) {
865           if (p != fromId) {
866             Query tmpqueryPtr = (Query) (toOrigParentQueryVectorPtr.vector_at(p));
867             status = parentQueryVectorPtr.vector_pushBack(queries[tmpqueryPtr.index]);
868           }
869         } // create new parent query 
870
871         /*
872          * Create query
873          */
874
875         status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
876         status = queryVectorPtr.vector_pushBack(queries[toId]);
877
878         queryVectorPtr.vector_sort();
879
880         /*
881          * Get log likelihood for removing parent from toId
882          */
883
884         float newLocalLogLikelihood =
885           computeLocalLogLikelihood(toId,
886               adtreePtr,
887               netPtr,
888               queries,
889               queryVectorPtr,
890               parentQueryVectorPtr);
891
892         /*
893          * Get log likelihood for adding parent to fromId
894          */
895
896         status = Vector_t.vector_copy(parentQueryVectorPtr, fromOrigParentQueryVectorPtr);
897         status = parentQueryVectorPtr.vector_pushBack(queries[toId]);
898
899         parentQueryVectorPtr.vector_sort();
900
901         status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
902
903         status = queryVectorPtr.vector_pushBack(queries[fromId]);
904
905         queryVectorPtr.vector_sort();
906
907         newLocalLogLikelihood +=
908           computeLocalLogLikelihood(fromId,
909               adtreePtr,
910               netPtr,
911               queries,
912               queryVectorPtr,
913               parentQueryVectorPtr);
914
915         /*
916          * Record best
917          */
918
919         if (newLocalLogLikelihood > bestLocalLogLikelihood) {
920           bestLocalLogLikelihood = newLocalLogLikelihood;
921           bestFromId = fromId;
922         }
923
924       } // for each parent 
925
926       /*
927        * Check validity of best
928        */
929
930       if (bestFromId != toId) {
931         boolean isTaskValid = true;
932         netPtr.net_applyOperation(OPERATION_REMOVE, bestFromId, toId);
933         if (netPtr.net_isPath(bestFromId,
934               toId,
935               visitedBitmapPtr,
936               workQueuePtr))
937         {
938           isTaskValid = false;
939         }
940         netPtr.net_applyOperation(OPERATION_INSERT, bestFromId, toId);
941         if (!isTaskValid) {
942           bestFromId = toId;
943         }
944       }
945
946       /*
947        * Return best task; Note: if none is better, fromId will equal toId
948        */
949
950       LearnerTask bestTask = new LearnerTask();
951       bestTask.op     = OPERATION_REVERSE;
952       bestTask.fromId = bestFromId;
953       bestTask.toId   = toId;
954       bestTask.score  = 0.0f;
955
956       if (bestFromId != toId) {
957         float fromLocalLogLikelihood = localBaseLogLikelihoods[bestFromId];
958         int numRecord = adtreePtr.numRecord;
959         float penalty = numTotalParent * basePenalty;
960         float logLikelihood = numRecord * (baseLogLikelihood +
961             + bestLocalLogLikelihood
962             - oldLocalLogLikelihood
963             - fromLocalLogLikelihood);
964         float bestScore = penalty + logLikelihood;
965         bestTask.score  = bestScore;
966       }
967
968       return bestTask;
969     }
970
971 #endif /* LEARNER_TRY_REVERSE */
972
973
974   /* =============================================================================
975    * learnStructure
976    *
977    * Note it is okay if the score is not exact, as we are relaxing the greedy
978    * search. This means we do not need to communicate baseLogLikelihood across
979    * threads.
980    * =============================================================================
981    */
982   public static void
983     learnStructure (int myId, int numThread, Learner learnerPtr)
984     {
985
986       int numRecord = learnerPtr.adtreePtr.numRecord;
987
988       float operationQualityFactor = learnerPtr.global_operationQualityFactor;
989
990       BitMap visitedBitmapPtr = BitMap.bitmap_alloc(learnerPtr.adtreePtr.numVar);
991
992       Queue workQueuePtr = Queue.queue_alloc(-1);
993
994       int numVar = learnerPtr.adtreePtr.numVar;
995       Query[] queries = new Query[numVar];
996
997       for (int v = 0; v < numVar; v++) {
998         queries[v] = new Query();
999         queries[v].index = v;
1000         queries[v].value = QUERY_VALUE_WILDCARD;
1001       }
1002
1003       float basePenalty = (float)(-0.5 * Math.log((double)numRecord));
1004
1005       Vector_t queryVectorPtr = new Vector_t(1);
1006       Vector_t parentQueryVectorPtr = new Vector_t(1);
1007       Vector_t aQueryVectorPtr = new Vector_t(1);
1008       Vector_t bQueryVectorPtr = new Vector_t(1);
1009
1010       FindBestTaskArg arg = new FindBestTaskArg();
1011       arg.learnerPtr           = learnerPtr;
1012       arg.queries              = queries;
1013       arg.queryVectorPtr       = queryVectorPtr;
1014       arg.parentQueryVectorPtr = parentQueryVectorPtr;
1015       arg.bitmapPtr            = visitedBitmapPtr;
1016       arg.workQueuePtr         = workQueuePtr;
1017       arg.aQueryVectorPtr      = aQueryVectorPtr;
1018       arg.bQueryVectorPtr      = bQueryVectorPtr;
1019
1020       while (true) {
1021
1022         LearnerTask taskPtr;
1023
1024         atomic {
1025           taskPtr = learnerPtr.TMpopTask(learnerPtr.taskListPtr);
1026         }
1027
1028         if (taskPtr == null) {
1029           break;
1030         }
1031
1032         int op = taskPtr.op;
1033         int fromId = taskPtr.fromId;
1034         int toId = taskPtr.toId;
1035
1036         boolean isTaskValid;
1037
1038         atomic {
1039           /*
1040            * Check if task is still valid
1041            */
1042
1043           isTaskValid = true;
1044           if(op == OPERATION_INSERT) {
1045             if(learnerPtr.netPtr.net_hasEdge(fromId, toId) || 
1046                 learnerPtr.netPtr.net_isPath(toId,
1047                   fromId,
1048                   visitedBitmapPtr,
1049                   workQueuePtr))
1050             {
1051               isTaskValid = false;
1052             }
1053           } else if (op == OPERATION_REMOVE) {
1054             // Can never create cycle, so always valid
1055             ;
1056           } else if (op == OPERATION_REVERSE) {
1057             // Temporarily remove edge for check
1058             learnerPtr.netPtr.net_applyOperation(OPERATION_REMOVE, fromId, toId);
1059             if(learnerPtr.netPtr.net_isPath(fromId,
1060                   toId,
1061                   visitedBitmapPtr,
1062                   workQueuePtr))
1063             {
1064               isTaskValid = false;
1065             }
1066             learnerPtr.netPtr.net_applyOperation(OPERATION_INSERT, fromId, toId);
1067           }
1068
1069
1070 #ifdef TEST_LEARNER
1071           System.out.println("[task] op= " + taskPtr.op + " from= " + taskPtr.fromId + " to= " + 
1072               taskPtr.toId + " score= " + taskPtr.score + " valid= " + (isTaskValid ? "yes" : "no"));
1073 #endif
1074
1075           /*
1076            * Perform task: update graph and probabilities
1077            */
1078
1079           if (isTaskValid) {
1080             learnerPtr.netPtr.net_applyOperation(op, fromId, toId);
1081           }
1082
1083         }
1084
1085         float deltaLogLikelihood = 0.0f;
1086
1087         if (isTaskValid) {
1088           float newBaseLogLikelihood;
1089           if(op == OPERATION_INSERT) {
1090             atomic {
1091               learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1092                                                 toId,
1093                                                 queries,
1094                                                 queryVectorPtr,
1095                                                 parentQueryVectorPtr);
1096               newBaseLogLikelihood =
1097                 learnerPtr.computeLocalLogLikelihood(toId,
1098                                                 learnerPtr.adtreePtr,
1099                                                 learnerPtr.netPtr,
1100                                                 queries,
1101                                                 queryVectorPtr,
1102                                                 parentQueryVectorPtr);
1103               float toLocalBaseLogLikelihood = learnerPtr.localBaseLogLikelihoods[toId];
1104               deltaLogLikelihood +=
1105                 toLocalBaseLogLikelihood - newBaseLogLikelihood;
1106               learnerPtr.localBaseLogLikelihoods[toId] = newBaseLogLikelihood;
1107             }
1108
1109             atomic {
1110               int numTotalParent = learnerPtr.numTotalParent;
1111               learnerPtr.numTotalParent = numTotalParent + 1;
1112             }
1113
1114 #ifdef LEARNER_TRY_REMOVE
1115           } else if(op == OPERATION_REMOVE) {
1116             atomic {
1117               learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1118                                                 fromId,
1119                                                 queries,
1120                                                 queryVectorPtr,
1121                                                 parentQueryVectorPtr);
1122               newBaseLogLikelihood =
1123                 learnerPtr. computeLocalLogLikelihood(fromId,
1124                                               learnerPtr.adtreePtr,
1125                                               learnerPtr.netPtr,
1126                                               queries,
1127                                               queryVectorPtr, 
1128                                               parentQueryVectorPtr);
1129               float fromLocalBaseLogLikelihood =
1130                     learnerPtr.localBaseLogLikelihoods[fromId];
1131               deltaLogLikelihood +=
1132                     fromLocalBaseLogLikelihood - newBaseLogLikelihood;
1133               learnerPtr.localBaseLogLikelihoods[fromId] = newBaseLogLikelihood;
1134             }
1135
1136             atomic{ 
1137               int numTotalParent = learnerPtr.numTotalParent;
1138               learnerPtr.numTotalParent = numTotalParent - 1;
1139             }
1140
1141 #endif // LEARNER_TRY_REMOVE
1142 #ifdef LEARNER_TRY_REVERSE
1143           } else if(op == OPERATION_REVERSE) {
1144             atomic {
1145               learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1146                                                 fromId,
1147                                                 queries,
1148                                                 queryVectorPtr,
1149                                                 parentQueryVectorPtr);
1150               newBaseLogLikelihood =
1151                 learnerPtr.computeLocalLogLikelihood(fromId,
1152                                         learnerPtr.adtreePtr,
1153                                         learnerPtr.netPtr,
1154                                         queries,
1155                                         queryVectorPtr,
1156                                         parentQueryVectorPtr);
1157               float fromLocalBaseLogLikelihood =
1158                           learnerPtr.localBaseLogLikelihoods[fromId];
1159               deltaLogLikelihood +=
1160                 fromLocalBaseLogLikelihood - newBaseLogLikelihood;
1161               learnerPtr.localBaseLogLikelihoods[fromId] = newBaseLogLikelihood;
1162             }
1163
1164             atomic {
1165               learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1166                                                 toId,
1167                                                 queries,
1168                                                 queryVectorPtr,
1169                                                 parentQueryVectorPtr);
1170               newBaseLogLikelihood =
1171                 learnerPtr.computeLocalLogLikelihood(toId,
1172                                         learnerPtr.adtreePtr,
1173                                         learnerPtr.netPtr,
1174                                         queries,
1175                                         queryVectorPtr,
1176                                         parentQueryVectorPtr);
1177               float toLocalBaseLogLikelihood =
1178                         learnerPtr.localBaseLogLikelihoods[toId];
1179               deltaLogLikelihood +=
1180                 toLocalBaseLogLikelihood - newBaseLogLikelihood;
1181               learnerPtr.localBaseLogLikelihoods[toId] = newBaseLogLikelihood;
1182             }
1183
1184 #endif // LEARNER_TRY_REVERSE 
1185           }
1186
1187         } //if isTaskValid
1188
1189         /*
1190          * Update/read globals
1191          */
1192
1193         float baseLogLikelihood;
1194         int numTotalParent;
1195
1196         atomic {
1197           float oldBaseLogLikelihood = learnerPtr.baseLogLikelihood;
1198           float newBaseLogLikelihood = oldBaseLogLikelihood + deltaLogLikelihood;
1199           learnerPtr.baseLogLikelihood = newBaseLogLikelihood;
1200           baseLogLikelihood = newBaseLogLikelihood;
1201           numTotalParent = learnerPtr.numTotalParent;
1202         }
1203
1204         /*
1205          * Find next task
1206          */
1207
1208
1209         float baseScore = ((float)numTotalParent * basePenalty)
1210           + (numRecord * baseLogLikelihood);
1211
1212         LearnerTask bestTask = new LearnerTask();
1213         bestTask.op     = NUM_OPERATION;
1214         bestTask.toId   = -1;
1215         bestTask.fromId = -1;
1216         bestTask.score  = baseScore;
1217
1218         LearnerTask newTask = new LearnerTask();
1219
1220         arg.toId              = toId;
1221         arg.numTotalParent    = numTotalParent;
1222         arg.basePenalty       = basePenalty;
1223         arg.baseLogLikelihood = baseLogLikelihood;
1224
1225         atomic {
1226           newTask = learnerPtr.TMfindBestInsertTask(arg);
1227         }
1228
1229         if ((newTask.fromId != newTask.toId) &&
1230             (newTask.score > (bestTask.score / operationQualityFactor)))
1231         {
1232           bestTask = newTask;
1233         }
1234
1235 #ifdef LEARNER_TRY_REMOVE
1236         atomic {
1237           newTask = learnerPtr.TMfindBestRemoveTask(arg);
1238         }
1239
1240         if ((newTask.fromId != newTask.toId) &&
1241             (newTask.score > (bestTask.score / operationQualityFactor)))
1242         {
1243           bestTask = newTask;
1244         }
1245 #endif // LEARNER_TRY_REMOVE 
1246
1247 #ifdef LEARNER_TRY_REVERSE
1248         atomic {
1249           newTask = learnerPtr.TMfindBestReverseTask(arg);
1250         }
1251
1252         if ((newTask.fromId != newTask.toId) &&
1253             (newTask.score > (bestTask.score / operationQualityFactor)))
1254         {
1255           bestTask = newTask;
1256         }
1257 #endif // LEARNER_TRY_REVERSE 
1258
1259         if (bestTask.toId != -1) {
1260           LearnerTask[] tasks = learnerPtr.tasks;
1261           tasks[toId] = bestTask;
1262           atomic {
1263             learnerPtr.taskListPtr.list_insert(tasks[toId]);
1264           }
1265 #ifdef TEST_LEARNER
1266           System.out.println("[new]  op= " + bestTask.op + " from= "+ bestTask.fromId + " to= "+ bestTask.toId + 
1267               " score= " + bestTask.score);
1268 #endif
1269         }
1270
1271       } // while (tasks) 
1272
1273       visitedBitmapPtr.bitmap_free();
1274       workQueuePtr.queue_free();
1275       bQueryVectorPtr.clear();
1276       aQueryVectorPtr.clear();
1277       queryVectorPtr.clear();
1278       parentQueryVectorPtr.clear();
1279       queries = null;
1280     }
1281
1282
1283   /* =============================================================================
1284    * learner_run
1285    * -- Call adtree_make before this
1286    * =============================================================================
1287    */
1288   //Is not called anywhere now parallel code
1289   public void
1290     learner_run (int myId, int numThread, Learner learnerPtr)
1291     {
1292       {
1293         createTaskList(myId, numThread, learnerPtr);
1294       }
1295       {
1296         learnStructure(myId, numThread, learnerPtr);
1297       }
1298     }
1299
1300   /* =============================================================================
1301    * learner_score
1302    * -- Score entire network
1303    * =============================================================================
1304    */
1305   public float
1306     learner_score ()
1307     {
1308
1309       Vector_t queryVectorPtr = new Vector_t(1);
1310       Vector_t parentQueryVectorPtr = new Vector_t(1);
1311
1312       int numVar = adtreePtr.numVar;
1313       Query[] queries = new Query[numVar];
1314
1315       for (int v = 0; v < numVar; v++) {
1316         queries[v] = new Query();
1317         queries[v].index = v;
1318         queries[v].value = QUERY_VALUE_WILDCARD;
1319       }
1320
1321       int numTotalParent = 0;
1322       float logLikelihood = 0.0f;
1323
1324       for (int v = 0; v < numVar; v++) {
1325
1326         IntList parentIdListPtr = netPtr.net_getParentIdListPtr(v);
1327         numTotalParent += parentIdListPtr.list_getSize();
1328
1329         populateQueryVectors(netPtr,
1330             v,
1331             queries,
1332             queryVectorPtr,
1333             parentQueryVectorPtr);
1334         float localLogLikelihood = computeLocalLogLikelihood(v,
1335             adtreePtr,
1336             netPtr,
1337             queries,
1338             queryVectorPtr,
1339             parentQueryVectorPtr);
1340         logLikelihood += localLogLikelihood;
1341       }
1342
1343       queryVectorPtr.clear();
1344       parentQueryVectorPtr.clear();
1345       queries = null;
1346
1347
1348       int numRecord = adtreePtr.numRecord;
1349       float penalty = (float)(-0.5f * (double)numTotalParent * Math.log((double)numRecord));
1350       float score = penalty + (float)numRecord * logLikelihood;
1351
1352       return score;
1353     }
1354 }
1355
1356 /* =============================================================================
1357  *
1358  * End of learner.java
1359  *
1360  * =============================================================================
1361  */