X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FTransforms%2FUtils%2FLowerSwitch.cpp;h=4acd988691d22f99caaba4c65e98fef1565e026f;hb=cd52a7a381a73c53ec4ef517ad87f19808cb1a28;hp=8141049a8179fb4bb7e223637a2ff30ecec9ff91;hpb=5243154a6a08bcda3d6ad9184c5552509c12f94d;p=oota-llvm.git diff --git a/lib/Transforms/Utils/LowerSwitch.cpp b/lib/Transforms/Utils/LowerSwitch.cpp index 8141049a817..4acd988691d 100644 --- a/lib/Transforms/Utils/LowerSwitch.cpp +++ b/lib/Transforms/Utils/LowerSwitch.cpp @@ -14,17 +14,17 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/IR/CFG.h" #include "llvm/Pass.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" #include using namespace llvm; @@ -67,11 +67,11 @@ namespace { } struct CaseRange { - Constant* Low; - Constant* High; + ConstantInt* Low; + ConstantInt* High; BasicBlock* BB; - CaseRange(Constant *low, Constant *high, BasicBlock *bb) + CaseRange(ConstantInt *low, ConstantInt *high, BasicBlock *bb) : Low(low), High(high), BB(bb) {} }; @@ -175,11 +175,16 @@ static void fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, // Remove additional occurences coming from condensed cases and keep the // number of incoming values equal to the number of branches to SuccBB. + SmallVector Indices; for (++Idx; LocalNumMergedCases > 0 && Idx < E; ++Idx) if (PN->getIncomingBlock(Idx) == OrigBB) { - PN->removeIncomingValue(Idx); + Indices.push_back(Idx); LocalNumMergedCases--; } + // Remove incoming values in the reverse order to prevent invalidating + // *successive* index. + for (auto III = Indices.rbegin(), IIE = Indices.rend(); III != IIE; ++III) + PN->removeIncomingValue(*III); } } @@ -220,14 +225,14 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, CaseRange &Pivot = *(Begin + Mid); DEBUG(dbgs() << "Pivot ==> " - << cast(Pivot.Low)->getValue() - << " -" << cast(Pivot.High)->getValue() << "\n"); + << Pivot.Low->getValue() + << " -" << Pivot.High->getValue() << "\n"); // NewLowerBound here should never be the integer minimal value. // This is because it is computed from a case range that is never // the smallest, so there is always a case range that has at least // a smaller value. - ConstantInt *NewLowerBound = cast(Pivot.Low); + ConstantInt *NewLowerBound = Pivot.Low; // Because NewLowerBound is never the smallest representable integer // it is safe here to subtract one. @@ -236,16 +241,16 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, if (!UnreachableRanges.empty()) { // Check if the gap between LHS's highest and NewLowerBound is unreachable. - int64_t GapLow = cast(LHS.back().High)->getSExtValue() + 1; + int64_t GapLow = LHS.back().High->getSExtValue() + 1; int64_t GapHigh = NewLowerBound->getSExtValue() - 1; IntRange Gap = { GapLow, GapHigh }; if (GapHigh >= GapLow && IsInRanges(Gap, UnreachableRanges)) - NewUpperBound = cast(LHS.back().High); + NewUpperBound = LHS.back().High; } DEBUG(dbgs() << "LHS Bounds ==> "; if (LowerBound) { - dbgs() << cast(LowerBound)->getSExtValue(); + dbgs() << LowerBound->getSExtValue(); } else { dbgs() << "NONE"; } @@ -253,7 +258,7 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, dbgs() << "RHS Bounds ==> "; dbgs() << NewLowerBound->getSExtValue() << " - "; if (UpperBound) { - dbgs() << cast(UpperBound)->getSExtValue() << "\n"; + dbgs() << UpperBound->getSExtValue() << "\n"; } else { dbgs() << "NONE\n"; }); @@ -304,11 +309,11 @@ BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, Leaf.Low, "SwitchLeaf"); } else { // Make range comparison - if (cast(Leaf.Low)->isMinValue(true /*isSigned*/)) { + if (Leaf.Low->isMinValue(true /*isSigned*/)) { // Val >= Min && Val <= Hi --> Val <= Hi Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, "SwitchLeaf"); - } else if (cast(Leaf.Low)->isZero()) { + } else if (Leaf.Low->isZero()) { // Val >= 0 && Val <= Hi --> Val <=u Hi Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, "SwitchLeaf"); @@ -333,8 +338,8 @@ BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, for (BasicBlock::iterator I = Succ->begin(); isa(I); ++I) { PHINode* PN = cast(I); // Remove all but one incoming entries from the cluster - uint64_t Range = cast(Leaf.High)->getSExtValue() - - cast(Leaf.Low)->getSExtValue(); + uint64_t Range = Leaf.High->getSExtValue() - + Leaf.Low->getSExtValue(); for (uint64_t j = 0; j < Range; ++j) { PN->removeIncomingValue(OrigBlock); } @@ -359,23 +364,26 @@ unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { std::sort(Cases.begin(), Cases.end(), CaseCmp()); // Merge case into clusters - if (Cases.size()>=2) - for (CaseItr I = Cases.begin(), J = std::next(Cases.begin()); - J != Cases.end();) { - int64_t nextValue = cast(J->Low)->getSExtValue(); - int64_t currentValue = cast(I->High)->getSExtValue(); + if (Cases.size() >= 2) { + CaseItr I = Cases.begin(); + for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) { + int64_t nextValue = J->Low->getSExtValue(); + int64_t currentValue = I->High->getSExtValue(); BasicBlock* nextBB = J->BB; BasicBlock* currentBB = I->BB; // If the two neighboring cases go to the same destination, merge them // into a single case. - if ((nextValue-currentValue==1) && (currentBB == nextBB)) { + assert(nextValue > currentValue && "Cases should be strictly ascending"); + if ((nextValue == currentValue + 1) && (currentBB == nextBB)) { I->High = J->High; - J = Cases.erase(J); - } else { - I = J++; + // FIXME: Combine branch weights. + } else if (++I != J) { + *I = *J; } } + Cases.erase(std::next(I), Cases.end()); + } for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) { if (I->Low != I->High) @@ -420,8 +428,8 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI) { // know that the value passed to the switch must be exactly one of the case // values. assert(!Cases.empty()); - LowerBound = cast(Cases.front().Low); - UpperBound = cast(Cases.back().High); + LowerBound = Cases.front().Low; + UpperBound = Cases.back().High; DenseMap Popularity; unsigned MaxPop = 0; @@ -430,8 +438,8 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI) { IntRange R = { INT64_MIN, INT64_MAX }; UnreachableRanges.push_back(R); for (const auto &I : Cases) { - int64_t Low = cast(I.Low)->getSExtValue(); - int64_t High = cast(I.High)->getSExtValue(); + int64_t Low = I.Low->getSExtValue(); + int64_t High = I.High->getSExtValue(); IntRange &LastRange = UnreachableRanges.back(); if (LastRange.Low == Low) { @@ -471,12 +479,10 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI) { // cases. assert(MaxPop > 0 && PopSucc); Default = PopSucc; - for (CaseItr I = Cases.begin(); I != Cases.end();) { - if (I->BB == PopSucc) - I = Cases.erase(I); - else - ++I; - } + Cases.erase(std::remove_if( + Cases.begin(), Cases.end(), + [PopSucc](const CaseRange &R) { return R.BB == PopSucc; }), + Cases.end()); // If there are no cases left, just branch. if (Cases.empty()) {