bug fixes
[IRC.git] / Robust / src / Benchmarks / SingleTM / Bayes / Learner.java
1 /* =============================================================================
2  *
3  * learner.java
4  * -- Learns structure of Bayesian net from data
5  *
6  * =============================================================================
7  *
8  * Copyright (C) Stanford University, 2006.  All Rights Reserved.
9  * Author: Chi Cao Minh
10  * Ported to Java June 2009 Alokika Dash
11  * University of California, Irvine
12  *
13  *
14  * =============================================================================
15  *
16  * The penalized log-likelihood score (Friedman & Yahkani, 1996) is used to
17  * evaluated the "goodness" of a Bayesian net:
18  *
19  *                             M      n_j
20  *                            --- --- ---
21  *  -N_params * ln(R) / 2 + R >   >   >   P((a_j = v), X_j) ln P(a_j = v | X_j)
22  *                            --- --- ---
23  *                            j=1 X_j v=1
24  *
25  * Where:
26  *
27  *     N_params     total number of parents across all variables
28  *     R            number of records
29  *     M            number of variables
30  *     X_j          parents of the jth variable
31  *     n_j          number of attributes of the jth variable
32  *     a_j          attribute
33  *
34  * The second summation of X_j varies across all possible assignments to the
35  * values of the parents X_j.
36  *
37  * In the code:
38  *
39  *    "local log likelihood" is  P((a_j = v), X_j) ln P(a_j = v | X_j)
40  *    "log likelihood" is everything to the right of the '+', i.e., "R ... X_j)"
41  *    "base penalty" is -ln(R) / 2
42  *    "penalty" is N_params * -ln(R) / 2
43  *    "score" is the entire expression
44  *
45  * For more notes, refer to:
46  *
47  * A. Moore and M.-S. Lee. Cached sufficient statistics for efficient machine
48  * learning with large datasets. Journal of Artificial Intelligence Research 8
49  * (1998), pp 67-91.
50  *
51  * =============================================================================
52  *
53  * The search strategy uses a combination of local and global structure search.
54  * Similar to the technique described in:
55  *
56  * D. M. Chickering, D. Heckerman, and C. Meek.  A Bayesian approach to learning
57  * Bayesian networks with local structure. In Proceedings of Thirteenth
58  * Conference on Uncertainty in Artificial Intelligence (1997), pp. 80-89.
59  *
60  * =============================================================================
61  *
62  * For the license of bayes/sort.h and bayes/sort.c, please see the header
63  * of the files.
64  * 
65  * ------------------------------------------------------------------------
66  *
67  * Unless otherwise noted, the following license applies to STAMP files:
68  * 
69  * Copyright (c) 2007, Stanford University
70  * All rights reserved.
71  * 
72  * Redistribution and use in source and binary forms, with or without
73  * modification, are permitted provided that the following conditions are
74  * met:
75
76 *     * Redistributions of source code must retain the above copyright
77 *       notice, this list of conditions and the following disclaimer.
78
79 *     * Redistributions in binary form must reproduce the above copyright
80 *       notice, this list of conditions and the following disclaimer in
81 *       the documentation and/or other materials provided with the
82 *       distribution.
83
84 *     * Neither the name of Stanford University nor the names of its
85 *       contributors may be used to endorse or promote products derived
86 *       from this software without specific prior written permission.
87
88 * THIS SOFTWARE IS PROVIDED BY STANFORD UNIVERSITY ``AS IS'' AND ANY
89 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
90 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
91 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL STANFORD UNIVERSITY BE LIABLE
92 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
93 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
94 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
95 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
96 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
97 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
98 * THE POSSIBILITY OF SUCH DAMAGE.
99 *
100 * =============================================================================
101 */
102
103
104 #define CACHE_LINE_SIZE 64
105 #define QUERY_VALUE_WILDCARD -1
106 #define OPERATION_INSERT      0
107 #define OPERATION_REMOVE      1
108 #define OPERATION_REVERSE     2
109 #define NUM_OPERATION         3
110
111 public class Learner {
112   Adtree adtreePtr;
113   Net netPtr;
114   float[] localBaseLogLikelihoods;
115   float baseLogLikelihood;
116   LearnerTask[] tasks;
117   List taskListPtr;
118   int numTotalParent;
119   int global_insertPenalty;
120   int global_maxNumEdgeLearned;
121   float global_operationQualityFactor;
122
123   public Learner() {
124 #ifdef TEST_LEARNER
125     global_maxNumEdgeLearned = -1;
126     global_insertPenalty = 1;
127     global_operationQualityFactor = 1.0F;
128 #endif
129   }
130
131   /* =============================================================================
132    * learner_alloc
133    * =============================================================================
134    */
135   public Learner(Data dataPtr, 
136                  Adtree adtreePtr, 
137                  int numThread, 
138                  int global_insertPenalty,
139                  int global_maxNumEdgeLearned,
140                  float global_operationQualityFactor) {
141     this.adtreePtr = adtreePtr;
142     this.netPtr = new Net(dataPtr.numVar);
143     this.localBaseLogLikelihoods = new float[dataPtr.numVar];
144     this.baseLogLikelihood = 0.0f;
145     this.tasks = new LearnerTask[dataPtr.numVar];
146     this.taskListPtr = List.list_alloc();
147     this.numTotalParent = 0;
148 #ifndef TEST_LEARNER
149     this.global_insertPenalty = global_insertPenalty;
150     this.global_maxNumEdgeLearned = global_maxNumEdgeLearned;
151     this.global_operationQualityFactor = global_operationQualityFactor;
152 #endif
153   }
154
155   public void learner_free() {
156     adtreePtr=null;
157     netPtr=null;
158     localBaseLogLikelihoods=null;
159     tasks=null;
160     taskListPtr=null;
161   }
162
163
164   /* =============================================================================
165    * computeSpecificLocalLogLikelihood
166    * -- Query vectors should not contain wildcards
167    * =============================================================================
168    */
169   public float
170     computeSpecificLocalLogLikelihood (Adtree adtreePtr,
171         Vector_t queryVectorPtr,
172         Vector_t parentQueryVectorPtr)
173     {
174       int count = adtreePtr.adtree_getCount(queryVectorPtr);
175       if (count == 0) {
176         return 0.0f;
177       }
178
179       double probability = (double)count / (double)adtreePtr.numRecord;
180       int parentCount = adtreePtr.adtree_getCount(parentQueryVectorPtr);
181
182
183       float fval = (float)(probability * (Math.log((double)count/ (double)parentCount)));
184
185       return fval;
186     }
187
188
189   /* =============================================================================
190    * createPartition
191    * =============================================================================
192    */
193   public void
194     createPartition (int min, int max, int id, int n, LocalStartStop lss)
195     {
196       int range = max - min;
197       int chunk = Math.imax(1, ((range + n/2) / n)); // rounded 
198       int start = min + chunk * id;
199       int stop;
200       if (id == (n-1)) {
201         stop = max;
202       } else {
203         stop = Math.imin(max, (start + chunk));
204       }
205
206       lss.i_start = start;
207       lss.i_stop = stop;
208     }
209
210   /* =============================================================================
211    * createTaskList
212    * -- baseLogLikelihoods and taskListPtr are updated
213    * =============================================================================
214    */
215   public static void
216     createTaskList (int myId, int numThread, Learner learnerPtr)
217     {
218       boolean status;
219
220       Query[] queries = new Query[2];
221       queries[0] = new Query();
222       queries[1] = new Query();
223
224       Vector_t queryVectorPtr = new Vector_t(2);
225
226       status = queryVectorPtr.vector_pushBack(queries[0]);
227
228       Query parentQuery = new Query();
229       Vector_t parentQueryVectorPtr = new Vector_t(1); 
230
231       int numVar = learnerPtr.adtreePtr.numVar;
232       int numRecord = learnerPtr.adtreePtr.numRecord;
233       float baseLogLikelihood = 0.0f;
234       float penalty = (float)(-0.5f * Math.log((double)numRecord)); // only add 1 edge 
235
236       LocalStartStop lss = new LocalStartStop();
237       learnerPtr.createPartition(0, numVar, myId, numThread, lss);
238
239       /*
240        * Compute base log likelihood for each variable and total base loglikelihood
241        */
242
243       for (int v = lss.i_start; v < lss.i_stop; v++) {
244
245         float localBaseLogLikelihood = 0.0f;
246         queries[0].index = v;
247
248         queries[0].value = 0;
249         localBaseLogLikelihood +=
250           learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
251               queryVectorPtr,
252               parentQueryVectorPtr);
253
254         queries[0].value = 1;
255         localBaseLogLikelihood +=
256           learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
257               queryVectorPtr,
258               parentQueryVectorPtr);
259
260         learnerPtr.localBaseLogLikelihoods[v] = localBaseLogLikelihood;
261         baseLogLikelihood += localBaseLogLikelihood;
262
263       } // for each variable 
264
265       atomic {
266         float globalBaseLogLikelihood =
267           learnerPtr.baseLogLikelihood;
268         learnerPtr.baseLogLikelihood = (baseLogLikelihood + globalBaseLogLikelihood);
269       }
270
271       /*
272        * For each variable, find if the addition of any edge _to_ it is better
273        */
274
275       status = parentQueryVectorPtr.vector_pushBack(parentQuery);
276
277       for (int v = lss.i_start; v < lss.i_stop; v++) {
278
279          //Compute base log likelihood for this variable
280
281         queries[0].index = v;
282         int bestLocalIndex = v;
283         float bestLocalLogLikelihood = learnerPtr.localBaseLogLikelihoods[v];
284
285         status = queryVectorPtr.vector_pushBack(queries[1]);
286
287         for (int vv = 0; vv < numVar; vv++) {
288
289           if (vv == v) {
290             continue;
291           }
292           parentQuery.index = vv;
293           if (v < vv) {
294             queries[0].index = v;
295             queries[1].index = vv;
296           } else {
297             queries[0].index = vv;
298             queries[1].index = v;
299           }
300
301           float newLocalLogLikelihood = 0.0f;
302
303           queries[0].value = 0;
304           queries[1].value = 0;
305           parentQuery.value = 0;
306           newLocalLogLikelihood +=
307             learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
308                 queryVectorPtr,
309                 parentQueryVectorPtr);
310
311           queries[0].value = 0;
312           queries[1].value = 1;
313           parentQuery.value = ((vv < v) ? 0 : 1);
314           newLocalLogLikelihood +=
315             learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
316                 queryVectorPtr,
317                 parentQueryVectorPtr);
318
319           queries[0].value = 1;
320           queries[1].value = 0;
321           parentQuery.value = ((vv < v) ? 1 : 0);
322           newLocalLogLikelihood +=
323             learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
324                 queryVectorPtr,
325                 parentQueryVectorPtr);
326
327           queries[0].value = 1;
328           queries[1].value = 1;
329           parentQuery.value = 1;
330           newLocalLogLikelihood +=
331             learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
332                 queryVectorPtr,
333                 parentQueryVectorPtr);
334
335           if (newLocalLogLikelihood > bestLocalLogLikelihood) {
336             bestLocalIndex = vv;
337             bestLocalLogLikelihood = newLocalLogLikelihood;
338           }
339
340         } // foreach other variable 
341
342         queryVectorPtr.vector_popBack();
343
344         if (bestLocalIndex != v) {
345           float logLikelihood = numRecord * (baseLogLikelihood +
346               + bestLocalLogLikelihood
347               - learnerPtr.localBaseLogLikelihoods[v]);
348           float score = penalty + logLikelihood;
349
350           learnerPtr.tasks[v] = new LearnerTask();
351           LearnerTask taskPtr = learnerPtr.tasks[v];
352           taskPtr.op = OPERATION_INSERT;
353           taskPtr.fromId = bestLocalIndex;
354           taskPtr.toId = v;
355           taskPtr.score = score;
356           atomic {
357             status = learnerPtr.taskListPtr.list_insert(taskPtr);
358           }
359
360         }
361
362       } // for each variable 
363
364
365       queryVectorPtr.clear();
366       parentQueryVectorPtr.clear();
367
368 #ifdef TEST_LEARNER
369       ListNode it = learnerPtr.taskListPtr.head;
370
371      while (it.nextPtr!=null) {
372         it = it.nextPtr;
373         LearnerTask taskPtr = it.dataPtr;
374         System.out.println("[task] op= "+ taskPtr.op +" from= "+taskPtr.fromId+" to= " +taskPtr.toId+
375            " score= " + taskPtr.score);
376       }
377 #endif // TEST_LEARNER 
378
379     }
380
381   /* =============================================================================
382    * TMpopTask
383    * -- Returns null is list is empty
384    * =============================================================================
385    */
386   public LearnerTask TMpopTask (List taskListPtr)
387     {
388       LearnerTask taskPtr = null;
389
390       ListNode it = taskListPtr.head;
391       if (it.nextPtr!=null) {
392         it = it.nextPtr;
393         taskPtr = it.dataPtr;
394         boolean status = taskListPtr.list_remove(taskPtr);
395       }
396
397       return taskPtr;
398     }
399
400
401   /* =============================================================================
402    * populateParentQuery
403    * -- Modifies contents of parentQueryVectorPtr
404    * =============================================================================
405    */
406   public void
407     populateParentQueryVector (Net netPtr,
408         int id,
409         Query[] queries,
410         Vector_t parentQueryVectorPtr)
411     {
412       parentQueryVectorPtr.vector_clear();
413
414       IntList parentIdListPtr = netPtr.net_getParentIdListPtr(id);
415       IntListNode it = parentIdListPtr.head;
416       while (it.nextPtr!=null) {
417         it = it.nextPtr;
418         int parentId = it.dataPtr;
419         boolean status = parentQueryVectorPtr.vector_pushBack(queries[parentId]);
420       }
421     }
422
423
424   /* =============================================================================
425    * TMpopulateParentQuery
426    * -- Modifies contents of parentQueryVectorPtr
427    * =============================================================================
428    */
429   public void
430     TMpopulateParentQueryVector (Net netPtr,
431         int id,
432         Query[] queries,
433         Vector_t parentQueryVectorPtr)
434     {
435       parentQueryVectorPtr.vector_clear();
436
437       IntList parentIdListPtr = netPtr.net_getParentIdListPtr(id);
438       IntListNode it = parentIdListPtr.head;
439
440       while (it.nextPtr!=null) {
441         it = it.nextPtr;
442         int parentId = it.dataPtr;
443         boolean status = parentQueryVectorPtr.vector_pushBack(queries[parentId]);
444       }
445     }
446
447
448   /* =============================================================================
449    * populateQueryVectors
450    * -- Modifies contents of queryVectorPtr and parentQueryVectorPtr
451    * =============================================================================
452    */
453   public void
454     populateQueryVectors (Net netPtr,
455         int id,
456         Query[] queries,
457         Vector_t queryVectorPtr,
458         Vector_t parentQueryVectorPtr)
459     {
460       populateParentQueryVector(netPtr, id, queries, parentQueryVectorPtr);
461
462       boolean status;
463       status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
464       status = queryVectorPtr.vector_pushBack(queries[id]);
465
466
467       queryVectorPtr.vector_sort();
468     }
469
470
471   /* =============================================================================
472    * TMpopulateQueryVectors
473    * -- Modifies contents of queryVectorPtr and parentQueryVectorPtr
474    * =============================================================================
475    */
476   public void
477     TMpopulateQueryVectors (Net netPtr,
478         int id,
479         Query[] queries,
480         Vector_t queryVectorPtr,
481         Vector_t parentQueryVectorPtr)
482     {
483       TMpopulateParentQueryVector(netPtr, id, queries, parentQueryVectorPtr);
484
485       boolean status;
486       status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
487       status = queryVectorPtr.vector_pushBack(queries[id]);
488
489       queryVectorPtr.vector_sort();
490     }
491
492   /* =============================================================================
493    * computeLocalLogLikelihoodHelper
494    * -- Recursive helper routine
495    * =============================================================================
496    */
497   public float
498     computeLocalLogLikelihoodHelper (int i,
499         int numParent,
500         Adtree adtreePtr,
501         Query[] queries,
502         Vector_t queryVectorPtr,
503         Vector_t parentQueryVectorPtr)
504     {
505       if (i >= numParent) {
506         return computeSpecificLocalLogLikelihood(adtreePtr,
507             queryVectorPtr,
508             parentQueryVectorPtr);
509       }
510
511       float localLogLikelihood = 0.0f;
512
513       Query parentQueryPtr = (Query) (parentQueryVectorPtr.vector_at(i));
514       int parentIndex = parentQueryPtr.index;
515
516       queries[parentIndex].value = 0;
517       localLogLikelihood += computeLocalLogLikelihoodHelper((i + 1),
518           numParent,
519           adtreePtr,
520           queries,
521           queryVectorPtr,
522           parentQueryVectorPtr);
523
524       queries[parentIndex].value = 1;
525       localLogLikelihood += computeLocalLogLikelihoodHelper((i + 1),
526           numParent,
527           adtreePtr,
528           queries,
529           queryVectorPtr,
530           parentQueryVectorPtr);
531
532       queries[parentIndex].value = QUERY_VALUE_WILDCARD;
533
534       return localLogLikelihood;
535     }
536
537
538   /* =============================================================================
539    * computeLocalLogLikelihood
540    * -- Populate the query vectors before passing as args
541    * =============================================================================
542    */
543   public float
544     computeLocalLogLikelihood (int id,
545         Adtree adtreePtr,
546         Net netPtr,
547         Query[] queries,
548         Vector_t queryVectorPtr,
549         Vector_t parentQueryVectorPtr)
550     {
551       int numParent = parentQueryVectorPtr.vector_getSize();
552       float localLogLikelihood = 0.0f;
553
554       queries[id].value = 0;
555       localLogLikelihood += computeLocalLogLikelihoodHelper(0,
556           numParent,
557           adtreePtr,
558           queries,
559           queryVectorPtr,
560           parentQueryVectorPtr);
561
562       queries[id].value = 1;
563       localLogLikelihood += computeLocalLogLikelihoodHelper(0,
564           numParent,
565           adtreePtr,
566           queries,
567           queryVectorPtr,
568           parentQueryVectorPtr);
569
570       queries[id].value = QUERY_VALUE_WILDCARD;
571
572       return localLogLikelihood;
573     }
574
575
576   /* =============================================================================
577    * TMfindBestInsertTask
578    * =============================================================================
579    */
580   public LearnerTask
581     TMfindBestInsertTask (FindBestTaskArg argPtr)
582     {
583       int       toId                      = argPtr.toId;
584       Learner   learnerPtr                = argPtr.learnerPtr;
585       Query[]   queries                   = argPtr.queries;
586       Vector_t  queryVectorPtr            = argPtr.queryVectorPtr;
587       Vector_t  parentQueryVectorPtr      = argPtr.parentQueryVectorPtr;
588       int       numTotalParent            = argPtr.numTotalParent;
589       float     basePenalty               = argPtr.basePenalty;
590       float     baseLogLikelihood         = argPtr.baseLogLikelihood;
591       BitMap    invalidBitmapPtr          = argPtr.bitmapPtr;
592       Queue     workQueuePtr              = argPtr.workQueuePtr;
593       Vector_t  baseParentQueryVectorPtr  = argPtr.aQueryVectorPtr;
594       Vector_t  baseQueryVectorPtr        = argPtr.bQueryVectorPtr;
595
596       boolean status;
597       Adtree adtreePtr               = learnerPtr.adtreePtr;
598       Net    netPtr                  = learnerPtr.netPtr;
599
600       TMpopulateParentQueryVector(netPtr, toId, queries, parentQueryVectorPtr);
601
602       /*
603        * Create base query and parentQuery
604        */
605
606       status = Vector_t.vector_copy(baseParentQueryVectorPtr, parentQueryVectorPtr);
607
608       status = Vector_t.vector_copy(baseQueryVectorPtr, baseParentQueryVectorPtr);
609
610       status = baseQueryVectorPtr.vector_pushBack(queries[toId]);
611
612       queryVectorPtr.vector_sort();
613
614       /*
615        * Search all possible valid operations for better local log likelihood
616        */
617
618       int bestFromId = toId; // flag for not found 
619       float oldLocalLogLikelihood = learnerPtr.localBaseLogLikelihoods[toId];
620       float bestLocalLogLikelihood = oldLocalLogLikelihood;
621
622       status = netPtr.net_findDescendants(toId, invalidBitmapPtr, workQueuePtr);
623
624       int fromId = -1;
625
626       IntList parentIdListPtr = netPtr.net_getParentIdListPtr(toId);
627
628       int maxNumEdgeLearned = global_maxNumEdgeLearned;
629
630       if ((maxNumEdgeLearned < 0) ||
631           (parentIdListPtr.list_getSize() <= maxNumEdgeLearned))
632       {
633
634         IntListNode it = parentIdListPtr.head;
635
636         while(it.nextPtr!=null) {
637           it = it.nextPtr;
638           int parentId = it.dataPtr;
639           invalidBitmapPtr.bitmap_set(parentId); // invalid since already have edge 
640         }
641
642         while ((fromId = invalidBitmapPtr.bitmap_findClear((fromId + 1))) >= 0) {
643
644           if (fromId == toId) {
645             continue;
646           }
647
648           status = Vector_t.vector_copy(queryVectorPtr, baseQueryVectorPtr);
649
650           status = queryVectorPtr.vector_pushBack(queries[fromId]);
651
652           queryVectorPtr.vector_sort();
653
654           status = Vector_t.vector_copy(parentQueryVectorPtr, baseParentQueryVectorPtr);
655           status = parentQueryVectorPtr.vector_pushBack(queries[fromId]);
656
657           parentQueryVectorPtr.vector_sort();
658
659           float newLocalLogLikelihood =
660             computeLocalLogLikelihood(toId,
661                 adtreePtr,
662                 netPtr,
663                 queries,
664                 queryVectorPtr,
665                 parentQueryVectorPtr);
666
667           if (newLocalLogLikelihood > bestLocalLogLikelihood) {
668             bestLocalLogLikelihood = newLocalLogLikelihood;
669             bestFromId = fromId;
670           }
671
672         } // foreach valid parent 
673
674       } // if have not exceeded max number of edges to learn 
675
676       /*
677        * Return best task; Note: if none is better, fromId will equal toId
678        */
679
680       LearnerTask bestTask = new LearnerTask();
681       bestTask.op     = OPERATION_INSERT;
682       bestTask.fromId = bestFromId;
683       bestTask.toId   = toId;
684       bestTask.score  = 0.0f;
685
686       if (bestFromId != toId) {
687         int numRecord = adtreePtr.numRecord;
688         int numParent = parentIdListPtr.list_getSize() + 1;
689         float penalty =
690           (numTotalParent + numParent * global_insertPenalty) * basePenalty;
691         float logLikelihood = numRecord * (baseLogLikelihood +
692             + bestLocalLogLikelihood
693             - oldLocalLogLikelihood);
694         float bestScore = penalty + logLikelihood;
695         bestTask.score  = bestScore;
696       }
697
698       return bestTask;
699     }
700
701 #ifdef LEARNER_TRY_REMOVE
702   /* =============================================================================
703    * TMfindBestRemoveTask
704    * =============================================================================
705    */
706   public LearnerTask
707     TMfindBestRemoveTask (FindBestTaskArg argPtr)
708     {
709       int       toId                     = argPtr.toId;
710       Learner   learnerPtr               = argPtr.learnerPtr;
711       Query[]   queries                  = argPtr.queries;
712       Vector_t  queryVectorPtr           = argPtr.queryVectorPtr;
713       Vector_t  parentQueryVectorPtr     = argPtr.parentQueryVectorPtr;
714       int       numTotalParent           = argPtr.numTotalParent;
715       float     basePenalty              = argPtr.basePenalty;
716       float     baseLogLikelihood        = argPtr.baseLogLikelihood;
717       Vector_t  origParentQueryVectorPtr = argPtr.aQueryVectorPtr;
718
719       boolean status;
720       Adtree adtreePtr = learnerPtr.adtreePtr;
721       Net netPtr = learnerPtr.netPtr;
722       float[] localBaseLogLikelihoods = learnerPtr.localBaseLogLikelihoods;
723
724       TMpopulateParentQueryVector(netPtr, toId, queries, origParentQueryVectorPtr);
725       int numParent = origParentQueryVectorPtr.vector_getSize();
726
727       /*
728        * Search all possible valid operations for better local log likelihood
729        */
730
731       int bestFromId = toId; // flag for not found 
732       float oldLocalLogLikelihood = localBaseLogLikelihoods[toId];
733       float bestLocalLogLikelihood = oldLocalLogLikelihood;
734
735       int i;
736       for (i = 0; i < numParent; i++) {
737
738         Query queryPtr = (Query) (origParentQueryVectorPtr.vector_at(i));
739         int fromId = queryPtr.index;
740
741         /*
742          * Create parent query (subset of parents since remove an edge)
743          */
744
745         parentQueryVectorPtr.vector_clear();
746
747         for (int p = 0; p < numParent; p++) {
748           if (p != fromId) {
749             Query tmpqueryPtr = (Query) (origParentQueryVectorPtr.vector_at(p));
750             status = parentQueryVectorPtr.vector_pushBack(queries[tmpqueryPtr.index]);
751           }
752         } // create new parent query 
753
754         /*
755          * Create query
756          */
757
758         status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
759         status = queryVectorPtr.vector_pushBack(queries[toId]);
760         queryVectorPtr.vector_sort();
761
762         /*
763          * See if removing parent is better
764          */
765
766         float newLocalLogLikelihood =
767           computeLocalLogLikelihood(toId,
768               adtreePtr,
769               netPtr,
770               queries,
771               queryVectorPtr,
772               parentQueryVectorPtr);
773
774         if (newLocalLogLikelihood > bestLocalLogLikelihood) {
775           bestLocalLogLikelihood = newLocalLogLikelihood;
776           bestFromId = fromId;
777         }
778
779       } // for each parent 
780
781       /*
782        * Return best task; Note: if none is better, fromId will equal toId
783        */
784
785       LearnerTask bestTask = new LearnerTask();
786       bestTask.op     = OPERATION_REMOVE;
787       bestTask.fromId = bestFromId;
788       bestTask.toId   = toId;
789       bestTask.score  = 0.0f;
790
791       if (bestFromId != toId) {
792         int numRecord = adtreePtr.numRecord;
793         float penalty = (numTotalParent - 1) * basePenalty;
794         float logLikelihood = numRecord * (baseLogLikelihood +
795             + bestLocalLogLikelihood
796             - oldLocalLogLikelihood);
797         float bestScore = penalty + logLikelihood;
798         bestTask.score  = bestScore;
799       }
800
801       return bestTask;
802     }
803 #endif /* LEARNER_TRY_REMOVE */
804
805
806 #ifdef LEARNER_TRY_REVERSE
807   /* =============================================================================
808    * TMfindBestReverseTask
809    * =============================================================================
810    */
811   public LearnerTask
812     TMfindBestReverseTask (FindBestTaskArg argPtr)
813     {
814       int       toId                         = argPtr.toId;
815       Learner   learnerPtr                   = argPtr.learnerPtr;
816       Query[]   queries                      = argPtr.queries;
817       Vector_t  queryVectorPtr               = argPtr.queryVectorPtr;
818       Vector_t  parentQueryVectorPtr         = argPtr.parentQueryVectorPtr;
819       int       numTotalParent               = argPtr.numTotalParent;
820       float     basePenalty                  = argPtr.basePenalty;
821       float     baseLogLikelihood            = argPtr.baseLogLikelihood;
822       BitMap    visitedBitmapPtr             = argPtr.bitmapPtr;
823       Queue     workQueuePtr                 = argPtr.workQueuePtr;
824       Vector_t  toOrigParentQueryVectorPtr   = argPtr.aQueryVectorPtr;
825       Vector_t  fromOrigParentQueryVectorPtr = argPtr.bQueryVectorPtr;
826
827       boolean status;
828       Adtree adtreePtr = learnerPtr.adtreePtr;
829       Net netPtr = learnerPtr.netPtr;
830       float[] localBaseLogLikelihoods = learnerPtr.localBaseLogLikelihoods;
831
832       TMpopulateParentQueryVector(netPtr, toId, queries, toOrigParentQueryVectorPtr);
833       int numParent = toOrigParentQueryVectorPtr.vector_getSize();
834
835       /*
836        * Search all possible valid operations for better local log likelihood
837        */
838
839       int bestFromId = toId; // flag for not found 
840       float oldLocalLogLikelihood = localBaseLogLikelihoods[toId];
841       float bestLocalLogLikelihood = oldLocalLogLikelihood;
842       int fromId = 0;
843
844       for (int i = 0; i < numParent; i++) {
845
846         Query queryPtr = (Query) (toOrigParentQueryVectorPtr.vector_at(i));
847         fromId = queryPtr.index;
848
849         bestLocalLogLikelihood =
850           oldLocalLogLikelihood + localBaseLogLikelihoods[fromId];
851
852         TMpopulateParentQueryVector(netPtr,
853             fromId,
854             queries,
855             fromOrigParentQueryVectorPtr);
856
857         /*
858          * Create parent query (subset of parents since remove an edge)
859          */
860
861         parentQueryVectorPtr.vector_clear();
862
863         for (int p = 0; p < numParent; p++) {
864           if (p != fromId) {
865             Query tmpqueryPtr = (Query) (toOrigParentQueryVectorPtr.vector_at(p));
866             status = parentQueryVectorPtr.vector_pushBack(queries[tmpqueryPtr.index]);
867           }
868         } // create new parent query 
869
870         /*
871          * Create query
872          */
873
874         status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
875         status = queryVectorPtr.vector_pushBack(queries[toId]);
876
877         queryVectorPtr.vector_sort();
878
879         /*
880          * Get log likelihood for removing parent from toId
881          */
882
883         float newLocalLogLikelihood =
884           computeLocalLogLikelihood(toId,
885               adtreePtr,
886               netPtr,
887               queries,
888               queryVectorPtr,
889               parentQueryVectorPtr);
890
891         /*
892          * Get log likelihood for adding parent to fromId
893          */
894
895         status = Vector_t.vector_copy(parentQueryVectorPtr, fromOrigParentQueryVectorPtr);
896         status = parentQueryVectorPtr.vector_pushBack(queries[toId]);
897
898         parentQueryVectorPtr.vector_sort();
899
900         status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
901
902         status = queryVectorPtr.vector_pushBack(queries[fromId]);
903
904         queryVectorPtr.vector_sort();
905
906         newLocalLogLikelihood +=
907           computeLocalLogLikelihood(fromId,
908               adtreePtr,
909               netPtr,
910               queries,
911               queryVectorPtr,
912               parentQueryVectorPtr);
913
914         /*
915          * Record best
916          */
917
918         if (newLocalLogLikelihood > bestLocalLogLikelihood) {
919           bestLocalLogLikelihood = newLocalLogLikelihood;
920           bestFromId = fromId;
921         }
922
923       } // for each parent 
924
925       /*
926        * Check validity of best
927        */
928
929       if (bestFromId != toId) {
930         boolean isTaskValid = true;
931         netPtr.net_applyOperation(OPERATION_REMOVE, bestFromId, toId);
932         if (netPtr.net_isPath(bestFromId,
933               toId,
934               visitedBitmapPtr,
935               workQueuePtr))
936         {
937           isTaskValid = false;
938         }
939         netPtr.net_applyOperation(OPERATION_INSERT, bestFromId, toId);
940         if (!isTaskValid) {
941           bestFromId = toId;
942         }
943       }
944
945       /*
946        * Return best task; Note: if none is better, fromId will equal toId
947        */
948
949       LearnerTask bestTask = new LearnerTask();
950       bestTask.op     = OPERATION_REVERSE;
951       bestTask.fromId = bestFromId;
952       bestTask.toId   = toId;
953       bestTask.score  = 0.0f;
954
955       if (bestFromId != toId) {
956         float fromLocalLogLikelihood = localBaseLogLikelihoods[bestFromId];
957         int numRecord = adtreePtr.numRecord;
958         float penalty = numTotalParent * basePenalty;
959         float logLikelihood = numRecord * (baseLogLikelihood +
960             + bestLocalLogLikelihood
961             - oldLocalLogLikelihood
962             - fromLocalLogLikelihood);
963         float bestScore = penalty + logLikelihood;
964         bestTask.score  = bestScore;
965       }
966
967       return bestTask;
968     }
969
970 #endif /* LEARNER_TRY_REVERSE */
971
972
973   /* =============================================================================
974    * learnStructure
975    *
976    * Note it is okay if the score is not exact, as we are relaxing the greedy
977    * search. This means we do not need to communicate baseLogLikelihood across
978    * threads.
979    * =============================================================================
980    */
981   public static void
982     learnStructure (int myId, int numThread, Learner learnerPtr)
983     {
984
985       int numRecord = learnerPtr.adtreePtr.numRecord;
986
987       float operationQualityFactor = learnerPtr.global_operationQualityFactor;
988
989       BitMap visitedBitmapPtr = BitMap.bitmap_alloc(learnerPtr.adtreePtr.numVar);
990
991       Queue workQueuePtr = Queue.queue_alloc(-1);
992
993       int numVar = learnerPtr.adtreePtr.numVar;
994       Query[] queries = new Query[numVar];
995
996       for (int v = 0; v < numVar; v++) {
997         queries[v] = new Query();
998         queries[v].index = v;
999         queries[v].value = QUERY_VALUE_WILDCARD;
1000       }
1001
1002       float basePenalty = (float)(-0.5 * Math.log((double)numRecord));
1003
1004       Vector_t queryVectorPtr = new Vector_t(1);
1005       Vector_t parentQueryVectorPtr = new Vector_t(1);
1006       Vector_t aQueryVectorPtr = new Vector_t(1);
1007       Vector_t bQueryVectorPtr = new Vector_t(1);
1008
1009       FindBestTaskArg arg = new FindBestTaskArg();
1010       arg.learnerPtr           = learnerPtr;
1011       arg.queries              = queries;
1012       arg.queryVectorPtr       = queryVectorPtr;
1013       arg.parentQueryVectorPtr = parentQueryVectorPtr;
1014       arg.bitmapPtr            = visitedBitmapPtr;
1015       arg.workQueuePtr         = workQueuePtr;
1016       arg.aQueryVectorPtr      = aQueryVectorPtr;
1017       arg.bQueryVectorPtr      = bQueryVectorPtr;
1018
1019       while (true) {
1020
1021         LearnerTask taskPtr;
1022
1023         atomic {
1024           taskPtr = learnerPtr.TMpopTask(learnerPtr.taskListPtr);
1025         }
1026
1027         if (taskPtr == null) {
1028           break;
1029         }
1030
1031         int op = taskPtr.op;
1032         int fromId = taskPtr.fromId;
1033         int toId = taskPtr.toId;
1034
1035         boolean isTaskValid;
1036
1037         atomic {
1038           /*
1039            * Check if task is still valid
1040            */
1041
1042           isTaskValid = true;
1043           if(op == OPERATION_INSERT) {
1044             if(learnerPtr.netPtr.net_hasEdge(fromId, toId) || 
1045                 learnerPtr.netPtr.net_isPath(toId,
1046                   fromId,
1047                   visitedBitmapPtr,
1048                   workQueuePtr))
1049             {
1050               isTaskValid = false;
1051             }
1052           } else if (op == OPERATION_REMOVE) {
1053             // Can never create cycle, so always valid
1054             ;
1055           } else if (op == OPERATION_REVERSE) {
1056             // Temporarily remove edge for check
1057             learnerPtr.netPtr.net_applyOperation(OPERATION_REMOVE, fromId, toId);
1058             if(learnerPtr.netPtr.net_isPath(fromId,
1059                   toId,
1060                   visitedBitmapPtr,
1061                   workQueuePtr))
1062             {
1063               isTaskValid = false;
1064             }
1065             learnerPtr.netPtr.net_applyOperation(OPERATION_INSERT, fromId, toId);
1066           }
1067
1068
1069 #ifdef TEST_LEARNER
1070           System.out.println("[task] op= " + taskPtr.op + " from= " + taskPtr.fromId + " to= " + 
1071               taskPtr.toId + " score= " + taskPtr.score + " valid= " + (isTaskValid ? "yes" : "no"));
1072 #endif
1073
1074           /*
1075            * Perform task: update graph and probabilities
1076            */
1077
1078           if (isTaskValid) {
1079             learnerPtr.netPtr.net_applyOperation(op, fromId, toId);
1080           }
1081
1082         }
1083
1084         float deltaLogLikelihood = 0.0f;
1085
1086         if (isTaskValid) {
1087           float newBaseLogLikelihood;
1088           if(op == OPERATION_INSERT) {
1089             atomic {
1090               learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1091                                                 toId,
1092                                                 queries,
1093                                                 queryVectorPtr,
1094                                                 parentQueryVectorPtr);
1095               newBaseLogLikelihood =
1096                 learnerPtr.computeLocalLogLikelihood(toId,
1097                                                 learnerPtr.adtreePtr,
1098                                                 learnerPtr.netPtr,
1099                                                 queries,
1100                                                 queryVectorPtr,
1101                                                 parentQueryVectorPtr);
1102               float toLocalBaseLogLikelihood = learnerPtr.localBaseLogLikelihoods[toId];
1103               deltaLogLikelihood +=
1104                 toLocalBaseLogLikelihood - newBaseLogLikelihood;
1105               learnerPtr.localBaseLogLikelihoods[toId] = newBaseLogLikelihood;
1106             }
1107
1108             atomic {
1109               int numTotalParent = learnerPtr.numTotalParent;
1110               learnerPtr.numTotalParent = numTotalParent + 1;
1111             }
1112
1113 #ifdef LEARNER_TRY_REMOVE
1114           } else if(op == OPERATION_REMOVE) {
1115             atomic {
1116               learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1117                                                 fromId,
1118                                                 queries,
1119                                                 queryVectorPtr,
1120                                                 parentQueryVectorPtr);
1121               newBaseLogLikelihood =
1122                 learnerPtr. computeLocalLogLikelihood(fromId,
1123                                               learnerPtr.adtreePtr,
1124                                               learnerPtr.netPtr,
1125                                               queries,
1126                                               queryVectorPtr, 
1127                                               parentQueryVectorPtr);
1128               float fromLocalBaseLogLikelihood =
1129                     learnerPtr.localBaseLogLikelihoods[fromId];
1130               deltaLogLikelihood +=
1131                     fromLocalBaseLogLikelihood - newBaseLogLikelihood;
1132               learnerPtr.localBaseLogLikelihoods[fromId] = newBaseLogLikelihood;
1133             }
1134
1135             atomic{ 
1136               int numTotalParent = learnerPtr.numTotalParent;
1137               learnerPtr.numTotalParent = numTotalParent - 1;
1138             }
1139
1140 #endif // LEARNER_TRY_REMOVE
1141 #ifdef LEARNER_TRY_REVERSE
1142           } else if(op == OPERATION_REVERSE) {
1143             atomic {
1144               learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1145                                                 fromId,
1146                                                 queries,
1147                                                 queryVectorPtr,
1148                                                 parentQueryVectorPtr);
1149               newBaseLogLikelihood =
1150                 learnerPtr.computeLocalLogLikelihood(fromId,
1151                                         learnerPtr.adtreePtr,
1152                                         learnerPtr.netPtr,
1153                                         queries,
1154                                         queryVectorPtr,
1155                                         parentQueryVectorPtr);
1156               float fromLocalBaseLogLikelihood =
1157                           learnerPtr.localBaseLogLikelihoods[fromId];
1158               deltaLogLikelihood +=
1159                 fromLocalBaseLogLikelihood - newBaseLogLikelihood;
1160               learnerPtr.localBaseLogLikelihoods[fromId] = newBaseLogLikelihood;
1161             }
1162
1163             atomic {
1164               learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1165                                                 toId,
1166                                                 queries,
1167                                                 queryVectorPtr,
1168                                                 parentQueryVectorPtr);
1169               newBaseLogLikelihood =
1170                 learnerPtr.computeLocalLogLikelihood(toId,
1171                                         learnerPtr.adtreePtr,
1172                                         learnerPtr.netPtr,
1173                                         queries,
1174                                         queryVectorPtr,
1175                                         parentQueryVectorPtr);
1176               float toLocalBaseLogLikelihood =
1177                         learnerPtr.localBaseLogLikelihoods[toId];
1178               deltaLogLikelihood +=
1179                 toLocalBaseLogLikelihood - newBaseLogLikelihood;
1180               learnerPtr.localBaseLogLikelihoods[toId] = newBaseLogLikelihood;
1181             }
1182
1183 #endif // LEARNER_TRY_REVERSE 
1184           }
1185
1186         } //if isTaskValid
1187
1188         /*
1189          * Update/read globals
1190          */
1191
1192         float baseLogLikelihood;
1193         int numTotalParent;
1194
1195         atomic {
1196           float oldBaseLogLikelihood = learnerPtr.baseLogLikelihood;
1197           float newBaseLogLikelihood = oldBaseLogLikelihood + deltaLogLikelihood;
1198           learnerPtr.baseLogLikelihood = newBaseLogLikelihood;
1199           baseLogLikelihood = newBaseLogLikelihood;
1200           numTotalParent = learnerPtr.numTotalParent;
1201         }
1202
1203         /*
1204          * Find next task
1205          */
1206
1207
1208         float baseScore = ((float)numTotalParent * basePenalty)
1209           + (numRecord * baseLogLikelihood);
1210
1211         LearnerTask bestTask = new LearnerTask();
1212         bestTask.op     = NUM_OPERATION;
1213         bestTask.toId   = -1;
1214         bestTask.fromId = -1;
1215         bestTask.score  = baseScore;
1216
1217         LearnerTask newTask = new LearnerTask();
1218
1219         arg.toId              = toId;
1220         arg.numTotalParent    = numTotalParent;
1221         arg.basePenalty       = basePenalty;
1222         arg.baseLogLikelihood = baseLogLikelihood;
1223
1224         atomic {
1225           newTask = learnerPtr.TMfindBestInsertTask(arg);
1226         }
1227
1228         if ((newTask.fromId != newTask.toId) &&
1229             (newTask.score > (bestTask.score / operationQualityFactor)))
1230         {
1231           bestTask = newTask;
1232         }
1233
1234 #ifdef LEARNER_TRY_REMOVE
1235         atomic {
1236           newTask = learnerPtr.TMfindBestRemoveTask(arg);
1237         }
1238
1239         if ((newTask.fromId != newTask.toId) &&
1240             (newTask.score > (bestTask.score / operationQualityFactor)))
1241         {
1242           bestTask = newTask;
1243         }
1244 #endif // LEARNER_TRY_REMOVE 
1245
1246 #ifdef LEARNER_TRY_REVERSE
1247         atomic {
1248           newTask = learnerPtr.TMfindBestReverseTask(arg);
1249         }
1250
1251         if ((newTask.fromId != newTask.toId) &&
1252             (newTask.score > (bestTask.score / operationQualityFactor)))
1253         {
1254           bestTask = newTask;
1255         }
1256 #endif // LEARNER_TRY_REVERSE 
1257
1258         if (bestTask.toId != -1) {
1259           LearnerTask[] tasks = learnerPtr.tasks;
1260           tasks[toId] = bestTask;
1261           atomic {
1262             learnerPtr.taskListPtr.list_insert(tasks[toId]);
1263           }
1264 #ifdef TEST_LEARNER
1265           System.out.println("[new]  op= " + bestTask.op + " from= "+ bestTask.fromId + " to= "+ bestTask.toId + 
1266               " score= " + bestTask.score);
1267 #endif
1268         }
1269
1270       } // while (tasks) 
1271
1272       visitedBitmapPtr.bitmap_free();
1273       workQueuePtr.queue_free();
1274       bQueryVectorPtr.clear();
1275       aQueryVectorPtr.clear();
1276       queryVectorPtr.clear();
1277       parentQueryVectorPtr.clear();
1278       queries = null;
1279     }
1280
1281
1282   /* =============================================================================
1283    * learner_run
1284    * -- Call adtree_make before this
1285    * =============================================================================
1286    */
1287   //Is not called anywhere now parallel code
1288   public void
1289     learner_run (int myId, int numThread, Learner learnerPtr)
1290     {
1291       {
1292         createTaskList(myId, numThread, learnerPtr);
1293       }
1294       {
1295         learnStructure(myId, numThread, learnerPtr);
1296       }
1297     }
1298
1299   /* =============================================================================
1300    * learner_score
1301    * -- Score entire network
1302    * =============================================================================
1303    */
1304   public float
1305     learner_score ()
1306     {
1307
1308       Vector_t queryVectorPtr = new Vector_t(1);
1309       Vector_t parentQueryVectorPtr = new Vector_t(1);
1310
1311       int numVar = adtreePtr.numVar;
1312       Query[] queries = new Query[numVar];
1313
1314       for (int v = 0; v < numVar; v++) {
1315         queries[v] = new Query();
1316         queries[v].index = v;
1317         queries[v].value = QUERY_VALUE_WILDCARD;
1318       }
1319
1320       int numTotalParent = 0;
1321       float logLikelihood = 0.0f;
1322
1323       for (int v = 0; v < numVar; v++) {
1324
1325         IntList parentIdListPtr = netPtr.net_getParentIdListPtr(v);
1326         numTotalParent += parentIdListPtr.list_getSize();
1327
1328         populateQueryVectors(netPtr,
1329             v,
1330             queries,
1331             queryVectorPtr,
1332             parentQueryVectorPtr);
1333         float localLogLikelihood = computeLocalLogLikelihood(v,
1334             adtreePtr,
1335             netPtr,
1336             queries,
1337             queryVectorPtr,
1338             parentQueryVectorPtr);
1339         logLikelihood += localLogLikelihood;
1340       }
1341
1342       queryVectorPtr.clear();
1343       parentQueryVectorPtr.clear();
1344       queries = null;
1345
1346
1347       int numRecord = adtreePtr.numRecord;
1348       float penalty = (float)(-0.5f * (double)numTotalParent * Math.log((double)numRecord));
1349       float score = penalty + (float)numRecord * logLikelihood;
1350
1351       return score;
1352     }
1353 }
1354
1355 /* =============================================================================
1356  *
1357  * End of learner.java
1358  *
1359  * =============================================================================
1360  */