1 /* =============================================================================
4 * -- Learns structure of Bayesian net from data
6 * =============================================================================
8 * Copyright (C) Stanford University, 2006. All Rights Reserved.
10 * Ported to Java June 2009 Alokika Dash
11 * University of California, Irvine
14 * =============================================================================
16 * The penalized log-likelihood score (Friedman & Yahkani, 1996) is used to
17 * evaluated the "goodness" of a Bayesian net:
21 * -N_params * ln(R) / 2 + R > > > P((a_j = v), X_j) ln P(a_j = v | X_j)
27 * N_params total number of parents across all variables
29 * M number of variables
30 * X_j parents of the jth variable
31 * n_j number of attributes of the jth variable
34 * The second summation of X_j varies across all possible assignments to the
35 * values of the parents X_j.
39 * "local log likelihood" is P((a_j = v), X_j) ln P(a_j = v | X_j)
40 * "log likelihood" is everything to the right of the '+', i.e., "R ... X_j)"
41 * "base penalty" is -ln(R) / 2
42 * "penalty" is N_params * -ln(R) / 2
43 * "score" is the entire expression
45 * For more notes, refer to:
47 * A. Moore and M.-S. Lee. Cached sufficient statistics for efficient machine
48 * learning with large datasets. Journal of Artificial Intelligence Research 8
51 * =============================================================================
53 * The search strategy uses a combination of local and global structure search.
54 * Similar to the technique described in:
56 * D. M. Chickering, D. Heckerman, and C. Meek. A Bayesian approach to learning
57 * Bayesian networks with local structure. In Proceedings of Thirteenth
58 * Conference on Uncertainty in Artificial Intelligence (1997), pp. 80-89.
60 * =============================================================================
62 * For the license of bayes/sort.h and bayes/sort.c, please see the header
65 * ------------------------------------------------------------------------
67 * Unless otherwise noted, the following license applies to STAMP files:
69 * Copyright (c) 2007, Stanford University
70 * All rights reserved.
72 * Redistribution and use in source and binary forms, with or without
73 * modification, are permitted provided that the following conditions are
76 * * Redistributions of source code must retain the above copyright
77 * notice, this list of conditions and the following disclaimer.
79 * * Redistributions in binary form must reproduce the above copyright
80 * notice, this list of conditions and the following disclaimer in
81 * the documentation and/or other materials provided with the
84 * * Neither the name of Stanford University nor the names of its
85 * contributors may be used to endorse or promote products derived
86 * from this software without specific prior written permission.
88 * THIS SOFTWARE IS PROVIDED BY STANFORD UNIVERSITY ``AS IS'' AND ANY
89 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
90 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
91 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL STANFORD UNIVERSITY BE LIABLE
92 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
93 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
94 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
95 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
96 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
97 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
98 * THE POSSIBILITY OF SUCH DAMAGE.
100 * =============================================================================
104 #define CACHE_LINE_SIZE 64
105 #define QUERY_VALUE_WILDCARD -1
106 #define OPERATION_INSERT 0
107 #define OPERATION_REMOVE 1
108 #define OPERATION_REVERSE 2
109 #define NUM_OPERATION 3
111 public class Learner {
114 float[] localBaseLogLikelihoods;
115 float baseLogLikelihood;
119 int global_insertPenalty;
120 int global_maxNumEdgeLearned;
121 float global_operationQualityFactor;
125 global_maxNumEdgeLearned = -1;
126 global_insertPenalty = 1;
127 global_operationQualityFactor = 1.0F;
131 /* =============================================================================
133 * =============================================================================
135 public Learner(Data dataPtr,
138 int global_insertPenalty,
139 int global_maxNumEdgeLearned,
140 float global_operationQualityFactor) {
141 this.adtreePtr = adtreePtr;
142 this.netPtr = new Net(dataPtr.numVar);
143 this.localBaseLogLikelihoods = new float[dataPtr.numVar];
144 this.baseLogLikelihood = 0.0f;
145 this.tasks = new LearnerTask[dataPtr.numVar];
146 this.taskListPtr = List.list_alloc();
147 this.numTotalParent = 0;
149 this.global_insertPenalty = global_insertPenalty;
150 this.global_maxNumEdgeLearned = global_maxNumEdgeLearned;
151 this.global_operationQualityFactor = global_operationQualityFactor;
155 public void learner_free() {
158 localBaseLogLikelihoods=null;
164 /* =============================================================================
165 * computeSpecificLocalLogLikelihood
166 * -- Query vectors should not contain wildcards
167 * =============================================================================
170 computeSpecificLocalLogLikelihood (Adtree adtreePtr,
171 Vector_t queryVectorPtr,
172 Vector_t parentQueryVectorPtr)
174 int count = adtreePtr.adtree_getCount(queryVectorPtr);
179 double probability = (double)count / (double)adtreePtr.numRecord;
180 int parentCount = adtreePtr.adtree_getCount(parentQueryVectorPtr);
183 float fval = (float)(probability * (Math.log((double)count/ (double)parentCount)));
189 /* =============================================================================
191 * =============================================================================
194 createPartition (int min, int max, int id, int n, LocalStartStop lss)
196 int range = max - min;
197 int chunk = Math.imax(1, ((range + n/2) / n)); // rounded
198 int start = min + chunk * id;
203 stop = Math.imin(max, (start + chunk));
211 /* =============================================================================
213 * -- baseLogLikelihoods and taskListPtr are updated
214 * =============================================================================
217 createTaskList (int myId, int numThread, Learner learnerPtr)
221 Query[] queries = new Query[2];
222 queries[0] = new Query();
223 queries[1] = new Query();
225 Vector_t queryVectorPtr = new Vector_t(2);
227 status = queryVectorPtr.vector_pushBack(queries[0]);
229 Query parentQuery = new Query();
230 Vector_t parentQueryVectorPtr = new Vector_t(1);
232 int numVar = learnerPtr.adtreePtr.numVar;
233 int numRecord = learnerPtr.adtreePtr.numRecord;
234 float baseLogLikelihood = 0.0f;
235 float penalty = (float)(-0.5f * Math.log((double)numRecord)); // only add 1 edge
237 LocalStartStop lss = new LocalStartStop();
238 learnerPtr.createPartition(0, numVar, myId, numThread, lss);
241 * Compute base log likelihood for each variable and total base loglikelihood
244 for (int v = lss.i_start; v < lss.i_stop; v++) {
246 float localBaseLogLikelihood = 0.0f;
247 queries[0].index = v;
249 queries[0].value = 0;
250 localBaseLogLikelihood +=
251 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
253 parentQueryVectorPtr);
255 queries[0].value = 1;
256 localBaseLogLikelihood +=
257 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
259 parentQueryVectorPtr);
261 learnerPtr.localBaseLogLikelihoods[v] = localBaseLogLikelihood;
262 baseLogLikelihood += localBaseLogLikelihood;
264 } // for each variable
267 float globalBaseLogLikelihood =
268 learnerPtr.baseLogLikelihood;
269 learnerPtr.baseLogLikelihood = (baseLogLikelihood + globalBaseLogLikelihood);
273 * For each variable, find if the addition of any edge _to_ it is better
276 status = parentQueryVectorPtr.vector_pushBack(parentQuery);
278 for (int v = lss.i_start; v < lss.i_stop; v++) {
280 //Compute base log likelihood for this variable
282 queries[0].index = v;
283 int bestLocalIndex = v;
284 float bestLocalLogLikelihood = learnerPtr.localBaseLogLikelihoods[v];
286 status = queryVectorPtr.vector_pushBack(queries[1]);
288 for (int vv = 0; vv < numVar; vv++) {
293 parentQuery.index = vv;
295 queries[0].index = v;
296 queries[1].index = vv;
298 queries[0].index = vv;
299 queries[1].index = v;
302 float newLocalLogLikelihood = 0.0f;
304 queries[0].value = 0;
305 queries[1].value = 0;
306 parentQuery.value = 0;
307 newLocalLogLikelihood +=
308 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
310 parentQueryVectorPtr);
312 queries[0].value = 0;
313 queries[1].value = 1;
314 parentQuery.value = ((vv < v) ? 0 : 1);
315 newLocalLogLikelihood +=
316 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
318 parentQueryVectorPtr);
320 queries[0].value = 1;
321 queries[1].value = 0;
322 parentQuery.value = ((vv < v) ? 1 : 0);
323 newLocalLogLikelihood +=
324 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
326 parentQueryVectorPtr);
328 queries[0].value = 1;
329 queries[1].value = 1;
330 parentQuery.value = 1;
331 newLocalLogLikelihood +=
332 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
334 parentQueryVectorPtr);
336 if (newLocalLogLikelihood > bestLocalLogLikelihood) {
338 bestLocalLogLikelihood = newLocalLogLikelihood;
341 } // foreach other variable
343 queryVectorPtr.vector_popBack();
345 if (bestLocalIndex != v) {
346 float logLikelihood = numRecord * (baseLogLikelihood +
347 + bestLocalLogLikelihood
348 - learnerPtr.localBaseLogLikelihoods[v]);
349 float score = penalty + logLikelihood;
351 learnerPtr.tasks[v] = new LearnerTask();
352 LearnerTask taskPtr = learnerPtr.tasks[v];
353 taskPtr.op = OPERATION_INSERT;
354 taskPtr.fromId = bestLocalIndex;
356 taskPtr.score = score;
358 status = learnerPtr.taskListPtr.list_insert(taskPtr);
363 } // for each variable
366 queryVectorPtr.clear();
367 parentQueryVectorPtr.clear();
370 ListNode it = learnerPtr.taskListPtr.head;
372 while (it.nextPtr!=null) {
374 LearnerTask taskPtr = it.dataPtr;
375 System.out.println("[task] op= "+ taskPtr.op +" from= "+taskPtr.fromId+" to= " +taskPtr.toId+
376 " score= " + taskPtr.score);
378 #endif // TEST_LEARNER
382 /* =============================================================================
384 * -- Returns null is list is empty
385 * =============================================================================
387 public LearnerTask TMpopTask (List taskListPtr)
389 LearnerTask taskPtr = null;
391 ListNode it = taskListPtr.head;
392 if (it.nextPtr!=null) {
394 taskPtr = it.dataPtr;
395 boolean status = taskListPtr.list_remove(taskPtr);
402 /* =============================================================================
403 * populateParentQuery
404 * -- Modifies contents of parentQueryVectorPtr
405 * =============================================================================
408 populateParentQueryVector (Net netPtr,
411 Vector_t parentQueryVectorPtr)
413 parentQueryVectorPtr.vector_clear();
415 IntList parentIdListPtr = netPtr.net_getParentIdListPtr(id);
416 IntListNode it = parentIdListPtr.head;
417 while (it.nextPtr!=null) {
419 int parentId = it.dataPtr;
420 boolean status = parentQueryVectorPtr.vector_pushBack(queries[parentId]);
425 /* =============================================================================
426 * TMpopulateParentQuery
427 * -- Modifies contents of parentQueryVectorPtr
428 * =============================================================================
431 TMpopulateParentQueryVector (Net netPtr,
434 Vector_t parentQueryVectorPtr)
436 parentQueryVectorPtr.vector_clear();
438 IntList parentIdListPtr = netPtr.net_getParentIdListPtr(id);
439 IntListNode it = parentIdListPtr.head;
441 while (it.nextPtr!=null) {
443 int parentId = it.dataPtr;
444 boolean status = parentQueryVectorPtr.vector_pushBack(queries[parentId]);
449 /* =============================================================================
450 * populateQueryVectors
451 * -- Modifies contents of queryVectorPtr and parentQueryVectorPtr
452 * =============================================================================
455 populateQueryVectors (Net netPtr,
458 Vector_t queryVectorPtr,
459 Vector_t parentQueryVectorPtr)
461 populateParentQueryVector(netPtr, id, queries, parentQueryVectorPtr);
464 status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
465 status = queryVectorPtr.vector_pushBack(queries[id]);
468 queryVectorPtr.vector_sort();
472 /* =============================================================================
473 * TMpopulateQueryVectors
474 * -- Modifies contents of queryVectorPtr and parentQueryVectorPtr
475 * =============================================================================
478 TMpopulateQueryVectors (Net netPtr,
481 Vector_t queryVectorPtr,
482 Vector_t parentQueryVectorPtr)
484 TMpopulateParentQueryVector(netPtr, id, queries, parentQueryVectorPtr);
487 status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
488 status = queryVectorPtr.vector_pushBack(queries[id]);
490 queryVectorPtr.vector_sort();
493 /* =============================================================================
494 * computeLocalLogLikelihoodHelper
495 * -- Recursive helper routine
496 * =============================================================================
499 computeLocalLogLikelihoodHelper (int i,
503 Vector_t queryVectorPtr,
504 Vector_t parentQueryVectorPtr)
506 if (i >= numParent) {
507 return computeSpecificLocalLogLikelihood(adtreePtr,
509 parentQueryVectorPtr);
512 float localLogLikelihood = 0.0f;
514 Query parentQueryPtr = (Query) (parentQueryVectorPtr.vector_at(i));
515 int parentIndex = parentQueryPtr.index;
517 queries[parentIndex].value = 0;
518 localLogLikelihood += computeLocalLogLikelihoodHelper((i + 1),
523 parentQueryVectorPtr);
525 queries[parentIndex].value = 1;
526 localLogLikelihood += computeLocalLogLikelihoodHelper((i + 1),
531 parentQueryVectorPtr);
533 queries[parentIndex].value = QUERY_VALUE_WILDCARD;
535 return localLogLikelihood;
539 /* =============================================================================
540 * computeLocalLogLikelihood
541 * -- Populate the query vectors before passing as args
542 * =============================================================================
545 computeLocalLogLikelihood (int id,
549 Vector_t queryVectorPtr,
550 Vector_t parentQueryVectorPtr)
552 int numParent = parentQueryVectorPtr.vector_getSize();
553 float localLogLikelihood = 0.0f;
555 queries[id].value = 0;
556 localLogLikelihood += computeLocalLogLikelihoodHelper(0,
561 parentQueryVectorPtr);
563 queries[id].value = 1;
564 localLogLikelihood += computeLocalLogLikelihoodHelper(0,
569 parentQueryVectorPtr);
571 queries[id].value = QUERY_VALUE_WILDCARD;
573 return localLogLikelihood;
577 /* =============================================================================
578 * TMfindBestInsertTask
579 * =============================================================================
582 TMfindBestInsertTask (FindBestTaskArg argPtr)
584 int toId = argPtr.toId;
585 Learner learnerPtr = argPtr.learnerPtr;
586 Query[] queries = argPtr.queries;
587 Vector_t queryVectorPtr = argPtr.queryVectorPtr;
588 Vector_t parentQueryVectorPtr = argPtr.parentQueryVectorPtr;
589 int numTotalParent = argPtr.numTotalParent;
590 float basePenalty = argPtr.basePenalty;
591 float baseLogLikelihood = argPtr.baseLogLikelihood;
592 BitMap invalidBitmapPtr = argPtr.bitmapPtr;
593 Queue workQueuePtr = argPtr.workQueuePtr;
594 Vector_t baseParentQueryVectorPtr = argPtr.aQueryVectorPtr;
595 Vector_t baseQueryVectorPtr = argPtr.bQueryVectorPtr;
598 Adtree adtreePtr = learnerPtr.adtreePtr;
599 Net netPtr = learnerPtr.netPtr;
601 TMpopulateParentQueryVector(netPtr, toId, queries, parentQueryVectorPtr);
604 * Create base query and parentQuery
607 status = Vector_t.vector_copy(baseParentQueryVectorPtr, parentQueryVectorPtr);
609 status = Vector_t.vector_copy(baseQueryVectorPtr, baseParentQueryVectorPtr);
611 status = baseQueryVectorPtr.vector_pushBack(queries[toId]);
613 queryVectorPtr.vector_sort();
616 * Search all possible valid operations for better local log likelihood
619 int bestFromId = toId; // flag for not found
620 float oldLocalLogLikelihood = learnerPtr.localBaseLogLikelihoods[toId];
621 float bestLocalLogLikelihood = oldLocalLogLikelihood;
623 status = netPtr.net_findDescendants(toId, invalidBitmapPtr, workQueuePtr);
627 IntList parentIdListPtr = netPtr.net_getParentIdListPtr(toId);
629 int maxNumEdgeLearned = global_maxNumEdgeLearned;
631 if ((maxNumEdgeLearned < 0) ||
632 (parentIdListPtr.list_getSize() <= maxNumEdgeLearned))
635 IntListNode it = parentIdListPtr.head;
637 while(it.nextPtr!=null) {
639 int parentId = it.dataPtr;
640 invalidBitmapPtr.bitmap_set(parentId); // invalid since already have edge
643 while ((fromId = invalidBitmapPtr.bitmap_findClear((fromId + 1))) >= 0) {
645 if (fromId == toId) {
649 status = Vector_t.vector_copy(queryVectorPtr, baseQueryVectorPtr);
651 status = queryVectorPtr.vector_pushBack(queries[fromId]);
653 queryVectorPtr.vector_sort();
655 status = Vector_t.vector_copy(parentQueryVectorPtr, baseParentQueryVectorPtr);
656 status = parentQueryVectorPtr.vector_pushBack(queries[fromId]);
658 parentQueryVectorPtr.vector_sort();
660 float newLocalLogLikelihood =
661 computeLocalLogLikelihood(toId,
666 parentQueryVectorPtr);
668 if (newLocalLogLikelihood > bestLocalLogLikelihood) {
669 bestLocalLogLikelihood = newLocalLogLikelihood;
673 } // foreach valid parent
675 } // if have not exceeded max number of edges to learn
678 * Return best task; Note: if none is better, fromId will equal toId
681 LearnerTask bestTask = new LearnerTask();
682 bestTask.op = OPERATION_INSERT;
683 bestTask.fromId = bestFromId;
684 bestTask.toId = toId;
685 bestTask.score = 0.0f;
687 if (bestFromId != toId) {
688 int numRecord = adtreePtr.numRecord;
689 int numParent = parentIdListPtr.list_getSize() + 1;
691 (numTotalParent + numParent * global_insertPenalty) * basePenalty;
692 float logLikelihood = numRecord * (baseLogLikelihood +
693 + bestLocalLogLikelihood
694 - oldLocalLogLikelihood);
695 float bestScore = penalty + logLikelihood;
696 bestTask.score = bestScore;
702 #ifdef LEARNER_TRY_REMOVE
703 /* =============================================================================
704 * TMfindBestRemoveTask
705 * =============================================================================
708 TMfindBestRemoveTask (FindBestTaskArg argPtr)
710 int toId = argPtr.toId;
711 Learner learnerPtr = argPtr.learnerPtr;
712 Query[] queries = argPtr.queries;
713 Vector_t queryVectorPtr = argPtr.queryVectorPtr;
714 Vector_t parentQueryVectorPtr = argPtr.parentQueryVectorPtr;
715 int numTotalParent = argPtr.numTotalParent;
716 float basePenalty = argPtr.basePenalty;
717 float baseLogLikelihood = argPtr.baseLogLikelihood;
718 Vector_t origParentQueryVectorPtr = argPtr.aQueryVectorPtr;
721 Adtree adtreePtr = learnerPtr.adtreePtr;
722 Net netPtr = learnerPtr.netPtr;
723 float[] localBaseLogLikelihoods = learnerPtr.localBaseLogLikelihoods;
725 TMpopulateParentQueryVector(netPtr, toId, queries, origParentQueryVectorPtr);
726 int numParent = origParentQueryVectorPtr.vector_getSize();
729 * Search all possible valid operations for better local log likelihood
732 int bestFromId = toId; // flag for not found
733 float oldLocalLogLikelihood = localBaseLogLikelihoods[toId];
734 float bestLocalLogLikelihood = oldLocalLogLikelihood;
737 for (i = 0; i < numParent; i++) {
739 Query queryPtr = (Query) (origParentQueryVectorPtr.vector_at(i));
740 int fromId = queryPtr.index;
743 * Create parent query (subset of parents since remove an edge)
746 parentQueryVectorPtr.vector_clear();
748 for (int p = 0; p < numParent; p++) {
750 Query tmpqueryPtr = (Query) (origParentQueryVectorPtr.vector_at(p));
751 status = parentQueryVectorPtr.vector_pushBack(queries[tmpqueryPtr.index]);
753 } // create new parent query
759 status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
760 status = queryVectorPtr.vector_pushBack(queries[toId]);
761 queryVectorPtr.vector_sort();
764 * See if removing parent is better
767 float newLocalLogLikelihood =
768 computeLocalLogLikelihood(toId,
773 parentQueryVectorPtr);
775 if (newLocalLogLikelihood > bestLocalLogLikelihood) {
776 bestLocalLogLikelihood = newLocalLogLikelihood;
783 * Return best task; Note: if none is better, fromId will equal toId
786 LearnerTask bestTask = new LearnerTask();
787 bestTask.op = OPERATION_REMOVE;
788 bestTask.fromId = bestFromId;
789 bestTask.toId = toId;
790 bestTask.score = 0.0f;
792 if (bestFromId != toId) {
793 int numRecord = adtreePtr.numRecord;
794 float penalty = (numTotalParent - 1) * basePenalty;
795 float logLikelihood = numRecord * (baseLogLikelihood +
796 + bestLocalLogLikelihood
797 - oldLocalLogLikelihood);
798 float bestScore = penalty + logLikelihood;
799 bestTask.score = bestScore;
804 #endif /* LEARNER_TRY_REMOVE */
807 #ifdef LEARNER_TRY_REVERSE
808 /* =============================================================================
809 * TMfindBestReverseTask
810 * =============================================================================
813 TMfindBestReverseTask (FindBestTaskArg argPtr)
815 int toId = argPtr.toId;
816 Learner learnerPtr = argPtr.learnerPtr;
817 Query[] queries = argPtr.queries;
818 Vector_t queryVectorPtr = argPtr.queryVectorPtr;
819 Vector_t parentQueryVectorPtr = argPtr.parentQueryVectorPtr;
820 int numTotalParent = argPtr.numTotalParent;
821 float basePenalty = argPtr.basePenalty;
822 float baseLogLikelihood = argPtr.baseLogLikelihood;
823 BitMap visitedBitmapPtr = argPtr.bitmapPtr;
824 Queue workQueuePtr = argPtr.workQueuePtr;
825 Vector_t toOrigParentQueryVectorPtr = argPtr.aQueryVectorPtr;
826 Vector_t fromOrigParentQueryVectorPtr = argPtr.bQueryVectorPtr;
829 Adtree adtreePtr = learnerPtr.adtreePtr;
830 Net netPtr = learnerPtr.netPtr;
831 float[] localBaseLogLikelihoods = learnerPtr.localBaseLogLikelihoods;
833 TMpopulateParentQueryVector(netPtr, toId, queries, toOrigParentQueryVectorPtr);
834 int numParent = toOrigParentQueryVectorPtr.vector_getSize();
837 * Search all possible valid operations for better local log likelihood
840 int bestFromId = toId; // flag for not found
841 float oldLocalLogLikelihood = localBaseLogLikelihoods[toId];
842 float bestLocalLogLikelihood = oldLocalLogLikelihood;
845 for (int i = 0; i < numParent; i++) {
847 Query queryPtr = (Query) (toOrigParentQueryVectorPtr.vector_at(i));
848 fromId = queryPtr.index;
850 bestLocalLogLikelihood =
851 oldLocalLogLikelihood + localBaseLogLikelihoods[fromId];
853 TMpopulateParentQueryVector(netPtr,
856 fromOrigParentQueryVectorPtr);
859 * Create parent query (subset of parents since remove an edge)
862 parentQueryVectorPtr.vector_clear();
864 for (int p = 0; p < numParent; p++) {
866 Query tmpqueryPtr = (Query) (toOrigParentQueryVectorPtr.vector_at(p));
867 status = parentQueryVectorPtr.vector_pushBack(queries[tmpqueryPtr.index]);
869 } // create new parent query
875 status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
876 status = queryVectorPtr.vector_pushBack(queries[toId]);
878 queryVectorPtr.vector_sort();
881 * Get log likelihood for removing parent from toId
884 float newLocalLogLikelihood =
885 computeLocalLogLikelihood(toId,
890 parentQueryVectorPtr);
893 * Get log likelihood for adding parent to fromId
896 status = Vector_t.vector_copy(parentQueryVectorPtr, fromOrigParentQueryVectorPtr);
897 status = parentQueryVectorPtr.vector_pushBack(queries[toId]);
899 parentQueryVectorPtr.vector_sort();
901 status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
903 status = queryVectorPtr.vector_pushBack(queries[fromId]);
905 queryVectorPtr.vector_sort();
907 newLocalLogLikelihood +=
908 computeLocalLogLikelihood(fromId,
913 parentQueryVectorPtr);
919 if (newLocalLogLikelihood > bestLocalLogLikelihood) {
920 bestLocalLogLikelihood = newLocalLogLikelihood;
927 * Check validity of best
930 if (bestFromId != toId) {
931 boolean isTaskValid = true;
932 netPtr.net_applyOperation(OPERATION_REMOVE, bestFromId, toId);
933 if (netPtr.net_isPath(bestFromId,
940 netPtr.net_applyOperation(OPERATION_INSERT, bestFromId, toId);
947 * Return best task; Note: if none is better, fromId will equal toId
950 LearnerTask bestTask = new LearnerTask();
951 bestTask.op = OPERATION_REVERSE;
952 bestTask.fromId = bestFromId;
953 bestTask.toId = toId;
954 bestTask.score = 0.0f;
956 if (bestFromId != toId) {
957 float fromLocalLogLikelihood = localBaseLogLikelihoods[bestFromId];
958 int numRecord = adtreePtr.numRecord;
959 float penalty = numTotalParent * basePenalty;
960 float logLikelihood = numRecord * (baseLogLikelihood +
961 + bestLocalLogLikelihood
962 - oldLocalLogLikelihood
963 - fromLocalLogLikelihood);
964 float bestScore = penalty + logLikelihood;
965 bestTask.score = bestScore;
971 #endif /* LEARNER_TRY_REVERSE */
974 /* =============================================================================
977 * Note it is okay if the score is not exact, as we are relaxing the greedy
978 * search. This means we do not need to communicate baseLogLikelihood across
980 * =============================================================================
983 learnStructure (int myId, int numThread, Learner learnerPtr)
986 int numRecord = learnerPtr.adtreePtr.numRecord;
988 float operationQualityFactor = learnerPtr.global_operationQualityFactor;
990 BitMap visitedBitmapPtr = BitMap.bitmap_alloc(learnerPtr.adtreePtr.numVar);
992 Queue workQueuePtr = Queue.queue_alloc(-1);
994 int numVar = learnerPtr.adtreePtr.numVar;
995 Query[] queries = new Query[numVar];
997 for (int v = 0; v < numVar; v++) {
998 queries[v] = new Query();
999 queries[v].index = v;
1000 queries[v].value = QUERY_VALUE_WILDCARD;
1003 float basePenalty = (float)(-0.5 * Math.log((double)numRecord));
1005 Vector_t queryVectorPtr = new Vector_t(1);
1006 Vector_t parentQueryVectorPtr = new Vector_t(1);
1007 Vector_t aQueryVectorPtr = new Vector_t(1);
1008 Vector_t bQueryVectorPtr = new Vector_t(1);
1010 FindBestTaskArg arg = new FindBestTaskArg();
1011 arg.learnerPtr = learnerPtr;
1012 arg.queries = queries;
1013 arg.queryVectorPtr = queryVectorPtr;
1014 arg.parentQueryVectorPtr = parentQueryVectorPtr;
1015 arg.bitmapPtr = visitedBitmapPtr;
1016 arg.workQueuePtr = workQueuePtr;
1017 arg.aQueryVectorPtr = aQueryVectorPtr;
1018 arg.bQueryVectorPtr = bQueryVectorPtr;
1022 LearnerTask taskPtr;
1025 taskPtr = learnerPtr.TMpopTask(learnerPtr.taskListPtr);
1028 if (taskPtr == null) {
1032 int op = taskPtr.op;
1033 int fromId = taskPtr.fromId;
1034 int toId = taskPtr.toId;
1036 boolean isTaskValid;
1040 * Check if task is still valid
1044 if(op == OPERATION_INSERT) {
1045 if(learnerPtr.netPtr.net_hasEdge(fromId, toId) ||
1046 learnerPtr.netPtr.net_isPath(toId,
1051 isTaskValid = false;
1053 } else if (op == OPERATION_REMOVE) {
1054 // Can never create cycle, so always valid
1056 } else if (op == OPERATION_REVERSE) {
1057 // Temporarily remove edge for check
1058 learnerPtr.netPtr.net_applyOperation(OPERATION_REMOVE, fromId, toId);
1059 if(learnerPtr.netPtr.net_isPath(fromId,
1064 isTaskValid = false;
1066 learnerPtr.netPtr.net_applyOperation(OPERATION_INSERT, fromId, toId);
1071 System.out.println("[task] op= " + taskPtr.op + " from= " + taskPtr.fromId + " to= " +
1072 taskPtr.toId + " score= " + taskPtr.score + " valid= " + (isTaskValid ? "yes" : "no"));
1076 * Perform task: update graph and probabilities
1080 learnerPtr.netPtr.net_applyOperation(op, fromId, toId);
1085 float deltaLogLikelihood = 0.0f;
1088 float newBaseLogLikelihood;
1089 if(op == OPERATION_INSERT) {
1091 learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1095 parentQueryVectorPtr);
1096 newBaseLogLikelihood =
1097 learnerPtr.computeLocalLogLikelihood(toId,
1098 learnerPtr.adtreePtr,
1102 parentQueryVectorPtr);
1103 float toLocalBaseLogLikelihood = learnerPtr.localBaseLogLikelihoods[toId];
1104 deltaLogLikelihood +=
1105 toLocalBaseLogLikelihood - newBaseLogLikelihood;
1106 learnerPtr.localBaseLogLikelihoods[toId] = newBaseLogLikelihood;
1110 int numTotalParent = learnerPtr.numTotalParent;
1111 learnerPtr.numTotalParent = numTotalParent + 1;
1114 #ifdef LEARNER_TRY_REMOVE
1115 } else if(op == OPERATION_REMOVE) {
1117 learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1121 parentQueryVectorPtr);
1122 newBaseLogLikelihood =
1123 learnerPtr. computeLocalLogLikelihood(fromId,
1124 learnerPtr.adtreePtr,
1128 parentQueryVectorPtr);
1129 float fromLocalBaseLogLikelihood =
1130 learnerPtr.localBaseLogLikelihoods[fromId];
1131 deltaLogLikelihood +=
1132 fromLocalBaseLogLikelihood - newBaseLogLikelihood;
1133 learnerPtr.localBaseLogLikelihoods[fromId] = newBaseLogLikelihood;
1137 int numTotalParent = learnerPtr.numTotalParent;
1138 learnerPtr.numTotalParent = numTotalParent - 1;
1141 #endif // LEARNER_TRY_REMOVE
1142 #ifdef LEARNER_TRY_REVERSE
1143 } else if(op == OPERATION_REVERSE) {
1145 learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1149 parentQueryVectorPtr);
1150 newBaseLogLikelihood =
1151 learnerPtr.computeLocalLogLikelihood(fromId,
1152 learnerPtr.adtreePtr,
1156 parentQueryVectorPtr);
1157 float fromLocalBaseLogLikelihood =
1158 learnerPtr.localBaseLogLikelihoods[fromId];
1159 deltaLogLikelihood +=
1160 fromLocalBaseLogLikelihood - newBaseLogLikelihood;
1161 learnerPtr.localBaseLogLikelihoods[fromId] = newBaseLogLikelihood;
1165 learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1169 parentQueryVectorPtr);
1170 newBaseLogLikelihood =
1171 learnerPtr.computeLocalLogLikelihood(toId,
1172 learnerPtr.adtreePtr,
1176 parentQueryVectorPtr);
1177 float toLocalBaseLogLikelihood =
1178 learnerPtr.localBaseLogLikelihoods[toId];
1179 deltaLogLikelihood +=
1180 toLocalBaseLogLikelihood - newBaseLogLikelihood;
1181 learnerPtr.localBaseLogLikelihoods[toId] = newBaseLogLikelihood;
1184 #endif // LEARNER_TRY_REVERSE
1190 * Update/read globals
1193 float baseLogLikelihood;
1197 float oldBaseLogLikelihood = learnerPtr.baseLogLikelihood;
1198 float newBaseLogLikelihood = oldBaseLogLikelihood + deltaLogLikelihood;
1199 learnerPtr.baseLogLikelihood = newBaseLogLikelihood;
1200 baseLogLikelihood = newBaseLogLikelihood;
1201 numTotalParent = learnerPtr.numTotalParent;
1209 float baseScore = ((float)numTotalParent * basePenalty)
1210 + (numRecord * baseLogLikelihood);
1212 LearnerTask bestTask = new LearnerTask();
1213 bestTask.op = NUM_OPERATION;
1215 bestTask.fromId = -1;
1216 bestTask.score = baseScore;
1218 LearnerTask newTask = new LearnerTask();
1221 arg.numTotalParent = numTotalParent;
1222 arg.basePenalty = basePenalty;
1223 arg.baseLogLikelihood = baseLogLikelihood;
1226 newTask = learnerPtr.TMfindBestInsertTask(arg);
1229 if ((newTask.fromId != newTask.toId) &&
1230 (newTask.score > (bestTask.score / operationQualityFactor)))
1235 #ifdef LEARNER_TRY_REMOVE
1237 newTask = learnerPtr.TMfindBestRemoveTask(arg);
1240 if ((newTask.fromId != newTask.toId) &&
1241 (newTask.score > (bestTask.score / operationQualityFactor)))
1245 #endif // LEARNER_TRY_REMOVE
1247 #ifdef LEARNER_TRY_REVERSE
1249 newTask = learnerPtr.TMfindBestReverseTask(arg);
1252 if ((newTask.fromId != newTask.toId) &&
1253 (newTask.score > (bestTask.score / operationQualityFactor)))
1257 #endif // LEARNER_TRY_REVERSE
1259 if (bestTask.toId != -1) {
1260 LearnerTask[] tasks = learnerPtr.tasks;
1261 tasks[toId] = bestTask;
1263 learnerPtr.taskListPtr.list_insert(tasks[toId]);
1266 System.out.println("[new] op= " + bestTask.op + " from= "+ bestTask.fromId + " to= "+ bestTask.toId +
1267 " score= " + bestTask.score);
1273 visitedBitmapPtr.bitmap_free();
1274 workQueuePtr.queue_free();
1275 bQueryVectorPtr.clear();
1276 aQueryVectorPtr.clear();
1277 queryVectorPtr.clear();
1278 parentQueryVectorPtr.clear();
1283 /* =============================================================================
1285 * -- Call adtree_make before this
1286 * =============================================================================
1288 //Is not called anywhere now parallel code
1290 learner_run (int myId, int numThread, Learner learnerPtr)
1293 createTaskList(myId, numThread, learnerPtr);
1296 learnStructure(myId, numThread, learnerPtr);
1300 /* =============================================================================
1302 * -- Score entire network
1303 * =============================================================================
1309 Vector_t queryVectorPtr = new Vector_t(1);
1310 Vector_t parentQueryVectorPtr = new Vector_t(1);
1312 int numVar = adtreePtr.numVar;
1313 Query[] queries = new Query[numVar];
1315 for (int v = 0; v < numVar; v++) {
1316 queries[v] = new Query();
1317 queries[v].index = v;
1318 queries[v].value = QUERY_VALUE_WILDCARD;
1321 int numTotalParent = 0;
1322 float logLikelihood = 0.0f;
1324 for (int v = 0; v < numVar; v++) {
1326 IntList parentIdListPtr = netPtr.net_getParentIdListPtr(v);
1327 numTotalParent += parentIdListPtr.list_getSize();
1329 populateQueryVectors(netPtr,
1333 parentQueryVectorPtr);
1334 float localLogLikelihood = computeLocalLogLikelihood(v,
1339 parentQueryVectorPtr);
1340 logLikelihood += localLogLikelihood;
1343 queryVectorPtr.clear();
1344 parentQueryVectorPtr.clear();
1348 int numRecord = adtreePtr.numRecord;
1349 float penalty = (float)(-0.5f * (double)numTotalParent * Math.log((double)numRecord));
1350 float score = penalty + (float)numRecord * logLikelihood;
1356 /* =============================================================================
1358 * End of learner.java
1360 * =============================================================================