1 /* =============================================================================
4 * -- Learns structure of Bayesian net from data
6 * =============================================================================
8 * Copyright (C) Stanford University, 2006. All Rights Reserved.
10 * Ported to Java June 2009 Alokika Dash
11 * University of California, Irvine
14 * =============================================================================
16 * The penalized log-likelihood score (Friedman & Yahkani, 1996) is used to
17 * evaluated the "goodness" of a Bayesian net:
21 * -N_params * ln(R) / 2 + R > > > P((a_j = v), X_j) ln P(a_j = v | X_j)
27 * N_params total number of parents across all variables
29 * M number of variables
30 * X_j parents of the jth variable
31 * n_j number of attributes of the jth variable
34 * The second summation of X_j varies across all possible assignments to the
35 * values of the parents X_j.
39 * "local log likelihood" is P((a_j = v), X_j) ln P(a_j = v | X_j)
40 * "log likelihood" is everything to the right of the '+', i.e., "R ... X_j)"
41 * "base penalty" is -ln(R) / 2
42 * "penalty" is N_params * -ln(R) / 2
43 * "score" is the entire expression
45 * For more notes, refer to:
47 * A. Moore and M.-S. Lee. Cached sufficient statistics for efficient machine
48 * learning with large datasets. Journal of Artificial Intelligence Research 8
51 * =============================================================================
53 * The search strategy uses a combination of local and global structure search.
54 * Similar to the technique described in:
56 * D. M. Chickering, D. Heckerman, and C. Meek. A Bayesian approach to learning
57 * Bayesian networks with local structure. In Proceedings of Thirteenth
58 * Conference on Uncertainty in Artificial Intelligence (1997), pp. 80-89.
60 * =============================================================================
62 * For the license of bayes/sort.h and bayes/sort.c, please see the header
65 * ------------------------------------------------------------------------
67 * Unless otherwise noted, the following license applies to STAMP files:
69 * Copyright (c) 2007, Stanford University
70 * All rights reserved.
72 * Redistribution and use in source and binary forms, with or without
73 * modification, are permitted provided that the following conditions are
76 * * Redistributions of source code must retain the above copyright
77 * notice, this list of conditions and the following disclaimer.
79 * * Redistributions in binary form must reproduce the above copyright
80 * notice, this list of conditions and the following disclaimer in
81 * the documentation and/or other materials provided with the
84 * * Neither the name of Stanford University nor the names of its
85 * contributors may be used to endorse or promote products derived
86 * from this software without specific prior written permission.
88 * THIS SOFTWARE IS PROVIDED BY STANFORD UNIVERSITY ``AS IS'' AND ANY
89 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
90 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
91 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL STANFORD UNIVERSITY BE LIABLE
92 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
93 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
94 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
95 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
96 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
97 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
98 * THE POSSIBILITY OF SUCH DAMAGE.
100 * =============================================================================
104 #define CACHE_LINE_SIZE 64
105 #define QUERY_VALUE_WILDCARD -1
106 #define OPERATION_INSERT 0
107 #define OPERATION_REMOVE 1
108 #define OPERATION_REVERSE 2
109 #define NUM_OPERATION 3
111 public class Learner {
114 float[] localBaseLogLikelihoods;
115 float baseLogLikelihood;
119 int global_insertPenalty;
120 int global_maxNumEdgeLearned;
121 float global_operationQualityFactor;
125 global_maxNumEdgeLearned = -1;
126 global_insertPenalty = 1;
127 global_operationQualityFactor = 1.0F;
131 /* =============================================================================
133 * =============================================================================
135 public Learner(Data dataPtr,
138 int global_insertPenalty,
139 int global_maxNumEdgeLearned,
140 float global_operationQualityFactor) {
141 this.adtreePtr = adtreePtr;
142 this.netPtr = new Net(dataPtr.numVar);
143 this.localBaseLogLikelihoods = new float[dataPtr.numVar];
144 this.baseLogLikelihood = 0.0f;
145 this.tasks = new LearnerTask[dataPtr.numVar];
146 this.taskListPtr = List.list_alloc();
147 this.numTotalParent = 0;
149 this.global_insertPenalty = global_insertPenalty;
150 this.global_maxNumEdgeLearned = global_maxNumEdgeLearned;
151 this.global_operationQualityFactor = global_operationQualityFactor;
155 public void learner_free() {
158 localBaseLogLikelihoods=null;
164 /* =============================================================================
165 * computeSpecificLocalLogLikelihood
166 * -- Query vectors should not contain wildcards
167 * =============================================================================
170 computeSpecificLocalLogLikelihood (Adtree adtreePtr,
171 Vector_t queryVectorPtr,
172 Vector_t parentQueryVectorPtr)
174 int count = adtreePtr.adtree_getCount(queryVectorPtr);
179 double probability = (double)count / (double)adtreePtr.numRecord;
180 int parentCount = adtreePtr.adtree_getCount(parentQueryVectorPtr);
183 float fval = (float)(probability * (Math.log((double)count/ (double)parentCount)));
189 /* =============================================================================
191 * =============================================================================
194 createPartition (int min, int max, int id, int n, LocalStartStop lss)
196 int range = max - min;
197 int chunk = Math.imax(1, ((range + n/2) / n)); // rounded
198 int start = min + chunk * id;
203 stop = Math.imin(max, (start + chunk));
210 /* =============================================================================
212 * -- baseLogLikelihoods and taskListPtr are updated
213 * =============================================================================
216 createTaskList (int myId, int numThread, Learner learnerPtr)
220 Query[] queries = new Query[2];
221 queries[0] = new Query();
222 queries[1] = new Query();
224 Vector_t queryVectorPtr = new Vector_t(2);
226 status = queryVectorPtr.vector_pushBack(queries[0]);
228 Query parentQuery = new Query();
229 Vector_t parentQueryVectorPtr = new Vector_t(1);
231 int numVar = learnerPtr.adtreePtr.numVar;
232 int numRecord = learnerPtr.adtreePtr.numRecord;
233 float baseLogLikelihood = 0.0f;
234 float penalty = (float)(-0.5f * Math.log((double)numRecord)); // only add 1 edge
236 LocalStartStop lss = new LocalStartStop();
237 learnerPtr.createPartition(0, numVar, myId, numThread, lss);
240 * Compute base log likelihood for each variable and total base loglikelihood
243 for (int v = lss.i_start; v < lss.i_stop; v++) {
245 float localBaseLogLikelihood = 0.0f;
246 queries[0].index = v;
248 queries[0].value = 0;
249 localBaseLogLikelihood +=
250 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
252 parentQueryVectorPtr);
254 queries[0].value = 1;
255 localBaseLogLikelihood +=
256 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
258 parentQueryVectorPtr);
260 learnerPtr.localBaseLogLikelihoods[v] = localBaseLogLikelihood;
261 baseLogLikelihood += localBaseLogLikelihood;
263 } // for each variable
266 float globalBaseLogLikelihood =
267 learnerPtr.baseLogLikelihood;
268 learnerPtr.baseLogLikelihood = (baseLogLikelihood + globalBaseLogLikelihood);
272 * For each variable, find if the addition of any edge _to_ it is better
275 status = parentQueryVectorPtr.vector_pushBack(parentQuery);
277 for (int v = lss.i_start; v < lss.i_stop; v++) {
279 //Compute base log likelihood for this variable
281 queries[0].index = v;
282 int bestLocalIndex = v;
283 float bestLocalLogLikelihood = learnerPtr.localBaseLogLikelihoods[v];
285 status = queryVectorPtr.vector_pushBack(queries[1]);
287 for (int vv = 0; vv < numVar; vv++) {
292 parentQuery.index = vv;
294 queries[0].index = v;
295 queries[1].index = vv;
297 queries[0].index = vv;
298 queries[1].index = v;
301 float newLocalLogLikelihood = 0.0f;
303 queries[0].value = 0;
304 queries[1].value = 0;
305 parentQuery.value = 0;
306 newLocalLogLikelihood +=
307 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
309 parentQueryVectorPtr);
311 queries[0].value = 0;
312 queries[1].value = 1;
313 parentQuery.value = ((vv < v) ? 0 : 1);
314 newLocalLogLikelihood +=
315 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
317 parentQueryVectorPtr);
319 queries[0].value = 1;
320 queries[1].value = 0;
321 parentQuery.value = ((vv < v) ? 1 : 0);
322 newLocalLogLikelihood +=
323 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
325 parentQueryVectorPtr);
327 queries[0].value = 1;
328 queries[1].value = 1;
329 parentQuery.value = 1;
330 newLocalLogLikelihood +=
331 learnerPtr.computeSpecificLocalLogLikelihood(learnerPtr.adtreePtr,
333 parentQueryVectorPtr);
335 if (newLocalLogLikelihood > bestLocalLogLikelihood) {
337 bestLocalLogLikelihood = newLocalLogLikelihood;
340 } // foreach other variable
342 queryVectorPtr.vector_popBack();
344 if (bestLocalIndex != v) {
345 float logLikelihood = numRecord * (baseLogLikelihood +
346 + bestLocalLogLikelihood
347 - learnerPtr.localBaseLogLikelihoods[v]);
348 float score = penalty + logLikelihood;
350 learnerPtr.tasks[v] = new LearnerTask();
351 LearnerTask taskPtr = learnerPtr.tasks[v];
352 taskPtr.op = OPERATION_INSERT;
353 taskPtr.fromId = bestLocalIndex;
355 taskPtr.score = score;
357 status = learnerPtr.taskListPtr.list_insert(taskPtr);
362 } // for each variable
365 queryVectorPtr.clear();
366 parentQueryVectorPtr.clear();
369 ListNode it = learnerPtr.taskListPtr.head;
371 while (it.nextPtr!=null) {
373 LearnerTask taskPtr = it.dataPtr;
374 System.out.println("[task] op= "+ taskPtr.op +" from= "+taskPtr.fromId+" to= " +taskPtr.toId+
375 " score= " + taskPtr.score);
377 #endif // TEST_LEARNER
381 /* =============================================================================
383 * -- Returns null is list is empty
384 * =============================================================================
386 public LearnerTask TMpopTask (List taskListPtr)
388 LearnerTask taskPtr = null;
390 ListNode it = taskListPtr.head;
391 if (it.nextPtr!=null) {
393 taskPtr = it.dataPtr;
394 boolean status = taskListPtr.list_remove(taskPtr);
401 /* =============================================================================
402 * populateParentQuery
403 * -- Modifies contents of parentQueryVectorPtr
404 * =============================================================================
407 populateParentQueryVector (Net netPtr,
410 Vector_t parentQueryVectorPtr)
412 parentQueryVectorPtr.vector_clear();
414 IntList parentIdListPtr = netPtr.net_getParentIdListPtr(id);
415 IntListNode it = parentIdListPtr.head;
416 while (it.nextPtr!=null) {
418 int parentId = it.dataPtr;
419 boolean status = parentQueryVectorPtr.vector_pushBack(queries[parentId]);
424 /* =============================================================================
425 * TMpopulateParentQuery
426 * -- Modifies contents of parentQueryVectorPtr
427 * =============================================================================
430 TMpopulateParentQueryVector (Net netPtr,
433 Vector_t parentQueryVectorPtr)
435 parentQueryVectorPtr.vector_clear();
437 IntList parentIdListPtr = netPtr.net_getParentIdListPtr(id);
438 IntListNode it = parentIdListPtr.head;
440 while (it.nextPtr!=null) {
442 int parentId = it.dataPtr;
443 boolean status = parentQueryVectorPtr.vector_pushBack(queries[parentId]);
448 /* =============================================================================
449 * populateQueryVectors
450 * -- Modifies contents of queryVectorPtr and parentQueryVectorPtr
451 * =============================================================================
454 populateQueryVectors (Net netPtr,
457 Vector_t queryVectorPtr,
458 Vector_t parentQueryVectorPtr)
460 populateParentQueryVector(netPtr, id, queries, parentQueryVectorPtr);
463 status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
464 status = queryVectorPtr.vector_pushBack(queries[id]);
467 queryVectorPtr.vector_sort();
471 /* =============================================================================
472 * TMpopulateQueryVectors
473 * -- Modifies contents of queryVectorPtr and parentQueryVectorPtr
474 * =============================================================================
477 TMpopulateQueryVectors (Net netPtr,
480 Vector_t queryVectorPtr,
481 Vector_t parentQueryVectorPtr)
483 TMpopulateParentQueryVector(netPtr, id, queries, parentQueryVectorPtr);
486 status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
487 status = queryVectorPtr.vector_pushBack(queries[id]);
489 queryVectorPtr.vector_sort();
492 /* =============================================================================
493 * computeLocalLogLikelihoodHelper
494 * -- Recursive helper routine
495 * =============================================================================
498 computeLocalLogLikelihoodHelper (int i,
502 Vector_t queryVectorPtr,
503 Vector_t parentQueryVectorPtr)
505 if (i >= numParent) {
506 return computeSpecificLocalLogLikelihood(adtreePtr,
508 parentQueryVectorPtr);
511 float localLogLikelihood = 0.0f;
513 Query parentQueryPtr = (Query) (parentQueryVectorPtr.vector_at(i));
514 int parentIndex = parentQueryPtr.index;
516 queries[parentIndex].value = 0;
517 localLogLikelihood += computeLocalLogLikelihoodHelper((i + 1),
522 parentQueryVectorPtr);
524 queries[parentIndex].value = 1;
525 localLogLikelihood += computeLocalLogLikelihoodHelper((i + 1),
530 parentQueryVectorPtr);
532 queries[parentIndex].value = QUERY_VALUE_WILDCARD;
534 return localLogLikelihood;
538 /* =============================================================================
539 * computeLocalLogLikelihood
540 * -- Populate the query vectors before passing as args
541 * =============================================================================
544 computeLocalLogLikelihood (int id,
548 Vector_t queryVectorPtr,
549 Vector_t parentQueryVectorPtr)
551 int numParent = parentQueryVectorPtr.vector_getSize();
552 float localLogLikelihood = 0.0f;
554 queries[id].value = 0;
555 localLogLikelihood += computeLocalLogLikelihoodHelper(0,
560 parentQueryVectorPtr);
562 queries[id].value = 1;
563 localLogLikelihood += computeLocalLogLikelihoodHelper(0,
568 parentQueryVectorPtr);
570 queries[id].value = QUERY_VALUE_WILDCARD;
572 return localLogLikelihood;
576 /* =============================================================================
577 * TMfindBestInsertTask
578 * =============================================================================
581 TMfindBestInsertTask (FindBestTaskArg argPtr)
583 int toId = argPtr.toId;
584 Learner learnerPtr = argPtr.learnerPtr;
585 Query[] queries = argPtr.queries;
586 Vector_t queryVectorPtr = argPtr.queryVectorPtr;
587 Vector_t parentQueryVectorPtr = argPtr.parentQueryVectorPtr;
588 int numTotalParent = argPtr.numTotalParent;
589 float basePenalty = argPtr.basePenalty;
590 float baseLogLikelihood = argPtr.baseLogLikelihood;
591 BitMap invalidBitmapPtr = argPtr.bitmapPtr;
592 Queue workQueuePtr = argPtr.workQueuePtr;
593 Vector_t baseParentQueryVectorPtr = argPtr.aQueryVectorPtr;
594 Vector_t baseQueryVectorPtr = argPtr.bQueryVectorPtr;
597 Adtree adtreePtr = learnerPtr.adtreePtr;
598 Net netPtr = learnerPtr.netPtr;
600 TMpopulateParentQueryVector(netPtr, toId, queries, parentQueryVectorPtr);
603 * Create base query and parentQuery
606 status = Vector_t.vector_copy(baseParentQueryVectorPtr, parentQueryVectorPtr);
608 status = Vector_t.vector_copy(baseQueryVectorPtr, baseParentQueryVectorPtr);
610 status = baseQueryVectorPtr.vector_pushBack(queries[toId]);
612 queryVectorPtr.vector_sort();
615 * Search all possible valid operations for better local log likelihood
618 int bestFromId = toId; // flag for not found
619 float oldLocalLogLikelihood = learnerPtr.localBaseLogLikelihoods[toId];
620 float bestLocalLogLikelihood = oldLocalLogLikelihood;
622 status = netPtr.net_findDescendants(toId, invalidBitmapPtr, workQueuePtr);
626 IntList parentIdListPtr = netPtr.net_getParentIdListPtr(toId);
628 int maxNumEdgeLearned = global_maxNumEdgeLearned;
630 if ((maxNumEdgeLearned < 0) ||
631 (parentIdListPtr.list_getSize() <= maxNumEdgeLearned))
634 IntListNode it = parentIdListPtr.head;
636 while(it.nextPtr!=null) {
638 int parentId = it.dataPtr;
639 invalidBitmapPtr.bitmap_set(parentId); // invalid since already have edge
642 while ((fromId = invalidBitmapPtr.bitmap_findClear((fromId + 1))) >= 0) {
644 if (fromId == toId) {
648 status = Vector_t.vector_copy(queryVectorPtr, baseQueryVectorPtr);
650 status = queryVectorPtr.vector_pushBack(queries[fromId]);
652 queryVectorPtr.vector_sort();
654 status = Vector_t.vector_copy(parentQueryVectorPtr, baseParentQueryVectorPtr);
655 status = parentQueryVectorPtr.vector_pushBack(queries[fromId]);
657 parentQueryVectorPtr.vector_sort();
659 float newLocalLogLikelihood =
660 computeLocalLogLikelihood(toId,
665 parentQueryVectorPtr);
667 if (newLocalLogLikelihood > bestLocalLogLikelihood) {
668 bestLocalLogLikelihood = newLocalLogLikelihood;
672 } // foreach valid parent
674 } // if have not exceeded max number of edges to learn
677 * Return best task; Note: if none is better, fromId will equal toId
680 LearnerTask bestTask = new LearnerTask();
681 bestTask.op = OPERATION_INSERT;
682 bestTask.fromId = bestFromId;
683 bestTask.toId = toId;
684 bestTask.score = 0.0f;
686 if (bestFromId != toId) {
687 int numRecord = adtreePtr.numRecord;
688 int numParent = parentIdListPtr.list_getSize() + 1;
690 (numTotalParent + numParent * global_insertPenalty) * basePenalty;
691 float logLikelihood = numRecord * (baseLogLikelihood +
692 + bestLocalLogLikelihood
693 - oldLocalLogLikelihood);
694 float bestScore = penalty + logLikelihood;
695 bestTask.score = bestScore;
701 #ifdef LEARNER_TRY_REMOVE
702 /* =============================================================================
703 * TMfindBestRemoveTask
704 * =============================================================================
707 TMfindBestRemoveTask (FindBestTaskArg argPtr)
709 int toId = argPtr.toId;
710 Learner learnerPtr = argPtr.learnerPtr;
711 Query[] queries = argPtr.queries;
712 Vector_t queryVectorPtr = argPtr.queryVectorPtr;
713 Vector_t parentQueryVectorPtr = argPtr.parentQueryVectorPtr;
714 int numTotalParent = argPtr.numTotalParent;
715 float basePenalty = argPtr.basePenalty;
716 float baseLogLikelihood = argPtr.baseLogLikelihood;
717 Vector_t origParentQueryVectorPtr = argPtr.aQueryVectorPtr;
720 Adtree adtreePtr = learnerPtr.adtreePtr;
721 Net netPtr = learnerPtr.netPtr;
722 float[] localBaseLogLikelihoods = learnerPtr.localBaseLogLikelihoods;
724 TMpopulateParentQueryVector(netPtr, toId, queries, origParentQueryVectorPtr);
725 int numParent = origParentQueryVectorPtr.vector_getSize();
728 * Search all possible valid operations for better local log likelihood
731 int bestFromId = toId; // flag for not found
732 float oldLocalLogLikelihood = localBaseLogLikelihoods[toId];
733 float bestLocalLogLikelihood = oldLocalLogLikelihood;
736 for (i = 0; i < numParent; i++) {
738 Query queryPtr = (Query) (origParentQueryVectorPtr.vector_at(i));
739 int fromId = queryPtr.index;
742 * Create parent query (subset of parents since remove an edge)
745 parentQueryVectorPtr.vector_clear();
747 for (int p = 0; p < numParent; p++) {
749 Query tmpqueryPtr = (Query) (origParentQueryVectorPtr.vector_at(p));
750 status = parentQueryVectorPtr.vector_pushBack(queries[tmpqueryPtr.index]);
752 } // create new parent query
758 status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
759 status = queryVectorPtr.vector_pushBack(queries[toId]);
760 queryVectorPtr.vector_sort();
763 * See if removing parent is better
766 float newLocalLogLikelihood =
767 computeLocalLogLikelihood(toId,
772 parentQueryVectorPtr);
774 if (newLocalLogLikelihood > bestLocalLogLikelihood) {
775 bestLocalLogLikelihood = newLocalLogLikelihood;
782 * Return best task; Note: if none is better, fromId will equal toId
785 LearnerTask bestTask = new LearnerTask();
786 bestTask.op = OPERATION_REMOVE;
787 bestTask.fromId = bestFromId;
788 bestTask.toId = toId;
789 bestTask.score = 0.0f;
791 if (bestFromId != toId) {
792 int numRecord = adtreePtr.numRecord;
793 float penalty = (numTotalParent - 1) * basePenalty;
794 float logLikelihood = numRecord * (baseLogLikelihood +
795 + bestLocalLogLikelihood
796 - oldLocalLogLikelihood);
797 float bestScore = penalty + logLikelihood;
798 bestTask.score = bestScore;
803 #endif /* LEARNER_TRY_REMOVE */
806 #ifdef LEARNER_TRY_REVERSE
807 /* =============================================================================
808 * TMfindBestReverseTask
809 * =============================================================================
812 TMfindBestReverseTask (FindBestTaskArg argPtr)
814 int toId = argPtr.toId;
815 Learner learnerPtr = argPtr.learnerPtr;
816 Query[] queries = argPtr.queries;
817 Vector_t queryVectorPtr = argPtr.queryVectorPtr;
818 Vector_t parentQueryVectorPtr = argPtr.parentQueryVectorPtr;
819 int numTotalParent = argPtr.numTotalParent;
820 float basePenalty = argPtr.basePenalty;
821 float baseLogLikelihood = argPtr.baseLogLikelihood;
822 BitMap visitedBitmapPtr = argPtr.bitmapPtr;
823 Queue workQueuePtr = argPtr.workQueuePtr;
824 Vector_t toOrigParentQueryVectorPtr = argPtr.aQueryVectorPtr;
825 Vector_t fromOrigParentQueryVectorPtr = argPtr.bQueryVectorPtr;
828 Adtree adtreePtr = learnerPtr.adtreePtr;
829 Net netPtr = learnerPtr.netPtr;
830 float[] localBaseLogLikelihoods = learnerPtr.localBaseLogLikelihoods;
832 TMpopulateParentQueryVector(netPtr, toId, queries, toOrigParentQueryVectorPtr);
833 int numParent = toOrigParentQueryVectorPtr.vector_getSize();
836 * Search all possible valid operations for better local log likelihood
839 int bestFromId = toId; // flag for not found
840 float oldLocalLogLikelihood = localBaseLogLikelihoods[toId];
841 float bestLocalLogLikelihood = oldLocalLogLikelihood;
844 for (int i = 0; i < numParent; i++) {
846 Query queryPtr = (Query) (toOrigParentQueryVectorPtr.vector_at(i));
847 fromId = queryPtr.index;
849 bestLocalLogLikelihood =
850 oldLocalLogLikelihood + localBaseLogLikelihoods[fromId];
852 TMpopulateParentQueryVector(netPtr,
855 fromOrigParentQueryVectorPtr);
858 * Create parent query (subset of parents since remove an edge)
861 parentQueryVectorPtr.vector_clear();
863 for (int p = 0; p < numParent; p++) {
865 Query tmpqueryPtr = (Query) (toOrigParentQueryVectorPtr.vector_at(p));
866 status = parentQueryVectorPtr.vector_pushBack(queries[tmpqueryPtr.index]);
868 } // create new parent query
874 status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
875 status = queryVectorPtr.vector_pushBack(queries[toId]);
877 queryVectorPtr.vector_sort();
880 * Get log likelihood for removing parent from toId
883 float newLocalLogLikelihood =
884 computeLocalLogLikelihood(toId,
889 parentQueryVectorPtr);
892 * Get log likelihood for adding parent to fromId
895 status = Vector_t.vector_copy(parentQueryVectorPtr, fromOrigParentQueryVectorPtr);
896 status = parentQueryVectorPtr.vector_pushBack(queries[toId]);
898 parentQueryVectorPtr.vector_sort();
900 status = Vector_t.vector_copy(queryVectorPtr, parentQueryVectorPtr);
902 status = queryVectorPtr.vector_pushBack(queries[fromId]);
904 queryVectorPtr.vector_sort();
906 newLocalLogLikelihood +=
907 computeLocalLogLikelihood(fromId,
912 parentQueryVectorPtr);
918 if (newLocalLogLikelihood > bestLocalLogLikelihood) {
919 bestLocalLogLikelihood = newLocalLogLikelihood;
926 * Check validity of best
929 if (bestFromId != toId) {
930 boolean isTaskValid = true;
931 netPtr.net_applyOperation(OPERATION_REMOVE, bestFromId, toId);
932 if (netPtr.net_isPath(bestFromId,
939 netPtr.net_applyOperation(OPERATION_INSERT, bestFromId, toId);
946 * Return best task; Note: if none is better, fromId will equal toId
949 LearnerTask bestTask = new LearnerTask();
950 bestTask.op = OPERATION_REVERSE;
951 bestTask.fromId = bestFromId;
952 bestTask.toId = toId;
953 bestTask.score = 0.0f;
955 if (bestFromId != toId) {
956 float fromLocalLogLikelihood = localBaseLogLikelihoods[bestFromId];
957 int numRecord = adtreePtr.numRecord;
958 float penalty = numTotalParent * basePenalty;
959 float logLikelihood = numRecord * (baseLogLikelihood +
960 + bestLocalLogLikelihood
961 - oldLocalLogLikelihood
962 - fromLocalLogLikelihood);
963 float bestScore = penalty + logLikelihood;
964 bestTask.score = bestScore;
970 #endif /* LEARNER_TRY_REVERSE */
973 /* =============================================================================
976 * Note it is okay if the score is not exact, as we are relaxing the greedy
977 * search. This means we do not need to communicate baseLogLikelihood across
979 * =============================================================================
982 learnStructure (int myId, int numThread, Learner learnerPtr)
985 int numRecord = learnerPtr.adtreePtr.numRecord;
987 float operationQualityFactor = learnerPtr.global_operationQualityFactor;
989 BitMap visitedBitmapPtr = BitMap.bitmap_alloc(learnerPtr.adtreePtr.numVar);
991 Queue workQueuePtr = Queue.queue_alloc(-1);
993 int numVar = learnerPtr.adtreePtr.numVar;
994 Query[] queries = new Query[numVar];
996 for (int v = 0; v < numVar; v++) {
997 queries[v] = new Query();
998 queries[v].index = v;
999 queries[v].value = QUERY_VALUE_WILDCARD;
1002 float basePenalty = (float)(-0.5 * Math.log((double)numRecord));
1004 Vector_t queryVectorPtr = new Vector_t(1);
1005 Vector_t parentQueryVectorPtr = new Vector_t(1);
1006 Vector_t aQueryVectorPtr = new Vector_t(1);
1007 Vector_t bQueryVectorPtr = new Vector_t(1);
1009 FindBestTaskArg arg = new FindBestTaskArg();
1010 arg.learnerPtr = learnerPtr;
1011 arg.queries = queries;
1012 arg.queryVectorPtr = queryVectorPtr;
1013 arg.parentQueryVectorPtr = parentQueryVectorPtr;
1014 arg.bitmapPtr = visitedBitmapPtr;
1015 arg.workQueuePtr = workQueuePtr;
1016 arg.aQueryVectorPtr = aQueryVectorPtr;
1017 arg.bQueryVectorPtr = bQueryVectorPtr;
1021 LearnerTask taskPtr;
1024 taskPtr = learnerPtr.TMpopTask(learnerPtr.taskListPtr);
1027 if (taskPtr == null) {
1031 int op = taskPtr.op;
1032 int fromId = taskPtr.fromId;
1033 int toId = taskPtr.toId;
1035 boolean isTaskValid;
1039 * Check if task is still valid
1043 if(op == OPERATION_INSERT) {
1044 if(learnerPtr.netPtr.net_hasEdge(fromId, toId) ||
1045 learnerPtr.netPtr.net_isPath(toId,
1050 isTaskValid = false;
1052 } else if (op == OPERATION_REMOVE) {
1053 // Can never create cycle, so always valid
1055 } else if (op == OPERATION_REVERSE) {
1056 // Temporarily remove edge for check
1057 learnerPtr.netPtr.net_applyOperation(OPERATION_REMOVE, fromId, toId);
1058 if(learnerPtr.netPtr.net_isPath(fromId,
1063 isTaskValid = false;
1065 learnerPtr.netPtr.net_applyOperation(OPERATION_INSERT, fromId, toId);
1070 System.out.println("[task] op= " + taskPtr.op + " from= " + taskPtr.fromId + " to= " +
1071 taskPtr.toId + " score= " + taskPtr.score + " valid= " + (isTaskValid ? "yes" : "no"));
1075 * Perform task: update graph and probabilities
1079 learnerPtr.netPtr.net_applyOperation(op, fromId, toId);
1084 float deltaLogLikelihood = 0.0f;
1087 float newBaseLogLikelihood;
1088 if(op == OPERATION_INSERT) {
1090 learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1094 parentQueryVectorPtr);
1095 newBaseLogLikelihood =
1096 learnerPtr.computeLocalLogLikelihood(toId,
1097 learnerPtr.adtreePtr,
1101 parentQueryVectorPtr);
1102 float toLocalBaseLogLikelihood = learnerPtr.localBaseLogLikelihoods[toId];
1103 deltaLogLikelihood +=
1104 toLocalBaseLogLikelihood - newBaseLogLikelihood;
1105 learnerPtr.localBaseLogLikelihoods[toId] = newBaseLogLikelihood;
1109 int numTotalParent = learnerPtr.numTotalParent;
1110 learnerPtr.numTotalParent = numTotalParent + 1;
1113 #ifdef LEARNER_TRY_REMOVE
1114 } else if(op == OPERATION_REMOVE) {
1116 learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1120 parentQueryVectorPtr);
1121 newBaseLogLikelihood =
1122 learnerPtr. computeLocalLogLikelihood(fromId,
1123 learnerPtr.adtreePtr,
1127 parentQueryVectorPtr);
1128 float fromLocalBaseLogLikelihood =
1129 learnerPtr.localBaseLogLikelihoods[fromId];
1130 deltaLogLikelihood +=
1131 fromLocalBaseLogLikelihood - newBaseLogLikelihood;
1132 learnerPtr.localBaseLogLikelihoods[fromId] = newBaseLogLikelihood;
1136 int numTotalParent = learnerPtr.numTotalParent;
1137 learnerPtr.numTotalParent = numTotalParent - 1;
1140 #endif // LEARNER_TRY_REMOVE
1141 #ifdef LEARNER_TRY_REVERSE
1142 } else if(op == OPERATION_REVERSE) {
1144 learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1148 parentQueryVectorPtr);
1149 newBaseLogLikelihood =
1150 learnerPtr.computeLocalLogLikelihood(fromId,
1151 learnerPtr.adtreePtr,
1155 parentQueryVectorPtr);
1156 float fromLocalBaseLogLikelihood =
1157 learnerPtr.localBaseLogLikelihoods[fromId];
1158 deltaLogLikelihood +=
1159 fromLocalBaseLogLikelihood - newBaseLogLikelihood;
1160 learnerPtr.localBaseLogLikelihoods[fromId] = newBaseLogLikelihood;
1164 learnerPtr.TMpopulateQueryVectors(learnerPtr.netPtr,
1168 parentQueryVectorPtr);
1169 newBaseLogLikelihood =
1170 learnerPtr.computeLocalLogLikelihood(toId,
1171 learnerPtr.adtreePtr,
1175 parentQueryVectorPtr);
1176 float toLocalBaseLogLikelihood =
1177 learnerPtr.localBaseLogLikelihoods[toId];
1178 deltaLogLikelihood +=
1179 toLocalBaseLogLikelihood - newBaseLogLikelihood;
1180 learnerPtr.localBaseLogLikelihoods[toId] = newBaseLogLikelihood;
1183 #endif // LEARNER_TRY_REVERSE
1189 * Update/read globals
1192 float baseLogLikelihood;
1196 float oldBaseLogLikelihood = learnerPtr.baseLogLikelihood;
1197 float newBaseLogLikelihood = oldBaseLogLikelihood + deltaLogLikelihood;
1198 learnerPtr.baseLogLikelihood = newBaseLogLikelihood;
1199 baseLogLikelihood = newBaseLogLikelihood;
1200 numTotalParent = learnerPtr.numTotalParent;
1208 float baseScore = ((float)numTotalParent * basePenalty)
1209 + (numRecord * baseLogLikelihood);
1211 LearnerTask bestTask = new LearnerTask();
1212 bestTask.op = NUM_OPERATION;
1214 bestTask.fromId = -1;
1215 bestTask.score = baseScore;
1217 LearnerTask newTask = new LearnerTask();
1220 arg.numTotalParent = numTotalParent;
1221 arg.basePenalty = basePenalty;
1222 arg.baseLogLikelihood = baseLogLikelihood;
1225 newTask = learnerPtr.TMfindBestInsertTask(arg);
1228 if ((newTask.fromId != newTask.toId) &&
1229 (newTask.score > (bestTask.score / operationQualityFactor)))
1234 #ifdef LEARNER_TRY_REMOVE
1236 newTask = learnerPtr.TMfindBestRemoveTask(arg);
1239 if ((newTask.fromId != newTask.toId) &&
1240 (newTask.score > (bestTask.score / operationQualityFactor)))
1244 #endif // LEARNER_TRY_REMOVE
1246 #ifdef LEARNER_TRY_REVERSE
1248 newTask = learnerPtr.TMfindBestReverseTask(arg);
1251 if ((newTask.fromId != newTask.toId) &&
1252 (newTask.score > (bestTask.score / operationQualityFactor)))
1256 #endif // LEARNER_TRY_REVERSE
1258 if (bestTask.toId != -1) {
1259 LearnerTask[] tasks = learnerPtr.tasks;
1260 tasks[toId] = bestTask;
1262 learnerPtr.taskListPtr.list_insert(tasks[toId]);
1265 System.out.println("[new] op= " + bestTask.op + " from= "+ bestTask.fromId + " to= "+ bestTask.toId +
1266 " score= " + bestTask.score);
1272 visitedBitmapPtr.bitmap_free();
1273 workQueuePtr.queue_free();
1274 bQueryVectorPtr.clear();
1275 aQueryVectorPtr.clear();
1276 queryVectorPtr.clear();
1277 parentQueryVectorPtr.clear();
1282 /* =============================================================================
1284 * -- Call adtree_make before this
1285 * =============================================================================
1287 //Is not called anywhere now parallel code
1289 learner_run (int myId, int numThread, Learner learnerPtr)
1292 createTaskList(myId, numThread, learnerPtr);
1295 learnStructure(myId, numThread, learnerPtr);
1299 /* =============================================================================
1301 * -- Score entire network
1302 * =============================================================================
1308 Vector_t queryVectorPtr = new Vector_t(1);
1309 Vector_t parentQueryVectorPtr = new Vector_t(1);
1311 int numVar = adtreePtr.numVar;
1312 Query[] queries = new Query[numVar];
1314 for (int v = 0; v < numVar; v++) {
1315 queries[v] = new Query();
1316 queries[v].index = v;
1317 queries[v].value = QUERY_VALUE_WILDCARD;
1320 int numTotalParent = 0;
1321 float logLikelihood = 0.0f;
1323 for (int v = 0; v < numVar; v++) {
1325 IntList parentIdListPtr = netPtr.net_getParentIdListPtr(v);
1326 numTotalParent += parentIdListPtr.list_getSize();
1328 populateQueryVectors(netPtr,
1332 parentQueryVectorPtr);
1333 float localLogLikelihood = computeLocalLogLikelihood(v,
1338 parentQueryVectorPtr);
1339 logLikelihood += localLogLikelihood;
1342 queryVectorPtr.clear();
1343 parentQueryVectorPtr.clear();
1347 int numRecord = adtreePtr.numRecord;
1348 float penalty = (float)(-0.5f * (double)numTotalParent * Math.log((double)numRecord));
1349 float score = penalty + (float)numRecord * logLikelihood;
1355 /* =============================================================================
1357 * End of learner.java
1359 * =============================================================================