- // Merge probability data into PredBlock's branch.
- APInt A, B, C, D;
- if (PBI->isConditional() && BI->isConditional() &&
- ExtractBranchMetadata(PBI, C, D) && ExtractBranchMetadata(BI, A, B)) {
- // Given IR which does:
- // bbA:
- // br i1 %x, label %bbB, label %bbC
- // bbB:
- // br i1 %y, label %bbD, label %bbC
- // Let's call the probability that we take the edge from %bbA to %bbB
- // 'a', from %bbA to %bbC, 'b', from %bbB to %bbD 'c' and from %bbB to
- // %bbC probability 'd'.
- //
- // We transform the IR into:
- // bbA:
- // br i1 %z, label %bbD, label %bbC
- // where the probability of going to %bbD is (a*c) and going to bbC is
- // (b+a*d).
- //
- // Probabilities aren't stored as ratios directly. Using branch weights,
- // we get:
- // (a*c)% = A*C, (b+(a*d))% = A*D+B*C+B*D.
-
- // In the event of overflow, we want to drop the LSB of the input
- // probabilities.
- unsigned BitsLost;
-
- // Ignore overflow result on ProbTrue.
- APInt ProbTrue = MultiplyAndLosePrecision(A, C, B, D, BitsLost);
-
- APInt Tmp1 = MultiplyAndLosePrecision(B, D, A, C, BitsLost);
- if (BitsLost) {
- ProbTrue = ProbTrue.lshr(BitsLost*2);
- }
-
- APInt Tmp2 = MultiplyAndLosePrecision(A, D, C, B, BitsLost);
- if (BitsLost) {
- ProbTrue = ProbTrue.lshr(BitsLost*2);
- Tmp1 = Tmp1.lshr(BitsLost*2);
- }
-
- APInt Tmp3 = MultiplyAndLosePrecision(B, C, A, D, BitsLost);
- if (BitsLost) {
- ProbTrue = ProbTrue.lshr(BitsLost*2);
- Tmp1 = Tmp1.lshr(BitsLost*2);
- Tmp2 = Tmp2.lshr(BitsLost*2);
- }
-
- bool Overflow1 = false, Overflow2 = false;
- APInt Tmp4 = Tmp2.uadd_ov(Tmp3, Overflow1);
- APInt ProbFalse = Tmp4.uadd_ov(Tmp1, Overflow2);
-
- if (Overflow1 || Overflow2) {
- ProbTrue = ProbTrue.lshr(1);
- Tmp1 = Tmp1.lshr(1);
- Tmp2 = Tmp2.lshr(1);
- Tmp3 = Tmp3.lshr(1);
- Tmp4 = Tmp2 + Tmp3;
- ProbFalse = Tmp4 + Tmp1;
- }
-
- // The sum of branch weights must fit in 32-bits.
- if (ProbTrue.isNegative() && ProbFalse.isNegative()) {
- ProbTrue = ProbTrue.lshr(1);
- ProbFalse = ProbFalse.lshr(1);
- }
-
- if (ProbTrue != ProbFalse) {
- // Normalize the result.
- APInt GCD = APIntOps::GreatestCommonDivisor(ProbTrue, ProbFalse);
- ProbTrue = ProbTrue.udiv(GCD);
- ProbFalse = ProbFalse.udiv(GCD);
-
- MDBuilder MDB(BI->getContext());
- MDNode *N = MDB.createBranchWeights(ProbTrue.getZExtValue(),
- ProbFalse.getZExtValue());
- PBI->setMetadata(LLVMContext::MD_prof, N);
- } else {
- PBI->setMetadata(LLVMContext::MD_prof, NULL);
- }
- } else {
- PBI->setMetadata(LLVMContext::MD_prof, NULL);
- }
-