X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FTransforms%2FIPO%2FLoopExtractor.cpp;h=e1ce290526c89c47647b9987d8939a85a3d1e72b;hb=551ccae044b0ff658fe629dd67edd5ffe75d10e8;hp=f0d79e1a06ef6b8179739f96c69bc29a51bb5d48;hpb=97836fad2c3d705c90855bf2fbb79696c129a64f;p=oota-llvm.git diff --git a/lib/Transforms/IPO/LoopExtractor.cpp b/lib/Transforms/IPO/LoopExtractor.cpp index f0d79e1a06e..e1ce290526c 100644 --- a/lib/Transforms/IPO/LoopExtractor.cpp +++ b/lib/Transforms/IPO/LoopExtractor.cpp @@ -15,46 +15,170 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO.h" +#include "llvm/Instructions.h" #include "llvm/Module.h" #include "llvm/Pass.h" +#include "llvm/Analysis/Dominators.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/FunctionUtils.h" +#include "llvm/ADT/Statistic.h" using namespace llvm; namespace { - // FIXME: PassManager should allow Module passes to require FunctionPasses + Statistic<> NumExtracted("loop-extract", "Number of loops extracted"); + + // FIXME: This is not a function pass, but the PassManager doesn't allow + // Module passes to require FunctionPasses, so we can't get loop info if we're + // not a function pass. struct LoopExtractor : public FunctionPass { + unsigned NumLoops; + + LoopExtractor(unsigned numLoops = ~0) : NumLoops(numLoops) {} + virtual bool runOnFunction(Function &F); virtual void getAnalysisUsage(AnalysisUsage &AU) const { - AU.addRequired(); + AU.addRequiredID(BreakCriticalEdgesID); AU.addRequiredID(LoopSimplifyID); + AU.addRequired(); + AU.addRequired(); } }; RegisterOpt X("loop-extract", "Extract loops into new functions"); + + /// SingleLoopExtractor - For bugpoint. + struct SingleLoopExtractor : public LoopExtractor { + SingleLoopExtractor() : LoopExtractor(1) {} + }; + + RegisterOpt + Y("loop-extract-single", "Extract at most one loop into a new function"); } // End anonymous namespace bool LoopExtractor::runOnFunction(Function &F) { LoopInfo &LI = getAnalysis(); - // We don't want to keep extracting the only loop of a function into a new one - if (LI.begin() == LI.end() || LI.begin() + 1 == LI.end()) + // If this function has no loops, there is nothing to do. + if (LI.begin() == LI.end()) return false; - bool Changed = false; + DominatorSet &DS = getAnalysis(); - // Try to move each loop out of the code into separate function - for (LoopInfo::iterator i = LI.begin(), e = LI.end(); i != e; ++i) - Changed |= (ExtractLoop(*i) != 0); + // If there is more than one top-level loop in this function, extract all of + // the loops. + bool Changed = false; + if (LI.end()-LI.begin() > 1) { + for (LoopInfo::iterator i = LI.begin(), e = LI.end(); i != e; ++i) { + if (NumLoops == 0) return Changed; + --NumLoops; + Changed |= ExtractLoop(DS, *i) != 0; + ++NumExtracted; + } + } else { + // Otherwise there is exactly one top-level loop. If this function is more + // than a minimal wrapper around the loop, extract the loop. + Loop *TLL = *LI.begin(); + bool ShouldExtractLoop = false; + + // Extract the loop if the entry block doesn't branch to the loop header. + TerminatorInst *EntryTI = F.getEntryBlock().getTerminator(); + if (!isa(EntryTI) || + !cast(EntryTI)->isUnconditional() || + EntryTI->getSuccessor(0) != TLL->getHeader()) + ShouldExtractLoop = true; + else { + // Check to see if any exits from the loop are more than just return + // blocks. + std::vector ExitBlocks; + TLL->getExitBlocks(ExitBlocks); + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) + if (!isa(ExitBlocks[i]->getTerminator())) { + ShouldExtractLoop = true; + break; + } + } + + if (ShouldExtractLoop) { + if (NumLoops == 0) return Changed; + --NumLoops; + Changed |= ExtractLoop(DS, TLL) != 0; + ++NumExtracted; + } else { + // Okay, this function is a minimal container around the specified loop. + // If we extract the loop, we will continue to just keep extracting it + // infinitely... so don't extract it. However, if the loop contains any + // subloops, extract them. + for (Loop::iterator i = TLL->begin(), e = TLL->end(); i != e; ++i) { + if (NumLoops == 0) return Changed; + --NumLoops; + Changed |= ExtractLoop(DS, *i) != 0; + ++NumExtracted; + } + } + } return Changed; } -/// createLoopExtractorPass -/// -Pass* llvm::createLoopExtractorPass() { - return new LoopExtractor(); +// createSingleLoopExtractorPass - This pass extracts one natural loop from the +// program into a function if it can. This is used by bugpoint. +// +Pass *llvm::createSingleLoopExtractorPass() { + return new SingleLoopExtractor(); +} + + +namespace { + /// BlockExtractorPass - This pass is used by bugpoint to extract all blocks + /// from the module into their own functions except for those specified by the + /// BlocksToNotExtract list. + class BlockExtractorPass : public Pass { + std::vector BlocksToNotExtract; + public: + BlockExtractorPass(std::vector &B) : BlocksToNotExtract(B) {} + BlockExtractorPass() {} + + bool run(Module &M); + }; + RegisterOpt + XX("extract-blocks", "Extract Basic Blocks From Module (for bugpoint use)"); +} + +// createBlockExtractorPass - This pass extracts all blocks (except those +// specified in the argument list) from the functions in the module. +// +Pass *llvm::createBlockExtractorPass(std::vector &BTNE) { + return new BlockExtractorPass(BTNE); +} + +bool BlockExtractorPass::run(Module &M) { + std::set TranslatedBlocksToNotExtract; + for (unsigned i = 0, e = BlocksToNotExtract.size(); i != e; ++i) { + BasicBlock *BB = BlocksToNotExtract[i]; + Function *F = BB->getParent(); + + // Map the corresponding function in this module. + Function *MF = M.getFunction(F->getName(), F->getFunctionType()); + + // Figure out which index the basic block is in its function. + Function::iterator BBI = MF->begin(); + std::advance(BBI, std::distance(F->begin(), Function::iterator(BB))); + TranslatedBlocksToNotExtract.insert(BBI); + } + + // Now that we know which blocks to not extract, figure out which ones we WANT + // to extract. + std::vector BlocksToExtract; + for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) + if (!TranslatedBlocksToNotExtract.count(BB)) + BlocksToExtract.push_back(BB); + + for (unsigned i = 0, e = BlocksToExtract.size(); i != e; ++i) + ExtractBasicBlock(BlocksToExtract[i]); + + return !BlocksToExtract.empty(); }