1 /* =============================================================================
4 * -- Learns structure of Bayesian net from data
6 * =============================================================================
8 * Copyright (C) Stanford University, 2006. All Rights Reserved.
10 * Ported to Java June 2009 Alokika Dash
11 * University of California, Irvine
14 * =============================================================================
16 * The penalized log-likelihood score (Friedman & Yahkani, 1996) is used to
17 * evaluated the "goodness" of a Bayesian net:
21 * -N_params * ln(R) / 2 + R > > > P((a_j = v), X_j) ln P(a_j = v | X_j)
27 * N_params total number of parents across all variables
29 * M number of variables
30 * X_j parents of the jth variable
31 * n_j number of attributes of the jth variable
34 * The second summation of X_j varies across all possible assignments to the
35 * values of the parents X_j.
39 * "local log likelihood" is P((a_j = v), X_j) ln P(a_j = v | X_j)
40 * "log likelihood" is everything to the right of the '+', i.e., "R ... X_j)"
41 * "base penalty" is -ln(R) / 2
42 * "penalty" is N_params * -ln(R) / 2
43 * "score" is the entire expression
45 * For more notes, refer to:
47 * A. Moore and M.-S. Lee. Cached sufficient statistics for efficient machine
48 * learning with large datasets. Journal of Artificial Intelligence Research 8
51 * =============================================================================
53 * The search strategy uses a combination of local and global structure search.
54 * Similar to the technique described in:
56 * D. M. Chickering, D. Heckerman, and C. Meek. A Bayesian approach to learning
57 * Bayesian networks with local structure. In Proceedings of Thirteenth
58 * Conference on Uncertainty in Artificial Intelligence (1997), pp. 80-89.
60 * =============================================================================
62 * For the license of bayes/sort.h and bayes/sort.c, please see the header
65 * ------------------------------------------------------------------------
67 * Unless otherwise noted, the following license applies to STAMP files:
69 * Copyright (c) 2007, Stanford University
70 * All rights reserved.
72 * Redistribution and use in source and binary forms, with or without
73 * modification, are permitted provided that the following conditions are
76 * * Redistributions of source code must retain the above copyright
77 * notice, this list of conditions and the following disclaimer.
79 * * Redistributions in binary form must reproduce the above copyright
80 * notice, this list of conditions and the following disclaimer in
81 * the documentation and/or other materials provided with the
84 * * Neither the name of Stanford University nor the names of its
85 * contributors may be used to endorse or promote products derived
86 * from this software without specific prior written permission.
88 * THIS SOFTWARE IS PROVIDED BY STANFORD UNIVERSITY ``AS IS'' AND ANY
89 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
90 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
91 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL STANFORD UNIVERSITY BE LIABLE
92 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
93 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
94 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
95 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
96 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
97 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
98 * THE POSSIBILITY OF SUCH DAMAGE.
100 * =============================================================================
104 #define CACHE_LINE_SIZE 64
105 #define QUERY_VALUE_WILDCARD -1
106 #define OPERATION_INSERT 0
107 #define OPERATION_REMOVE 1
108 #define OPERATION_REVERSE 2
109 #define NUM_OPERATION 3
111 public class Learner {
114 float[] localBaseLogLikelihoods;
115 float baseLogLikelihood;
119 int global_insertPenalty;
120 int global_maxNumEdgeLearned;
121 float global_operationQualityFactor;
125 global_maxNumEdgeLearned = -1;
126 global_insertPenalty = 1;
127 global_operationQualityFactor = 1.0F;
131 /* =============================================================================
133 * =============================================================================
135 public static Learner
136 learner_alloc (Data dataPtr,
139 int global_insertPenalty,
140 int global_maxNumEdgeLearned,
141 float global_operationQualityFactor)
143 Learner learnerPtr = new Learner();
145 if (learnerPtr != null) {
146 learnerPtr.adtreePtr = adtreePtr;
147 learnerPtr.netPtr = Net.net_alloc(dataPtr.numVar);
148 learnerPtr.localBaseLogLikelihoods = new float[dataPtr.numVar];
149 learnerPtr.baseLogLikelihood = 0.0f;
150 learnerPtr.tasks = new LearnerTask[dataPtr.numVar];
151 learnerPtr.taskListPtr = List.list_alloc();
152 learnerPtr.numTotalParent = 0;
154 learnerPtr.global_insertPenalty = global_insertPenalty;
155 learnerPtr.global_maxNumEdgeLearned = global_maxNumEdgeLearned;
156 learnerPtr.global_operationQualityFactor = global_operationQualityFactor;
164 /* =============================================================================
166 * =============================================================================
171 taskListPtr.list_free();
173 localBaseLogLikelihoods = null;
178 /* =============================================================================
179 * computeSpecificLocalLogLikelihood
180 * -- Query vectors should not contain wildcards
181 * =============================================================================
184 computeSpecificLocalLogLikelihood (Adtree adtreePtr,
185 Vector_t queryVectorPtr,
186 Vector_t parentQueryVectorPtr)
188 int count = adtreePtr.adtree_getCount(queryVectorPtr);
193 double probability = (double)count / (double)adtreePtr.numRecord;
194 int parentCount = adtreePtr.adtree_getCount(parentQueryVectorPtr);
196 if(parentCount < count || parentCount <= 0) {
200 float fval = (float)(probability * (Math.log((double)count/ (double)parentCount)));
206 /* =============================================================================
208 * =============================================================================
211 createPartition (int min, int max, int id, int n, LocalStartStop lss)
213 int range = max - min;
214 int chunk = Math.imax(1, ((range + n/2) / n)); // rounded
215 int start = min + chunk * id;
220 stop = Math.imin(max, (start + chunk));
227 /* =============================================================================
229 * -- baseLogLikelihoods and taskListPtr are updated
230 * =============================================================================
233 createTaskList (int myId, int numThread, Learner learnerPtr)
237 Query[] queries = new Query[2];
238 queries[0] = new Query();
239 queries[1] = new Query();
241 Vector_t queryVectorPtr = Vector_t.vector_alloc(2);
242 if(queryVectorPtr == null) {
243 System.out.println("Assert failed: cannot allocate vector");
247 if((status = queryVectorPtr.vector_pushBack(queries[0])) == false) {
248 System.out.println("Assert failed: status = "+ status + "vector_pushBack failed in createTaskList()");
252 Query parentQuery = new Query();
253 Vector_t parentQueryVectorPtr = Vector_t.vector_alloc(1);
255 if(parentQueryVectorPtr == null) {
256 System.out.println("Assert failed: for vector_alloc at createTaskList()");
260 int numVar = learnerPtr.adtreePtr.numVar;
261 int numRecord = learnerPtr.adtreePtr.numRecord;
262 float baseLogLikelihood = 0.0f;
263 float penalty = (float)(-0.5f * Math.log((double)numRecord)); // only add 1 edge
265 LocalStartStop lss = new LocalStartStop();
266 learnerPtr.createPartition(0, numVar, myId, numThread, lss);
269 * Compute base log likelihood for each variable and total base loglikelihood
272 for (int v = lss.i_start; v < lss.i_stop; v++) {
274 float localBaseLogLikelihood = 0.0f;
275 queries[0].index = v;
277 queries[0].value = 0;
278 localBaseLogLikelihood +=
279 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
281 parentQueryVectorPtr);
283 queries[0].value = 1;
284 localBaseLogLikelihood +=
285 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
287 parentQueryVectorPtr);
289 learnerPtr.localBaseLogLikelihoods[v] = localBaseLogLikelihood;
290 baseLogLikelihood += localBaseLogLikelihood;
292 } // for each variable
295 float globalBaseLogLikelihood =
296 learnerPtr.baseLogLikelihood;
297 learnerPtr.baseLogLikelihood = (baseLogLikelihood + globalBaseLogLikelihood);
301 * For each variable, find if the addition of any edge _to_ it is better
304 if((status = parentQueryVectorPtr.vector_pushBack(parentQuery)) == false) {
305 System.out.println("Assert failed: status = "+ status + " vector_pushBack failed in createPartition()");
309 for (int v = lss.i_start; v < lss.i_stop; v++) {
311 //Compute base log likelihood for this variable
313 queries[0].index = v;
314 int bestLocalIndex = v;
315 float bestLocalLogLikelihood = learnerPtr.localBaseLogLikelihoods[v];
317 if((status = queryVectorPtr.vector_pushBack(queries[1])) == false) {
318 System.out.println("Assert failed: status = "+ status + " vector_pushBack failed in createPartition()");
322 for (int vv = 0; vv < numVar; vv++) {
327 parentQuery.index = vv;
329 queries[0].index = v;
330 queries[1].index = vv;
332 queries[0].index = vv;
333 queries[1].index = v;
336 float newLocalLogLikelihood = 0.0f;
338 queries[0].value = 0;
339 queries[1].value = 0;
340 parentQuery.value = 0;
341 newLocalLogLikelihood +=
342 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
344 parentQueryVectorPtr);
346 queries[0].value = 0;
347 queries[1].value = 1;
348 parentQuery.value = ((vv < v) ? 0 : 1);
349 newLocalLogLikelihood +=
350 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
352 parentQueryVectorPtr);
354 queries[0].value = 1;
355 queries[1].value = 0;
356 parentQuery.value = ((vv < v) ? 1 : 0);
357 newLocalLogLikelihood +=
358 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
360 parentQueryVectorPtr);
362 queries[0].value = 1;
363 queries[1].value = 1;
364 parentQuery.value = 1;
365 newLocalLogLikelihood +=
366 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
368 parentQueryVectorPtr);
370 if (newLocalLogLikelihood > bestLocalLogLikelihood) {
372 bestLocalLogLikelihood = newLocalLogLikelihood;
375 } // foreach other variable
377 queryVectorPtr.vector_popBack();
379 if (bestLocalIndex != v) {
380 float logLikelihood = numRecord * (baseLogLikelihood +
381 + bestLocalLogLikelihood
382 - learnerPtr.localBaseLogLikelihoods[v]);
383 float score = penalty + logLikelihood;
385 learnerPtr.tasks[v] = new LearnerTask();
386 LearnerTask taskPtr = learnerPtr.tasks[v];
387 taskPtr.op = OPERATION_INSERT;
388 taskPtr.fromId = bestLocalIndex;
390 taskPtr.score = score;
392 status = learnerPtr.taskListPtr.list_insert(taskPtr);
395 if(status == false) {
396 System.out.println("Assert failed: atomic list insert failed at createTaskList()");
401 } // for each variable
404 queryVectorPtr.vector_free();
405 parentQueryVectorPtr.vector_free();
408 ListNode it = learnerPtr.taskListPtr.head;
410 while (it.nextPtr!=null) {
412 LearnerTask taskPtr = it.dataPtr;
413 System.out.println("[task] op= "+ taskPtr.op +" from= "+taskPtr.fromId+" to= " +taskPtr.toId+
414 " score= " + taskPtr.score);
416 #endif // TEST_LEARNER
420 /* =============================================================================
422 * -- Returns null is list is empty
423 * =============================================================================
425 public LearnerTask TMpopTask (List taskListPtr)
427 LearnerTask taskPtr = null;
429 ListNode it = taskListPtr.head;
430 if (it.nextPtr!=null) {
432 taskPtr = it.dataPtr;
433 boolean status = taskListPtr.list_remove(taskPtr);
434 if(status == false) {
435 System.out.println("Assert failed: when removing from a list in TMpopTask()");
444 /* =============================================================================
445 * populateParentQuery
446 * -- Modifies contents of parentQueryVectorPtr
447 * =============================================================================
450 populateParentQueryVector (Net netPtr,
453 Vector_t parentQueryVectorPtr)
455 parentQueryVectorPtr.vector_clear();
457 IntList parentIdListPtr = netPtr.net_getParentIdListPtr(id);
458 IntListNode it = parentIdListPtr.head;
459 while (it.nextPtr!=null) {
461 int parentId = it.dataPtr;
462 boolean status = parentQueryVectorPtr.vector_pushBack(queries[parentId]);
463 if(status == false) {
464 System.out.println("Assert failed: unable to pushBack in queue");
471 /* =============================================================================
472 * TMpopulateParentQuery
473 * -- Modifies contents of parentQueryVectorPtr
474 * =============================================================================
477 TMpopulateParentQueryVector (Net netPtr,
480 Vector_t parentQueryVectorPtr)
482 parentQueryVectorPtr.vector_clear();
484 IntList parentIdListPtr = netPtr.net_getParentIdListPtr(id);
485 IntListNode it = parentIdListPtr.head;
487 while (it.nextPtr!=null) {
489 int parentId = it.dataPtr;
490 boolean status = parentQueryVectorPtr.vector_pushBack(queries[parentId]);
491 if(status == false) {
492 System.out.println("Assert failed: unable to pushBack in queue in TMpopulateParentQueryVector()");
499 /* =============================================================================
500 * populateQueryVectors
501 * -- Modifies contents of queryVectorPtr and parentQueryVectorPtr
502 * =============================================================================
505 populateQueryVectors (Net netPtr,
508 Vector_t queryVectorPtr,
509 Vector_t parentQueryVectorPtr)
511 populateParentQueryVector(netPtr, id, queries, parentQueryVectorPtr);
514 status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
515 if(status == false ) {
516 System.out.println("Assert failed: while vector copy in populateQueryVectors()");
520 status = queryVectorPtr.vector_pushBack(queries[id]);
521 if(status == false ) {
522 System.out.println("Assert failed: while vector pushBack in populateQueryVectors()");
526 queryVectorPtr.vector_sort();
530 /* =============================================================================
531 * TMpopulateQueryVectors
532 * -- Modifies contents of queryVectorPtr and parentQueryVectorPtr
533 * =============================================================================
536 TMpopulateQueryVectors (Net netPtr,
539 Vector_t queryVectorPtr,
540 Vector_t parentQueryVectorPtr)
542 TMpopulateParentQueryVector(netPtr, id, queries, parentQueryVectorPtr);
545 status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
546 if(status == false ) {
547 System.out.println("Assert failed: while vector copy in TMpopulateQueryVectors()");
550 status = queryVectorPtr.vector_pushBack(queries[id]);
551 if(status == false ) {
552 System.out.println("Assert failed: while vector pushBack in TMpopulateQueryVectors()");
556 queryVectorPtr.vector_sort();
559 /* =============================================================================
560 * computeLocalLogLikelihoodHelper
561 * -- Recursive helper routine
562 * =============================================================================
565 computeLocalLogLikelihoodHelper (int i,
569 Vector_t queryVectorPtr,
570 Vector_t parentQueryVectorPtr)
572 if (i >= numParent) {
573 return computeSpecificLocalLogLikelihood(adtreePtr,
575 parentQueryVectorPtr);
578 float localLogLikelihood = 0.0f;
580 Query parentQueryPtr = (Query) (parentQueryVectorPtr.vector_at(i));
581 int parentIndex = parentQueryPtr.index;
583 queries[parentIndex].value = 0;
584 localLogLikelihood += computeLocalLogLikelihoodHelper((i + 1),
589 parentQueryVectorPtr);
591 queries[parentIndex].value = 1;
592 localLogLikelihood += computeLocalLogLikelihoodHelper((i + 1),
597 parentQueryVectorPtr);
599 queries[parentIndex].value = QUERY_VALUE_WILDCARD;
601 return localLogLikelihood;
605 /* =============================================================================
606 * computeLocalLogLikelihood
607 * -- Populate the query vectors before passing as args
608 * =============================================================================
611 computeLocalLogLikelihood (int id,
615 Vector_t queryVectorPtr,
616 Vector_t parentQueryVectorPtr)
618 int numParent = parentQueryVectorPtr.vector_getSize();
619 float localLogLikelihood = 0.0f;
621 queries[id].value = 0;
622 localLogLikelihood += computeLocalLogLikelihoodHelper(0,
627 parentQueryVectorPtr);
629 queries[id].value = 1;
630 localLogLikelihood += computeLocalLogLikelihoodHelper(0,
635 parentQueryVectorPtr);
637 queries[id].value = QUERY_VALUE_WILDCARD;
639 return localLogLikelihood;
643 /* =============================================================================
644 * TMfindBestInsertTask
645 * =============================================================================
648 TMfindBestInsertTask (FindBestTaskArg argPtr)
650 int toId = argPtr.toId;
651 Learner learnerPtr = argPtr.learnerPtr;
652 Query[] queries = argPtr.queries;
653 Vector_t queryVectorPtr = argPtr.queryVectorPtr;
654 Vector_t parentQueryVectorPtr = argPtr.parentQueryVectorPtr;
655 int numTotalParent = argPtr.numTotalParent;
656 float basePenalty = argPtr.basePenalty;
657 float baseLogLikelihood = argPtr.baseLogLikelihood;
658 BitMap invalidBitmapPtr = argPtr.bitmapPtr;
659 Queue workQueuePtr = argPtr.workQueuePtr;
660 Vector_t baseParentQueryVectorPtr = argPtr.aQueryVectorPtr;
661 Vector_t baseQueryVectorPtr = argPtr.bQueryVectorPtr;
664 Adtree adtreePtr = learnerPtr.adtreePtr;
665 Net netPtr = learnerPtr.netPtr;
667 TMpopulateParentQueryVector(netPtr, toId, queries, parentQueryVectorPtr);
670 * Create base query and parentQuery
673 status = Vector_t.vector_copy(baseParentQueryVectorPtr, parentQueryVectorPtr);
674 if(status == false) {
675 System.out.println("Assert failed: copying baseParentQuery vector in TMfindBestInsertTask");
679 status = Vector_t.vector_copy(baseQueryVectorPtr, baseParentQueryVectorPtr);
680 if(status == false) {
681 System.out.println("Assert failed: copying baseQuery vector in TMfindBestInsertTask");
685 status = baseQueryVectorPtr.vector_pushBack(queries[toId]);
686 if(status == false ) {
687 System.out.println("Assert failed: while vector pushBack in TMfindBestInsertTask()");
691 queryVectorPtr.vector_sort();
694 * Search all possible valid operations for better local log likelihood
697 int bestFromId = toId; // flag for not found
698 float oldLocalLogLikelihood = learnerPtr.localBaseLogLikelihoods[toId];
699 float bestLocalLogLikelihood = oldLocalLogLikelihood;
701 status = netPtr.net_findDescendants(toId, invalidBitmapPtr, workQueuePtr);
702 if(status == false) {
703 System.out.println("Assert failed: while net_findDescendants in TMfindBestInsertTask()");
709 IntList parentIdListPtr = netPtr.net_getParentIdListPtr(toId);
711 int maxNumEdgeLearned = global_maxNumEdgeLearned;
713 if ((maxNumEdgeLearned < 0) ||
714 (parentIdListPtr.list_getSize() <= maxNumEdgeLearned))
717 IntListNode it = parentIdListPtr.head;
719 while(it.nextPtr!=null) {
721 int parentId = it.dataPtr;
722 invalidBitmapPtr.bitmap_set(parentId); // invalid since already have edge
725 while ((fromId = invalidBitmapPtr.bitmap_findClear((fromId + 1))) >= 0) {
727 if (fromId == toId) {
731 status = Vector_t.vector_copy(queryVectorPtr, baseQueryVectorPtr);
732 if(status == false) {
733 System.out.println("Assert failed: copying query vector in TMfindBestInsertTask");
737 status = queryVectorPtr.vector_pushBack(queries[fromId]);
738 if(status == false) {
739 System.out.println("Assert failed: vector pushback for query in TMfindBestInsertTask");
743 queryVectorPtr.vector_sort();
745 status = Vector_t.vector_copy(parentQueryVectorPtr, baseParentQueryVectorPtr);
746 if(status == false) {
747 System.out.println("Assert failed: copying parentQuery vector in TMfindBestInsertTask");
751 status = parentQueryVectorPtr.vector_pushBack(queries[fromId]);
752 if(status == false) {
753 System.out.println("Assert failed: vector pushBack for parentQuery in TMfindBestInsertTask");
757 parentQueryVectorPtr.vector_sort();
759 float newLocalLogLikelihood =
760 computeLocalLogLikelihood(toId,
765 parentQueryVectorPtr);
767 if (newLocalLogLikelihood > bestLocalLogLikelihood) {
768 bestLocalLogLikelihood = newLocalLogLikelihood;
772 } // foreach valid parent
774 } // if have not exceeded max number of edges to learn
777 * Return best task; Note: if none is better, fromId will equal toId
780 LearnerTask bestTask = new LearnerTask();
781 bestTask.op = OPERATION_INSERT;
782 bestTask.fromId = bestFromId;
783 bestTask.toId = toId;
784 bestTask.score = 0.0f;
786 if (bestFromId != toId) {
787 int numRecord = adtreePtr.numRecord;
788 int numParent = parentIdListPtr.list_getSize() + 1;
790 (numTotalParent + numParent * global_insertPenalty) * basePenalty;
791 float logLikelihood = numRecord * (baseLogLikelihood +
792 + bestLocalLogLikelihood
793 - oldLocalLogLikelihood);
794 float bestScore = penalty + logLikelihood;
795 bestTask.score = bestScore;
801 #ifdef LEARNER_TRY_REMOVE
802 /* =============================================================================
803 * TMfindBestRemoveTask
804 * =============================================================================
807 TMfindBestRemoveTask (FindBestTaskArg argPtr)
809 int toId = argPtr.toId;
810 Learner learnerPtr = argPtr.learnerPtr;
811 Query[] queries = argPtr.queries;
812 Vector_t queryVectorPtr = argPtr.queryVectorPtr;
813 Vector_t parentQueryVectorPtr = argPtr.parentQueryVectorPtr;
814 int numTotalParent = argPtr.numTotalParent;
815 float basePenalty = argPtr.basePenalty;
816 float baseLogLikelihood = argPtr.baseLogLikelihood;
817 Vector_t origParentQueryVectorPtr = argPtr.aQueryVectorPtr;
820 Adtree adtreePtr = learnerPtr.adtreePtr;
821 Net netPtr = learnerPtr.netPtr;
822 float[] localBaseLogLikelihoods = learnerPtr.localBaseLogLikelihoods;
824 TMpopulateParentQueryVector(netPtr, toId, queries, origParentQueryVectorPtr);
825 int numParent = origParentQueryVectorPtr.vector_getSize();
828 * Search all possible valid operations for better local log likelihood
831 int bestFromId = toId; // flag for not found
832 float oldLocalLogLikelihood = localBaseLogLikelihoods[toId];
833 float bestLocalLogLikelihood = oldLocalLogLikelihood;
836 for (i = 0; i < numParent; i++) {
838 Query queryPtr = (Query) (origParentQueryVectorPtr.vector_at(i));
839 int fromId = queryPtr.index;
842 * Create parent query (subset of parents since remove an edge)
845 parentQueryVectorPtr.vector_clear();
847 for (int p = 0; p < numParent; p++) {
849 Query tmpqueryPtr = (Query) (origParentQueryVectorPtr.vector_at(p));
850 status = parentQueryVectorPtr.vector_pushBack(queries[tmpqueryPtr.index]);
851 if(status == false) {
852 System.out.println("Assert failed: vector_pushBack to parentQuery in TMfindBestRemoveTask()");
856 } // create new parent query
862 status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
863 if(status == false) {
864 System.out.println("Assert failed: while vector copy to query in TMfindBestRemoveTask()");
868 status = queryVectorPtr.vector_pushBack(queries[toId]);
869 if(status == false) {
870 System.out.println("Assert failed: while vector_pushBack to query in TMfindBestRemoveTask()");
874 queryVectorPtr.vector_sort();
877 * See if removing parent is better
880 float newLocalLogLikelihood =
881 computeLocalLogLikelihood(toId,
886 parentQueryVectorPtr);
888 if (newLocalLogLikelihood > bestLocalLogLikelihood) {
889 bestLocalLogLikelihood = newLocalLogLikelihood;
896 * Return best task; Note: if none is better, fromId will equal toId
899 LearnerTask bestTask = new LearnerTask();
900 bestTask.op = OPERATION_REMOVE;
901 bestTask.fromId = bestFromId;
902 bestTask.toId = toId;
903 bestTask.score = 0.0f;
905 if (bestFromId != toId) {
906 int numRecord = adtreePtr.numRecord;
907 float penalty = (numTotalParent - 1) * basePenalty;
908 float logLikelihood = numRecord * (baseLogLikelihood +
909 + bestLocalLogLikelihood
910 - oldLocalLogLikelihood);
911 float bestScore = penalty + logLikelihood;
912 bestTask.score = bestScore;
917 #endif /* LEARNER_TRY_REMOVE */
920 #ifdef LEARNER_TRY_REVERSE
921 /* =============================================================================
922 * TMfindBestReverseTask
923 * =============================================================================
926 TMfindBestReverseTask (FindBestTaskArg argPtr)
928 int toId = argPtr.toId;
929 Learner learnerPtr = argPtr.learnerPtr;
930 Query[] queries = argPtr.queries;
931 Vector_t queryVectorPtr = argPtr.queryVectorPtr;
932 Vector_t parentQueryVectorPtr = argPtr.parentQueryVectorPtr;
933 int numTotalParent = argPtr.numTotalParent;
934 float basePenalty = argPtr.basePenalty;
935 float baseLogLikelihood = argPtr.baseLogLikelihood;
936 BitMap visitedBitmapPtr = argPtr.bitmapPtr;
937 Queue workQueuePtr = argPtr.workQueuePtr;
938 Vector_t toOrigParentQueryVectorPtr = argPtr.aQueryVectorPtr;
939 Vector_t fromOrigParentQueryVectorPtr = argPtr.bQueryVectorPtr;
942 Adtree adtreePtr = learnerPtr.adtreePtr;
943 Net netPtr = learnerPtr.netPtr;
944 float[] localBaseLogLikelihoods = learnerPtr.localBaseLogLikelihoods;
946 TMpopulateParentQueryVector(netPtr, toId, queries, toOrigParentQueryVectorPtr);
947 int numParent = toOrigParentQueryVectorPtr.vector_getSize();
950 * Search all possible valid operations for better local log likelihood
953 int bestFromId = toId; // flag for not found
954 float oldLocalLogLikelihood = localBaseLogLikelihoods[toId];
955 float bestLocalLogLikelihood = oldLocalLogLikelihood;
958 for (int i = 0; i < numParent; i++) {
960 Query queryPtr = (Query) (toOrigParentQueryVectorPtr.vector_at(i));
961 fromId = queryPtr.index;
963 bestLocalLogLikelihood =
964 oldLocalLogLikelihood + localBaseLogLikelihoods[fromId];
966 TMpopulateParentQueryVector(netPtr,
969 fromOrigParentQueryVectorPtr);
972 * Create parent query (subset of parents since remove an edge)
975 parentQueryVectorPtr.vector_clear();
977 for (int p = 0; p < numParent; p++) {
979 Query tmpqueryPtr = (Query) (toOrigParentQueryVectorPtr.vector_at(p));
980 status = parentQueryVectorPtr.vector_pushBack(queries[tmpqueryPtr.index]);
981 if(status == false) {
982 System.out.println("Assert failed: while vector_pushBack parentQueryVectorPtr");
986 } // create new parent query
992 status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
993 if(status == false) {
994 System.out.println("Assert failed: while vector_copy in TMfindBestReverseTask()");
998 status = queryVectorPtr.vector_pushBack(queries[toId]);
999 if(status == false) {
1000 System.out.println("Assert failed: while vector_pushBack in TMfindBestReverseTask()");
1004 queryVectorPtr.vector_sort();
1007 * Get log likelihood for removing parent from toId
1010 float newLocalLogLikelihood =
1011 computeLocalLogLikelihood(toId,
1016 parentQueryVectorPtr);
1019 * Get log likelihood for adding parent to fromId
1022 status = Vector_t.vector_copy(parentQueryVectorPtr, fromOrigParentQueryVectorPtr);
1023 if(status == false) {
1024 System.out.println("Assert failed: while parentQuery vector_copy in TMfindBestReverseTask()");
1028 status = parentQueryVectorPtr.vector_pushBack(queries[toId]);
1029 if(status == false) {
1030 System.out.println("Assert failed: while vector_pushBack into parentQuery on TMfindBestReverseTask()");
1034 parentQueryVectorPtr.vector_sort();
1036 status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
1037 if(status == false) {
1038 System.out.println("Assert failed: while vector_copy on TMfindBestReverseTask()");
1042 status = queryVectorPtr.vector_pushBack(queries[fromId]);
1043 if(status == false) {
1044 System.out.println("Assert failed: while vector_pushBack on TMfindBestReverseTask()");
1048 queryVectorPtr.vector_sort();
1050 newLocalLogLikelihood +=
1051 computeLocalLogLikelihood(fromId,
1056 parentQueryVectorPtr);
1062 if (newLocalLogLikelihood > bestLocalLogLikelihood) {
1063 bestLocalLogLikelihood = newLocalLogLikelihood;
1064 bestFromId = fromId;
1067 } // for each parent
1070 * Check validity of best
1073 if (bestFromId != toId) {
1074 boolean isTaskValid = true;
1075 netPtr.net_applyOperation(OPERATION_REMOVE, bestFromId, toId);
1076 if (netPtr.net_isPath(bestFromId,
1081 isTaskValid = false;
1083 netPtr.net_applyOperation(OPERATION_INSERT, bestFromId, toId);
1090 * Return best task; Note: if none is better, fromId will equal toId
1093 LearnerTask bestTask = new LearnerTask();
1094 bestTask.op = OPERATION_REVERSE;
1095 bestTask.fromId = bestFromId;
1096 bestTask.toId = toId;
1097 bestTask.score = 0.0f;
1099 if (bestFromId != toId) {
1100 float fromLocalLogLikelihood = localBaseLogLikelihoods[bestFromId];
1101 int numRecord = adtreePtr.numRecord;
1102 float penalty = numTotalParent * basePenalty;
1103 float logLikelihood = numRecord * (baseLogLikelihood +
1104 + bestLocalLogLikelihood
1105 - oldLocalLogLikelihood
1106 - fromLocalLogLikelihood);
1107 float bestScore = penalty + logLikelihood;
1108 bestTask.score = bestScore;
1114 #endif /* LEARNER_TRY_REVERSE */
1117 /* =============================================================================
1120 * Note it is okay if the score is not exact, as we are relaxing the greedy
1121 * search. This means we do not need to communicate baseLogLikelihood across
1123 * =============================================================================
1126 learnStructure (int myId, int numThread, Learner learnerPtr)
1129 int numRecord = learnerPtr.adtreePtr.numRecord;
1131 float operationQualityFactor = learnerPtr.global_operationQualityFactor;
1133 BitMap visitedBitmapPtr = BitMap.bitmap_alloc(learnerPtr.adtreePtr.numVar);
1134 if(visitedBitmapPtr == null) {
1135 System.out.println("Assert failed: for bitmap alloc in learnStructure()");
1139 Queue workQueuePtr = Queue.queue_alloc(-1);
1140 if(workQueuePtr == null) {
1141 System.out.println("Assert failed: for vector alloc in learnStructure()");
1145 int numVar = learnerPtr.adtreePtr.numVar;
1146 Query[] queries = new Query[numVar];
1148 if(queries == null) {
1149 System.out.println("Assert failed: for queries alloc in learnStructure()");
1153 for (int v = 0; v < numVar; v++) {
1154 queries[v] = new Query();
1155 queries[v].index = v;
1156 queries[v].value = QUERY_VALUE_WILDCARD;
1159 float basePenalty = (float)(-0.5 * Math.log((double)numRecord));
1161 Vector_t queryVectorPtr = Vector_t.vector_alloc(1);
1162 if(queryVectorPtr == null) {
1163 System.out.println("Assert failed: for vector_alloc in learnStructure()");
1167 Vector_t parentQueryVectorPtr = Vector_t.vector_alloc(1);
1168 if(parentQueryVectorPtr == null) {
1169 System.out.println("Assert failed: for vector_alloc in learnStructure()");
1173 Vector_t aQueryVectorPtr = Vector_t.vector_alloc(1);
1174 if(aQueryVectorPtr == null) {
1175 System.out.println("Assert failed: for vector_alloc in learnStructure()");
1179 Vector_t bQueryVectorPtr = Vector_t.vector_alloc(1);
1180 if(bQueryVectorPtr == null) {
1181 System.out.println("Assert failed: for vector_alloc in learnStructure()");
1186 FindBestTaskArg arg = new FindBestTaskArg();
1187 arg.learnerPtr = learnerPtr;
1188 arg.queries = queries;
1189 arg.queryVectorPtr = queryVectorPtr;
1190 arg.parentQueryVectorPtr = parentQueryVectorPtr;
1191 arg.bitmapPtr = visitedBitmapPtr;
1192 arg.workQueuePtr = workQueuePtr;
1193 arg.aQueryVectorPtr = aQueryVectorPtr;
1194 arg.bQueryVectorPtr = bQueryVectorPtr;
1198 LearnerTask taskPtr;
1201 taskPtr = learnerPtr.TMpopTask(learnerPtr.taskListPtr);
1204 if (taskPtr == null) {
1208 int op = taskPtr.op;
1209 int fromId = taskPtr.fromId;
1210 int toId = taskPtr.toId;
1212 boolean isTaskValid;
1216 * Check if task is still valid
1220 if(op == OPERATION_INSERT) {
1221 if(learnerPtr.netPtr.net_hasEdge(fromId, toId) ||
1222 learnerPtr.netPtr.net_isPath(toId,
1227 isTaskValid = false;
1229 } else if (op == OPERATION_REMOVE) {
1230 // Can never create cycle, so always valid
1232 } else if (op == OPERATION_REVERSE) {
1233 // Temporarily remove edge for check
1234 learnerPtr.netPtr.net_applyOperation(OPERATION_REMOVE, fromId, toId);
1235 if(learnerPtr.netPtr.net_isPath(fromId,
1240 isTaskValid = false;
1242 learnerPtr.netPtr.net_applyOperation(OPERATION_INSERT, fromId, toId);
1244 System.out.println("Assert failed: We shouldn't get here in learnStructure()");
1250 System.out.println("[task] op= " + taskPtr.op + " from= " + taskPtr.fromId + " to= " +
1251 taskPtr.toId + " score= " + taskPtr.score + " valid= " + (isTaskValid ? "yes" : "no"));
1255 * Perform task: update graph and probabilities
1259 learnerPtr.netPtr.net_applyOperation(op, fromId, toId);
1264 float deltaLogLikelihood = 0.0f;
1267 float newBaseLogLikelihood;
1268 if(op == OPERATION_INSERT) {
1270 learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1274 parentQueryVectorPtr);
1275 newBaseLogLikelihood =
1276 learnerPtr.computeLocalLogLikelihood(toId,
1277 learnerPtr.adtreePtr,
1281 parentQueryVectorPtr);
1282 float toLocalBaseLogLikelihood = learnerPtr.localBaseLogLikelihoods[toId];
1283 deltaLogLikelihood +=
1284 toLocalBaseLogLikelihood - newBaseLogLikelihood;
1285 learnerPtr.localBaseLogLikelihoods[toId] = newBaseLogLikelihood;
1289 int numTotalParent = learnerPtr.numTotalParent;
1290 learnerPtr.numTotalParent = numTotalParent + 1;
1293 #ifdef LEARNER_TRY_REMOVE
1294 } else if(op == OPERATION_REMOVE) {
1296 learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1300 parentQueryVectorPtr);
1301 newBaseLogLikelihood =
1302 learnerPtr. computeLocalLogLikelihood(fromId,
1303 learnerPtr.adtreePtr,
1307 parentQueryVectorPtr);
1308 float fromLocalBaseLogLikelihood =
1309 learnerPtr.localBaseLogLikelihoods[fromId];
1310 deltaLogLikelihood +=
1311 fromLocalBaseLogLikelihood - newBaseLogLikelihood;
1312 learnerPtr.localBaseLogLikelihoods[fromId] = newBaseLogLikelihood;
1316 int numTotalParent = learnerPtr.numTotalParent;
1317 learnerPtr.numTotalParent = numTotalParent - 1;
1320 #endif // LEARNER_TRY_REMOVE
1321 #ifdef LEARNER_TRY_REVERSE
1322 } else if(op == OPERATION_REVERSE) {
1324 learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1328 parentQueryVectorPtr);
1329 newBaseLogLikelihood =
1330 learnerPtr.computeLocalLogLikelihood(fromId,
1331 learnerPtr.adtreePtr,
1335 parentQueryVectorPtr);
1336 float fromLocalBaseLogLikelihood =
1337 learnerPtr.localBaseLogLikelihoods[fromId];
1338 deltaLogLikelihood +=
1339 fromLocalBaseLogLikelihood - newBaseLogLikelihood;
1340 learnerPtr.localBaseLogLikelihoods[fromId] = newBaseLogLikelihood;
1344 learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1348 parentQueryVectorPtr);
1349 newBaseLogLikelihood =
1350 learnerPtr.computeLocalLogLikelihood(toId,
1351 learnerPtr.adtreePtr,
1355 parentQueryVectorPtr);
1356 float toLocalBaseLogLikelihood =
1357 learnerPtr.localBaseLogLikelihoods[toId];
1358 deltaLogLikelihood +=
1359 toLocalBaseLogLikelihood - newBaseLogLikelihood;
1360 learnerPtr.localBaseLogLikelihoods[toId] = newBaseLogLikelihood;
1363 #endif // LEARNER_TRY_REVERSE
1365 System.out.println("Assert failed: We should not reach here in learnerPtr()");
1372 * Update/read globals
1375 float baseLogLikelihood;
1379 float oldBaseLogLikelihood = learnerPtr.baseLogLikelihood;
1380 float newBaseLogLikelihood = oldBaseLogLikelihood + deltaLogLikelihood;
1381 learnerPtr.baseLogLikelihood = newBaseLogLikelihood;
1382 baseLogLikelihood = newBaseLogLikelihood;
1383 numTotalParent = learnerPtr.numTotalParent;
1391 float baseScore = ((float)numTotalParent * basePenalty)
1392 + (numRecord * baseLogLikelihood);
1394 LearnerTask bestTask = new LearnerTask();
1395 bestTask.op = NUM_OPERATION;
1397 bestTask.fromId = -1;
1398 bestTask.score = baseScore;
1400 LearnerTask newTask = new LearnerTask();
1403 arg.numTotalParent = numTotalParent;
1404 arg.basePenalty = basePenalty;
1405 arg.baseLogLikelihood = baseLogLikelihood;
1408 newTask = learnerPtr.TMfindBestInsertTask(arg);
1411 if ((newTask.fromId != newTask.toId) &&
1412 (newTask.score > (bestTask.score / operationQualityFactor)))
1417 #ifdef LEARNER_TRY_REMOVE
1419 newTask = learnerPtr.TMfindBestRemoveTask(arg);
1422 if ((newTask.fromId != newTask.toId) &&
1423 (newTask.score > (bestTask.score / operationQualityFactor)))
1427 #endif // LEARNER_TRY_REMOVE
1429 #ifdef LEARNER_TRY_REVERSE
1431 newTask = learnerPtr.TMfindBestReverseTask(arg);
1434 if ((newTask.fromId != newTask.toId) &&
1435 (newTask.score > (bestTask.score / operationQualityFactor)))
1439 #endif // LEARNER_TRY_REVERSE
1441 if (bestTask.toId != -1) {
1442 LearnerTask[] tasks = learnerPtr.tasks;
1443 tasks[toId] = bestTask;
1445 learnerPtr.taskListPtr.list_insert(tasks[toId]);
1448 System.out.println("[new] op= " + bestTask.op + " from= "+ bestTask.fromId + " to= "+ bestTask.toId +
1449 " score= " + bestTask.score);
1455 visitedBitmapPtr.bitmap_free();
1456 workQueuePtr.queue_free();
1457 bQueryVectorPtr.vector_free();
1458 aQueryVectorPtr.vector_free();
1459 queryVectorPtr.vector_free();
1460 parentQueryVectorPtr.vector_free();
1465 /* =============================================================================
1467 * -- Call adtree_make before this
1468 * =============================================================================
1470 //Is not called anywhere now parallel code
1472 learner_run (int myId, int numThread, Learner learnerPtr)
1475 createTaskList(myId, numThread, learnerPtr);
1478 learnStructure(myId, numThread, learnerPtr);
1482 /* =============================================================================
1484 * -- Score entire network
1485 * =============================================================================
1491 Vector_t queryVectorPtr = Vector_t.vector_alloc(1);
1492 Vector_t parentQueryVectorPtr = Vector_t.vector_alloc(1);
1494 int numVar = adtreePtr.numVar;
1495 Query[] queries = new Query[numVar];
1497 for (int v = 0; v < numVar; v++) {
1498 queries[v] = new Query();
1499 queries[v].index = v;
1500 queries[v].value = QUERY_VALUE_WILDCARD;
1503 int numTotalParent = 0;
1504 float logLikelihood = 0.0f;
1506 for (int v = 0; v < numVar; v++) {
1508 IntList parentIdListPtr = netPtr.net_getParentIdListPtr(v);
1509 numTotalParent += parentIdListPtr.list_getSize();
1511 populateQueryVectors(netPtr,
1515 parentQueryVectorPtr);
1516 float localLogLikelihood = computeLocalLogLikelihood(v,
1521 parentQueryVectorPtr);
1522 logLikelihood += localLogLikelihood;
1525 queryVectorPtr.vector_free();
1526 parentQueryVectorPtr.vector_free();
1530 int numRecord = adtreePtr.numRecord;
1531 float penalty = (float)(-0.5f * (double)numTotalParent * Math.log((double)numRecord));
1532 float score = penalty + (float)numRecord * logLikelihood;
1538 /* =============================================================================
1540 * End of learner.java
1542 * =============================================================================