Fix a bug that caused SimplifyCFG to drop DebugLocs.
[oota-llvm.git] / lib / Transforms / Vectorize / BBVectorize.cpp
index 29fb01f1b2ec79444cace5c61d50f1cb22699634..1eed84fa90ff53be6410b5e479fa91e3c6d33179 100644 (file)
@@ -27,6 +27,7 @@
 #include "llvm/Analysis/AliasSetTracker.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/Constants.h"
@@ -206,7 +207,8 @@ namespace {
       : BasicBlockPass(ID), Config(C) {
       AA = &P->getAnalysis<AliasAnalysis>();
       DT = &P->getAnalysis<DominatorTreeWrapperPass>().getDomTree();
-      SE = &P->getAnalysis<ScalarEvolution>();
+      SE = &P->getAnalysis<ScalarEvolutionWrapperPass>().getSE();
+      TLI = &P->getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
       TTI = IgnoreTargetInfo
                 ? nullptr
                 : &P->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
@@ -221,6 +223,7 @@ namespace {
     AliasAnalysis *AA;
     DominatorTree *DT;
     ScalarEvolution *SE;
+    const TargetLibraryInfo *TLI;
     const TargetTransformInfo *TTI;
 
     // FIXME: const correct?
@@ -439,7 +442,8 @@ namespace {
 
       AA = &getAnalysis<AliasAnalysis>();
       DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
-      SE = &getAnalysis<ScalarEvolution>();
+      SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
+      TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
       TTI = IgnoreTargetInfo
                 ? nullptr
                 : &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
@@ -452,11 +456,12 @@ namespace {
       BasicBlockPass::getAnalysisUsage(AU);
       AU.addRequired<AliasAnalysis>();
       AU.addRequired<DominatorTreeWrapperPass>();
-      AU.addRequired<ScalarEvolution>();
+      AU.addRequired<ScalarEvolutionWrapperPass>();
+      AU.addRequired<TargetLibraryInfoWrapperPass>();
       AU.addRequired<TargetTransformInfoWrapperPass>();
       AU.addPreserved<AliasAnalysis>();
       AU.addPreserved<DominatorTreeWrapperPass>();
-      AU.addPreserved<ScalarEvolution>();
+      AU.addPreserved<ScalarEvolutionWrapperPass>();
       AU.setPreservesCFG();
     }
 
@@ -662,7 +667,7 @@ namespace {
       Function *F = I->getCalledFunction();
       if (!F) return false;
 
-      Intrinsic::ID IID = (Intrinsic::ID) F->getIntrinsicID();
+      Intrinsic::ID IID = F->getIntrinsicID();
       if (!IID) return false;
 
       switch(IID) {
@@ -842,7 +847,7 @@ namespace {
 
     // It is important to cleanup here so that future iterations of this
     // function have less work to do.
-    (void)SimplifyInstructionsInBlock(&BB, AA->getTargetLibraryInfo());
+    (void)SimplifyInstructionsInBlock(&BB, TLI);
     return true;
   }
 
@@ -1098,7 +1103,7 @@ namespace {
     CallInst *CI = dyn_cast<CallInst>(I);
     Function *FI;
     if (CI && (FI = CI->getCalledFunction())) {
-      Intrinsic::ID IID = (Intrinsic::ID) FI->getIntrinsicID();
+      Intrinsic::ID IID = FI->getIntrinsicID();
       if (IID == Intrinsic::powi || IID == Intrinsic::ctlz ||
           IID == Intrinsic::cttz) {
         Value *A1I = CI->getArgOperand(1),
@@ -2770,7 +2775,7 @@ namespace {
         continue;
       } else if (isa<CallInst>(I)) {
         Function *F = cast<CallInst>(I)->getCalledFunction();
-        Intrinsic::ID IID = (Intrinsic::ID) F->getIntrinsicID();
+        Intrinsic::ID IID = F->getIntrinsicID();
         if (o == NumOperands-1) {
           BasicBlock &BB = *I->getParent();
 
@@ -2806,55 +2811,51 @@ namespace {
                      Instruction *J, Instruction *K,
                      Instruction *&InsertionPt,
                      Instruction *&K1, Instruction *&K2) {
-    if (isa<StoreInst>(I)) {
-      AA->replaceWithNewValue(I, K);
-      AA->replaceWithNewValue(J, K);
-    } else {
-      Type *IType = I->getType();
-      Type *JType = J->getType();
+    if (isa<StoreInst>(I))
+      return;
 
-      VectorType *VType = getVecTypeForPair(IType, JType);
-      unsigned numElem = VType->getNumElements();
+    Type *IType = I->getType();
+    Type *JType = J->getType();
 
-      unsigned numElemI = getNumScalarElements(IType);
-      unsigned numElemJ = getNumScalarElements(JType);
+    VectorType *VType = getVecTypeForPair(IType, JType);
+    unsigned numElem = VType->getNumElements();
 
-      if (IType->isVectorTy()) {
-        std::vector<Constant*> Mask1(numElemI), Mask2(numElemI);
-        for (unsigned v = 0; v < numElemI; ++v) {
-          Mask1[v] = ConstantInt::get(Type::getInt32Ty(Context), v);
-          Mask2[v] = ConstantInt::get(Type::getInt32Ty(Context), numElemJ+v);
-        }
+    unsigned numElemI = getNumScalarElements(IType);
+    unsigned numElemJ = getNumScalarElements(JType);
 
-        K1 = new ShuffleVectorInst(K, UndefValue::get(VType),
-                                   ConstantVector::get( Mask1),
-                                   getReplacementName(K, false, 1));
-      } else {
-        Value *CV0 = ConstantInt::get(Type::getInt32Ty(Context), 0);
-        K1 = ExtractElementInst::Create(K, CV0,
-                                          getReplacementName(K, false, 1));
+    if (IType->isVectorTy()) {
+      std::vector<Constant *> Mask1(numElemI), Mask2(numElemI);
+      for (unsigned v = 0; v < numElemI; ++v) {
+        Mask1[v] = ConstantInt::get(Type::getInt32Ty(Context), v);
+        Mask2[v] = ConstantInt::get(Type::getInt32Ty(Context), numElemJ + v);
       }
 
-      if (JType->isVectorTy()) {
-        std::vector<Constant*> Mask1(numElemJ), Mask2(numElemJ);
-        for (unsigned v = 0; v < numElemJ; ++v) {
-          Mask1[v] = ConstantInt::get(Type::getInt32Ty(Context), v);
-          Mask2[v] = ConstantInt::get(Type::getInt32Ty(Context), numElemI+v);
-        }
+      K1 = new ShuffleVectorInst(K, UndefValue::get(VType),
+                                 ConstantVector::get(Mask1),
+                                 getReplacementName(K, false, 1));
+    } else {
+      Value *CV0 = ConstantInt::get(Type::getInt32Ty(Context), 0);
+      K1 = ExtractElementInst::Create(K, CV0, getReplacementName(K, false, 1));
+    }
 
-        K2 = new ShuffleVectorInst(K, UndefValue::get(VType),
-                                   ConstantVector::get( Mask2),
-                                   getReplacementName(K, false, 2));
-      } else {
-        Value *CV1 = ConstantInt::get(Type::getInt32Ty(Context), numElem-1);
-        K2 = ExtractElementInst::Create(K, CV1,
-                                          getReplacementName(K, false, 2));
+    if (JType->isVectorTy()) {
+      std::vector<Constant *> Mask1(numElemJ), Mask2(numElemJ);
+      for (unsigned v = 0; v < numElemJ; ++v) {
+        Mask1[v] = ConstantInt::get(Type::getInt32Ty(Context), v);
+        Mask2[v] = ConstantInt::get(Type::getInt32Ty(Context), numElemI + v);
       }
 
-      K1->insertAfter(K);
-      K2->insertAfter(K1);
-      InsertionPt = K2;
+      K2 = new ShuffleVectorInst(K, UndefValue::get(VType),
+                                 ConstantVector::get(Mask2),
+                                 getReplacementName(K, false, 2));
+    } else {
+      Value *CV1 = ConstantInt::get(Type::getInt32Ty(Context), numElem - 1);
+      K2 = ExtractElementInst::Create(K, CV1, getReplacementName(K, false, 2));
     }
+
+    K1->insertAfter(K);
+    K2->insertAfter(K1);
+    InsertionPt = K2;
   }
 
   // Move all uses of the function I (including pairing-induced uses) after J.
@@ -3103,10 +3104,21 @@ namespace {
       else if (H->hasName())
         K->takeName(H);
 
-      if (!isa<StoreInst>(K))
+      if (auto CS = CallSite(K)) {
+        SmallVector<Type *, 3> Tys;
+        FunctionType *Old = CS.getFunctionType();
+        unsigned NumOld = Old->getNumParams();
+        assert(NumOld <= ReplacedOperands.size());
+        for (unsigned i = 0; i != NumOld; ++i)
+          Tys.push_back(ReplacedOperands[i]->getType());
+        CS.mutateFunctionType(
+            FunctionType::get(getVecTypeForPair(L->getType(), H->getType()),
+                              Tys, Old->isVarArg()));
+      } else if (!isa<StoreInst>(K))
         K->mutateType(getVecTypeForPair(L->getType(), H->getType()));
 
       unsigned KnownIDs[] = {
+        LLVMContext::MD_dbg,
         LLVMContext::MD_tbaa,
         LLVMContext::MD_alias_scope,
         LLVMContext::MD_noalias,
@@ -3135,8 +3147,6 @@ namespace {
       if (!isa<StoreInst>(I)) {
         L->replaceAllUsesWith(K1);
         H->replaceAllUsesWith(K2);
-        AA->replaceWithNewValue(L, K1);
-        AA->replaceWithNewValue(H, K2);
       }
 
       // Instructions that may read from memory may be in the load move set.
@@ -3188,9 +3198,10 @@ char BBVectorize::ID = 0;
 static const char bb_vectorize_name[] = "Basic-Block Vectorization";
 INITIALIZE_PASS_BEGIN(BBVectorize, BBV_NAME, bb_vectorize_name, false, false)
 INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
+INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
 INITIALIZE_PASS_END(BBVectorize, BBV_NAME, bb_vectorize_name, false, false)
 
 BasicBlockPass *llvm::createBBVectorizePass(const VectorizeConfig &C) {