[PM/AA] Switch to an early-exit. NFC. This was split out of another
[oota-llvm.git] / lib / Transforms / Vectorize / BBVectorize.cpp
index 50c3fa41b1dad2136ba9ee0e78b23894d3cf9f6f..df016baafe5c9c6441115fad1078c43eab2ef057 100644 (file)
@@ -39,6 +39,7 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/ValueHandle.h"
 #include "llvm/Pass.h"
@@ -201,14 +202,14 @@ namespace {
       initializeBBVectorizePass(*PassRegistry::getPassRegistry());
     }
 
-    BBVectorize(Pass *P, const VectorizeConfig &C)
+    BBVectorize(Pass *P, Function &F, const VectorizeConfig &C)
       : BasicBlockPass(ID), Config(C) {
       AA = &P->getAnalysis<AliasAnalysis>();
       DT = &P->getAnalysis<DominatorTreeWrapperPass>().getDomTree();
       SE = &P->getAnalysis<ScalarEvolution>();
-      DataLayoutPass *DLP = P->getAnalysisIfAvailable<DataLayoutPass>();
-      DL = DLP ? &DLP->getDataLayout() : nullptr;
-      TTI = IgnoreTargetInfo ? nullptr : &P->getAnalysis<TargetTransformInfo>();
+      TTI = IgnoreTargetInfo
+                ? nullptr
+                : &P->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
     }
 
     typedef std::pair<Value *, Value *> ValuePair;
@@ -220,7 +221,6 @@ namespace {
     AliasAnalysis *AA;
     DominatorTree *DT;
     ScalarEvolution *SE;
-    const DataLayout *DL;
     const TargetTransformInfo *TTI;
 
     // FIXME: const correct?
@@ -440,9 +440,10 @@ namespace {
       AA = &getAnalysis<AliasAnalysis>();
       DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
       SE = &getAnalysis<ScalarEvolution>();
-      DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>();
-      DL = DLP ? &DLP->getDataLayout() : nullptr;
-      TTI = IgnoreTargetInfo ? nullptr : &getAnalysis<TargetTransformInfo>();
+      TTI = IgnoreTargetInfo
+                ? nullptr
+                : &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
+                      *BB.getParent());
 
       return vectorizeBB(BB);
     }
@@ -452,7 +453,7 @@ namespace {
       AU.addRequired<AliasAnalysis>();
       AU.addRequired<DominatorTreeWrapperPass>();
       AU.addRequired<ScalarEvolution>();
-      AU.addRequired<TargetTransformInfo>();
+      AU.addRequired<TargetTransformInfoWrapperPass>();
       AU.addPreserved<AliasAnalysis>();
       AU.addPreserved<DominatorTreeWrapperPass>();
       AU.addPreserved<ScalarEvolution>();
@@ -637,19 +638,19 @@ namespace {
             dyn_cast<SCEVConstant>(OffsetSCEV)) {
         ConstantInt *IntOff = ConstOffSCEV->getValue();
         int64_t Offset = IntOff->getSExtValue();
-
+        const DataLayout &DL = I->getModule()->getDataLayout();
         Type *VTy = IPtr->getType()->getPointerElementType();
-        int64_t VTyTSS = (int64_t) DL->getTypeStoreSize(VTy);
+        int64_t VTyTSS = (int64_t)DL.getTypeStoreSize(VTy);
 
         Type *VTy2 = JPtr->getType()->getPointerElementType();
         if (VTy != VTy2 && Offset < 0) {
-          int64_t VTy2TSS = (int64_t) DL->getTypeStoreSize(VTy2);
+          int64_t VTy2TSS = (int64_t)DL.getTypeStoreSize(VTy2);
           OffsetInElmts = Offset/VTy2TSS;
-          return (abs64(Offset) % VTy2TSS) == 0;
+          return (std::abs(Offset) % VTy2TSS) == 0;
         }
 
         OffsetInElmts = Offset/VTyTSS;
-        return (abs64(Offset) % VTyTSS) == 0;
+        return (std::abs(Offset) % VTyTSS) == 0;
       }
 
       return false;
@@ -661,7 +662,7 @@ namespace {
       Function *F = I->getCalledFunction();
       if (!F) return false;
 
-      Intrinsic::ID IID = (Intrinsic::ID) F->getIntrinsicID();
+      Intrinsic::ID IID = F->getIntrinsicID();
       if (!IID) return false;
 
       switch(IID) {
@@ -685,6 +686,8 @@ namespace {
       case Intrinsic::trunc:
       case Intrinsic::floor:
       case Intrinsic::fabs:
+      case Intrinsic::minnum:
+      case Intrinsic::maxnum:
         return Config.VectorizeMath;
       case Intrinsic::bswap:
       case Intrinsic::ctpop:
@@ -839,7 +842,7 @@ namespace {
 
     // It is important to cleanup here so that future iterations of this
     // function have less work to do.
-    (void) SimplifyInstructionsInBlock(&BB, DL, AA->getTargetLibraryInfo());
+    (void)SimplifyInstructionsInBlock(&BB, AA->getTargetLibraryInfo());
     return true;
   }
 
@@ -893,10 +896,6 @@ namespace {
       return false;
     }
 
-    // We can't vectorize memory operations without target data
-    if (!DL && IsSimpleLoadStore)
-      return false;
-
     Type *T1, *T2;
     getInstructionTypes(I, T1, T2);
 
@@ -931,9 +930,8 @@ namespace {
     if (T2->isX86_FP80Ty() || T2->isPPC_FP128Ty() || T2->isX86_MMXTy())
       return false;
 
-    if ((!Config.VectorizePointers || !DL) &&
-        (T1->getScalarType()->isPointerTy() ||
-         T2->getScalarType()->isPointerTy()))
+    if (!Config.VectorizePointers && (T1->getScalarType()->isPointerTy() ||
+                                      T2->getScalarType()->isPointerTy()))
       return false;
 
     if (!TTI && (T1->getPrimitiveSizeInBits() >= Config.VectorBits ||
@@ -978,8 +976,8 @@ namespace {
       unsigned IAlignment, JAlignment, IAddressSpace, JAddressSpace;
       int64_t OffsetInElmts = 0;
       if (getPairPtrInfo(I, J, IPtr, JPtr, IAlignment, JAlignment,
-            IAddressSpace, JAddressSpace,
-            OffsetInElmts) && abs64(OffsetInElmts) == 1) {
+                         IAddressSpace, JAddressSpace, OffsetInElmts) &&
+          std::abs(OffsetInElmts) == 1) {
         FixedOrder = (int) OffsetInElmts;
         unsigned BottomAlignment = IAlignment;
         if (OffsetInElmts < 0) BottomAlignment = JAlignment;
@@ -994,8 +992,8 @@ namespace {
           // An aligned load or store is possible only if the instruction
           // with the lower offset has an alignment suitable for the
           // vector type.
-
-          unsigned VecAlignment = DL->getPrefTypeAlignment(VType);
+          const DataLayout &DL = I->getModule()->getDataLayout();
+          unsigned VecAlignment = DL.getPrefTypeAlignment(VType);
           if (BottomAlignment < VecAlignment)
             return false;
         }
@@ -1100,7 +1098,7 @@ namespace {
     CallInst *CI = dyn_cast<CallInst>(I);
     Function *FI;
     if (CI && (FI = CI->getCalledFunction())) {
-      Intrinsic::ID IID = (Intrinsic::ID) FI->getIntrinsicID();
+      Intrinsic::ID IID = FI->getIntrinsicID();
       if (IID == Intrinsic::powi || IID == Intrinsic::ctlz ||
           IID == Intrinsic::cttz) {
         Value *A1I = CI->getArgOperand(1),
@@ -1275,7 +1273,7 @@ namespace {
             CostSavings, FixedOrder)) continue;
 
         // J is a candidate for merging with I.
-        if (!PairableInsts.size() ||
+        if (PairableInsts.empty() ||
              PairableInsts[PairableInsts.size()-1] != I) {
           PairableInsts.push_back(I);
         }
@@ -2607,7 +2605,6 @@ namespace {
                                                      true, o, 1));
           NewI1->insertBefore(IBeforeJ ? J : I);
           I1 = NewI1;
-          I1T = I2T;
           I1Elem = I2Elem;
         } else if (I1Elem > I2Elem) {
           std::vector<Constant *> Mask(I1Elem);
@@ -2624,8 +2621,6 @@ namespace {
                                                      true, o, 1));
           NewI2->insertBefore(IBeforeJ ? J : I);
           I2 = NewI2;
-          I2T = I1T;
-          I2Elem = I1Elem;
         }
 
         // Now that both I1 and I2 are the same length we can shuffle them
@@ -2775,7 +2770,7 @@ namespace {
         continue;
       } else if (isa<CallInst>(I)) {
         Function *F = cast<CallInst>(I)->getCalledFunction();
-        Intrinsic::ID IID = (Intrinsic::ID) F->getIntrinsicID();
+        Intrinsic::ID IID = F->getIntrinsicID();
         if (o == NumOperands-1) {
           BasicBlock &BB = *I->getParent();
 
@@ -2814,52 +2809,51 @@ namespace {
     if (isa<StoreInst>(I)) {
       AA->replaceWithNewValue(I, K);
       AA->replaceWithNewValue(J, K);
-    } else {
-      Type *IType = I->getType();
-      Type *JType = J->getType();
+      return;
+    }
 
-      VectorType *VType = getVecTypeForPair(IType, JType);
-      unsigned numElem = VType->getNumElements();
+    Type *IType = I->getType();
+    Type *JType = J->getType();
 
-      unsigned numElemI = getNumScalarElements(IType);
-      unsigned numElemJ = getNumScalarElements(JType);
+    VectorType *VType = getVecTypeForPair(IType, JType);
+    unsigned numElem = VType->getNumElements();
 
-      if (IType->isVectorTy()) {
-        std::vector<Constant*> Mask1(numElemI), Mask2(numElemI);
-        for (unsigned v = 0; v < numElemI; ++v) {
-          Mask1[v] = ConstantInt::get(Type::getInt32Ty(Context), v);
-          Mask2[v] = ConstantInt::get(Type::getInt32Ty(Context), numElemJ+v);
-        }
+    unsigned numElemI = getNumScalarElements(IType);
+    unsigned numElemJ = getNumScalarElements(JType);
 
-        K1 = new ShuffleVectorInst(K, UndefValue::get(VType),
-                                   ConstantVector::get( Mask1),
-                                   getReplacementName(K, false, 1));
-      } else {
-        Value *CV0 = ConstantInt::get(Type::getInt32Ty(Context), 0);
-        K1 = ExtractElementInst::Create(K, CV0,
-                                          getReplacementName(K, false, 1));
+    if (IType->isVectorTy()) {
+      std::vector<Constant *> Mask1(numElemI), Mask2(numElemI);
+      for (unsigned v = 0; v < numElemI; ++v) {
+        Mask1[v] = ConstantInt::get(Type::getInt32Ty(Context), v);
+        Mask2[v] = ConstantInt::get(Type::getInt32Ty(Context), numElemJ + v);
       }
 
-      if (JType->isVectorTy()) {
-        std::vector<Constant*> Mask1(numElemJ), Mask2(numElemJ);
-        for (unsigned v = 0; v < numElemJ; ++v) {
-          Mask1[v] = ConstantInt::get(Type::getInt32Ty(Context), v);
-          Mask2[v] = ConstantInt::get(Type::getInt32Ty(Context), numElemI+v);
-        }
+      K1 = new ShuffleVectorInst(K, UndefValue::get(VType),
+                                 ConstantVector::get(Mask1),
+                                 getReplacementName(K, false, 1));
+    } else {
+      Value *CV0 = ConstantInt::get(Type::getInt32Ty(Context), 0);
+      K1 = ExtractElementInst::Create(K, CV0, getReplacementName(K, false, 1));
+    }
 
-        K2 = new ShuffleVectorInst(K, UndefValue::get(VType),
-                                   ConstantVector::get( Mask2),
-                                   getReplacementName(K, false, 2));
-      } else {
-        Value *CV1 = ConstantInt::get(Type::getInt32Ty(Context), numElem-1);
-        K2 = ExtractElementInst::Create(K, CV1,
-                                          getReplacementName(K, false, 2));
+    if (JType->isVectorTy()) {
+      std::vector<Constant *> Mask1(numElemJ), Mask2(numElemJ);
+      for (unsigned v = 0; v < numElemJ; ++v) {
+        Mask1[v] = ConstantInt::get(Type::getInt32Ty(Context), v);
+        Mask2[v] = ConstantInt::get(Type::getInt32Ty(Context), numElemI + v);
       }
 
-      K1->insertAfter(K);
-      K2->insertAfter(K1);
-      InsertionPt = K2;
+      K2 = new ShuffleVectorInst(K, UndefValue::get(VType),
+                                 ConstantVector::get(Mask2),
+                                 getReplacementName(K, false, 2));
+    } else {
+      Value *CV1 = ConstantInt::get(Type::getInt32Ty(Context), numElem - 1);
+      K2 = ExtractElementInst::Create(K, CV1, getReplacementName(K, false, 2));
     }
+
+    K1->insertAfter(K);
+    K2->insertAfter(K1);
+    InsertionPt = K2;
   }
 
   // Move all uses of the function I (including pairing-induced uses) after J.
@@ -3108,7 +3102,17 @@ namespace {
       else if (H->hasName())
         K->takeName(H);
 
-      if (!isa<StoreInst>(K))
+      if (auto CS = CallSite(K)) {
+        SmallVector<Type *, 3> Tys;
+        FunctionType *Old = CS.getFunctionType();
+        unsigned NumOld = Old->getNumParams();
+        assert(NumOld <= ReplacedOperands.size());
+        for (unsigned i = 0; i != NumOld; ++i)
+          Tys.push_back(ReplacedOperands[i]->getType());
+        CS.mutateFunctionType(
+            FunctionType::get(getVecTypeForPair(L->getType(), H->getType()),
+                              Tys, Old->isVarArg()));
+      } else if (!isa<StoreInst>(K))
         K->mutateType(getVecTypeForPair(L->getType(), H->getType()));
 
       unsigned KnownIDs[] = {
@@ -3193,7 +3197,7 @@ char BBVectorize::ID = 0;
 static const char bb_vectorize_name[] = "Basic-Block Vectorization";
 INITIALIZE_PASS_BEGIN(BBVectorize, BBV_NAME, bb_vectorize_name, false, false)
 INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
-INITIALIZE_AG_DEPENDENCY(TargetTransformInfo)
+INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
 INITIALIZE_PASS_END(BBVectorize, BBV_NAME, bb_vectorize_name, false, false)
@@ -3204,7 +3208,7 @@ BasicBlockPass *llvm::createBBVectorizePass(const VectorizeConfig &C) {
 
 bool
 llvm::vectorizeBasicBlock(Pass *P, BasicBlock &BB, const VectorizeConfig &C) {
-  BBVectorize BBVectorizer(P, C);
+  BBVectorize BBVectorizer(P, *BB.getParent(), C);
   return BBVectorizer.vectorizeBB(BB);
 }