IR: Add isUniqued() and isTemporary()
[oota-llvm.git] / lib / Transforms / Utils / CodeExtractor.cpp
index 0178c336d9e7a5fa56bc9d0e84d1509bc07a40e5..e70a7d6e76c6e5f123334a9cf8497e604f8fe153 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Utils/CodeExtractor.h"
-#include "llvm/Constants.h"
-#include "llvm/DerivedTypes.h"
-#include "llvm/Instructions.h"
-#include "llvm/Intrinsics.h"
-#include "llvm/LLVMContext.h"
-#include "llvm/Module.h"
-#include "llvm/Pass.h"
-#include "llvm/Analysis/Dominators.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/LoopInfo.h"
-#include "llvm/Analysis/Verifier.h"
-#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Analysis/RegionInfo.h"
+#include "llvm/Analysis/RegionIterator.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Verifier.h"
+#include "llvm/Pass.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/ADT/StringExtras.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include <algorithm>
 #include <set>
 using namespace llvm;
 
+#define DEBUG_TYPE "code-extractor"
+
 // Provide a command-line option to aggregate function arguments into a struct
 // for functions produced by the code extractor. This is useful when converting
 // extracted functions to pthread-based code, as only one argument (void*) can
@@ -63,18 +68,18 @@ static bool isBlockValidForExtraction(const BasicBlock &BB) {
 }
 
 /// \brief Build a set of blocks to extract if the input blocks are viable.
-static SetVector<BasicBlock *>
-buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs) {
+template <typename IteratorT>
+static SetVector<BasicBlock *> buildExtractionBlockSet(IteratorT BBBegin,
+                                                       IteratorT BBEnd) {
   SetVector<BasicBlock *> Result;
 
-  assert(!BBs.empty());
+  assert(BBBegin != BBEnd);
 
   // Loop over the blocks, adding them to our set-vector, and aborting with an
   // empty set if we encounter invalid blocks.
-  for (ArrayRef<BasicBlock *>::iterator I = BBs.begin(), E = BBs.end();
-       I != E; ++I) {
+  for (IteratorT I = BBBegin, E = BBEnd; I != E; ++I) {
     if (!Result.insert(*I))
-      continue;
+      llvm_unreachable("Repeated basic blocks in extraction input");
 
     if (!isBlockValidForExtraction(**I)) {
       Result.clear();
@@ -83,8 +88,8 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs) {
   }
 
 #ifndef NDEBUG
-  for (ArrayRef<BasicBlock *>::iterator I = llvm::next(BBs.begin()),
-                                        E = BBs.end();
+  for (SetVector<BasicBlock *>::iterator I = std::next(Result.begin()),
+                                         E = Result.end();
        I != E; ++I)
     for (pred_iterator PI = pred_begin(*I), PE = pred_end(*I);
          PI != PE; ++PI)
@@ -96,8 +101,26 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs) {
   return Result;
 }
 
+/// \brief Helper to call buildExtractionBlockSet with an ArrayRef.
+static SetVector<BasicBlock *>
+buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs) {
+  return buildExtractionBlockSet(BBs.begin(), BBs.end());
+}
+
+/// \brief Helper to call buildExtractionBlockSet with a RegionNode.
+static SetVector<BasicBlock *>
+buildExtractionBlockSet(const RegionNode &RN) {
+  if (!RN.isSubRegion())
+    // Just a single BasicBlock.
+    return buildExtractionBlockSet(RN.getNodeAs<BasicBlock>());
+
+  const Region &R = *RN.getNodeAs<Region>();
+
+  return buildExtractionBlockSet(R.block_begin(), R.block_end());
+}
+
 CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs)
-  : DT(0), AggregateArgs(AggregateArgs||AggregateArgsOpt),
+  : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt),
     Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {}
 
 CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
@@ -109,6 +132,10 @@ CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs)
   : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
     Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {}
 
+CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN,
+                             bool AggregateArgs)
+  : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
+    Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {}
 
 /// definedInRegion - Return true if the specified value is defined in the
 /// extracted region.
@@ -130,6 +157,31 @@ static bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) {
   return false;
 }
 
+void CodeExtractor::findInputsOutputs(ValueSet &Inputs,
+                                      ValueSet &Outputs) const {
+  for (SetVector<BasicBlock *>::const_iterator I = Blocks.begin(),
+                                               E = Blocks.end();
+       I != E; ++I) {
+    BasicBlock *BB = *I;
+
+    // If a used value is defined outside the region, it's an input.  If an
+    // instruction is used outside the region, it's an output.
+    for (BasicBlock::iterator II = BB->begin(), IE = BB->end();
+         II != IE; ++II) {
+      for (User::op_iterator OI = II->op_begin(), OE = II->op_end();
+           OI != OE; ++OI)
+        if (definedInCaller(Blocks, *OI))
+          Inputs.insert(*OI);
+
+      for (User *U : II->users())
+        if (!definedInRegion(Blocks, U)) {
+          Outputs.insert(II);
+          break;
+        }
+    }
+  }
+}
+
 /// severSplitPHINodes - If a PHI node has multiple inputs from outside of the
 /// region, we need to split the entry block of the region so that the PHI node
 /// is easier to deal with.
@@ -226,47 +278,13 @@ void CodeExtractor::splitReturnBlocks() {
 
         DomTreeNode *NewNode = DT->addNewBlock(New, *I);
 
-        for (SmallVector<DomTreeNode*, 8>::iterator I = Children.begin(),
-               E = Children.end(); I != E; ++I) 
+        for (SmallVectorImpl<DomTreeNode *>::iterator I = Children.begin(),
+               E = Children.end(); I != E; ++I)
           DT->changeImmediateDominator(*I, NewNode);
       }
     }
 }
 
-// findInputsOutputs - Find inputs to, outputs from the code region.
-//
-void CodeExtractor::findInputsOutputs(ValueSet &inputs, ValueSet &outputs) {
-  std::set<BasicBlock*> ExitBlocks;
-  for (SetVector<BasicBlock*>::const_iterator ci = Blocks.begin(),
-       ce = Blocks.end(); ci != ce; ++ci) {
-    BasicBlock *BB = *ci;
-
-    for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) {
-      // If a used value is defined outside the region, it's an input.  If an
-      // instruction is used outside the region, it's an output.
-      for (User::op_iterator O = I->op_begin(), E = I->op_end(); O != E; ++O)
-        if (definedInCaller(Blocks, *O))
-          inputs.insert(*O);
-
-      // Consider uses of this instruction (outputs).
-      for (Value::use_iterator UI = I->use_begin(), E = I->use_end();
-           UI != E; ++UI)
-        if (!definedInRegion(Blocks, *UI)) {
-          outputs.insert(I);
-          break;
-        }
-    } // for: insts
-
-    // Keep track of the exit blocks from the region.
-    TerminatorInst *TI = BB->getTerminator();
-    for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
-      if (!Blocks.count(TI->getSuccessor(i)))
-        ExitBlocks.insert(TI->getSuccessor(i));
-  } // for: basic blocks
-
-  NumExitBlocks = ExitBlocks.size();
-}
-
 /// constructFunction - make a function based on inputs and outputs, as follows:
 /// f(in0, ..., inN, out0, ..., outN)
 ///
@@ -330,7 +348,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
                                            header->getName(), M);
   // If the old function is no-throw, so is the new one.
   if (oldFunction->doesNotThrow())
-    newFunction->setDoesNotThrow(true);
+    newFunction->setDoesNotThrow();
   
   newFunction->getBasicBlockList().push_back(newRootNode);
 
@@ -352,7 +370,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
     } else
       RewriteVal = AI++;
 
-    std::vector<User*> Users(inputs[i]->use_begin(), inputs[i]->use_end());
+    std::vector<User*> Users(inputs[i]->user_begin(), inputs[i]->user_end());
     for (std::vector<User*>::iterator use = Users.begin(), useE = Users.end();
          use != useE; ++use)
       if (Instruction* inst = dyn_cast<Instruction>(*use))
@@ -372,7 +390,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
   // Rewrite branches to basic blocks outside of the loop to new dummy blocks
   // within the new function. This must be done before we lose track of which
   // blocks were originally in the code region.
-  std::vector<User*> Users(header->use_begin(), header->use_end());
+  std::vector<User*> Users(header->user_begin(), header->user_end());
   for (unsigned i = 0, e = Users.size(); i != e; ++i)
     // The BasicBlock which contains the branch is not in the region
     // modify the branch target to a new block
@@ -388,14 +406,13 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
 /// that uses the value within the basic block, and return the predecessor
 /// block associated with that use, or return 0 if none is found.
 static BasicBlock* FindPhiPredForUseInBlock(Value* Used, BasicBlock* BB) {
-  for (Value::use_iterator UI = Used->use_begin(),
-       UE = Used->use_end(); UI != UE; ++UI) {
-     PHINode *P = dyn_cast<PHINode>(*UI);
+  for (Use &U : Used->uses()) {
+     PHINode *P = dyn_cast<PHINode>(U.getUser());
      if (P && P->getParent() == BB)
-       return P->getIncomingBlock(UI);
+       return P->getIncomingBlock(U);
   }
-  
-  return 0;
+
+  return nullptr;
 }
 
 /// emitCallAndSwitchStatement - This method sets up the caller side by adding
@@ -423,14 +440,14 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,
       StructValues.push_back(*i);
     } else {
       AllocaInst *alloca =
-        new AllocaInst((*i)->getType(), 0, (*i)->getName()+".loc",
+        new AllocaInst((*i)->getType(), nullptr, (*i)->getName()+".loc",
                        codeReplacer->getParent()->begin()->begin());
       ReloadOutputs.push_back(alloca);
       params.push_back(alloca);
     }
   }
 
-  AllocaInst *Struct = 0;
+  AllocaInst *Struct = nullptr;
   if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
     std::vector<Type*> ArgTypes;
     for (ValueSet::iterator v = StructValues.begin(),
@@ -440,7 +457,7 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,
     // Allocate a struct at the beginning of this function
     Type *StructArgTy = StructType::get(newFunction->getContext(), ArgTypes);
     Struct =
-      new AllocaInst(StructArgTy, 0, "structArg",
+      new AllocaInst(StructArgTy, nullptr, "structArg",
                      codeReplacer->getParent()->begin()->begin());
     params.push_back(Struct);
 
@@ -469,7 +486,7 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,
 
   // Reload the outputs passed in by reference
   for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
-    Value *Output = 0;
+    Value *Output = nullptr;
     if (AggregateArgs) {
       Value *Idx[2];
       Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
@@ -485,7 +502,7 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,
     LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload");
     Reloads.push_back(load);
     codeReplacer->getInstList().push_back(load);
-    std::vector<User*> Users(outputs[i]->use_begin(), outputs[i]->use_end());
+    std::vector<User*> Users(outputs[i]->user_begin(), outputs[i]->user_end());
     for (unsigned u = 0, e = Users.size(); u != e; ++u) {
       Instruction *inst = cast<Instruction>(Users[u]);
       if (!Blocks.count(inst->getParent()))
@@ -522,7 +539,7 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,
                                          newFunction);
           unsigned SuccNum = switchVal++;
 
-          Value *brVal = 0;
+          Value *brVal = nullptr;
           switch (NumExitBlocks) {
           case 0:
           case 1: break;  // No value needed.
@@ -618,7 +635,7 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,
 
     // Check if the function should return a value
     if (OldFnRetTy->isVoidTy()) {
-      ReturnInst::Create(Context, 0, TheSwitch);  // Return void
+      ReturnInst::Create(Context, nullptr, TheSwitch);  // Return void
     } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) {
       // return what we have
       ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch);
@@ -670,7 +687,7 @@ void CodeExtractor::moveCodeToFunction(Function *newFunction) {
 
 Function *CodeExtractor::extractCodeRegion() {
   if (!isEligible())
-    return 0;
+    return nullptr;
 
   ValueSet inputs, outputs;
 
@@ -701,6 +718,14 @@ Function *CodeExtractor::extractCodeRegion() {
   // Find inputs to, outputs from the code region.
   findInputsOutputs(inputs, outputs);
 
+  SmallPtrSet<BasicBlock *, 1> ExitBlocks;
+  for (SetVector<BasicBlock *>::iterator I = Blocks.begin(), E = Blocks.end();
+       I != E; ++I)
+    for (succ_iterator SI = succ_begin(*I), SE = succ_end(*I); SI != SE; ++SI)
+      if (!Blocks.count(*SI))
+        ExitBlocks.insert(*SI);
+  NumExitBlocks = ExitBlocks.size();
+
   // Construct new function based on inputs/outputs & add allocas for all defs.
   Function *newFunction = constructFunction(inputs, outputs, header,
                                             newFuncRoot,