[LoopVectorize] Shrink integer operations into the smallest type possible
authorJames Molloy <james.molloy@arm.com>
Mon, 12 Oct 2015 12:34:45 +0000 (12:34 +0000)
committerJames Molloy <james.molloy@arm.com>
Mon, 12 Oct 2015 12:34:45 +0000 (12:34 +0000)
C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int
type (e.g. i32) whenever arithmetic is performed on them.

For targets with native i8 or i16 operations, usually InstCombine can shrink
the arithmetic type down again. However InstCombine refuses to create illegal
types, so for targets without i8 or i16 registers, the lengthening and
shrinking remains.

Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when
their scalar equivalents do not, so during vectorization it is important to
remove these lengthens and truncates when deciding the profitability of
vectorization.

The algorithm this uses starts at truncs and icmps, trawling their use-def
chains until they terminate or instructions outside the loop are found (or
unsafe instructions like inttoptr casts are found). If the use-def chains
starting from different root instructions (truncs/icmps) meet, they are
unioned. The demanded bits of each node in the graph are ORed together to form
an overall mask of the demanded bits in the entire graph. The minimum bitwidth
that graph can be truncated to is the bitwidth minus the number of leading
zeroes in the overall mask.

The intention is that this algorithm should "first do no harm", so it will
never insert extra cast instructions. This is why the use-def graphs are
unioned, so that subgraphs with different minimum bitwidths do not need casts
inserted between them.

This algorithm works hard to reduce compile time impact. DemandedBits are only
queried if there are extends of illegal types and if a truncate to an illegal
type is seen. In the general case, this results in a simple linear scan of the
instructions in the loop.

No non-noise compile time impact was seen on a clang bootstrap build.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@250032 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/VectorUtils.h
lib/Analysis/VectorUtils.cpp
lib/Transforms/Vectorize/LoopVectorize.cpp
test/Transforms/LoopVectorize/AArch64/loop-vectorization-factors.ll [new file with mode: 0644]

index 125a56a530bc0ea10202eb0690200889a45a46d7..71c2ebe2136344163b95b7ee6c3e4a9aa684305a 100644 (file)
 #ifndef LLVM_TRANSFORMS_UTILS_VECTORUTILS_H
 #define LLVM_TRANSFORMS_UTILS_VECTORUTILS_H
 
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 
 namespace llvm {
 
+struct DemandedBits;
 class GetElementPtrInst;
 class Loop;
 class ScalarEvolution;
+class TargetTransformInfo;
 class Type;
 class Value;
 
@@ -84,6 +87,45 @@ Value *findScalarElement(Value *V, unsigned EltNo);
 /// a sequence of instructions that broadcast a single value into a vector.
 Value *getSplatValue(Value *V);
 
+/// \brief Compute a map of integer instructions to their minimum legal type
+/// size.
+///
+/// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int
+/// type (e.g. i32) whenever arithmetic is performed on them.
+///
+/// For targets with native i8 or i16 operations, usually InstCombine can shrink
+/// the arithmetic type down again. However InstCombine refuses to create
+/// illegal types, so for targets without i8 or i16 registers, the lengthening
+/// and shrinking remains.
+///
+/// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when
+/// their scalar equivalents do not, so during vectorization it is important to
+/// remove these lengthens and truncates when deciding the profitability of
+/// vectorization.
+///
+/// This function analyzes the given range of instructions and determines the
+/// minimum type size each can be converted to. It attempts to remove or
+/// minimize type size changes across each def-use chain, so for example in the
+/// following code:
+///
+///   %1 = load i8, i8*
+///   %2 = add i8 %1, 2
+///   %3 = load i16, i16*
+///   %4 = zext i8 %2 to i32
+///   %5 = zext i16 %3 to i32
+///   %6 = add i32 %4, %5
+///   %7 = trunc i32 %6 to i16
+///
+/// Instruction %6 must be done at least in i16, so computeMinimumValueSizes
+/// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}.
+///
+/// If the optional TargetTransformInfo is provided, this function tries harder
+/// to do less work by only looking at illegal types.
+DenseMap<Instruction*, uint64_t>
+computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks,
+                         DemandedBits &DB,
+                         const TargetTransformInfo *TTI=nullptr);
+    
 } // llvm namespace
 
 #endif
index 93720857662f988a8e0a046cd10cdfeb7afbe165..4153c843c40eba11f98bfc70ded54514e20b04a9 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
+#include "llvm/ADT/EquivalenceClasses.h"
+#include "llvm/Analysis/DemandedBits.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/GetElementPtrTypeIterator.h"
 #include "llvm/IR/PatternMatch.h"
@@ -434,3 +437,130 @@ llvm::Value *llvm::getSplatValue(Value *V) {
 
   return InsertEltInst->getOperand(1);
 }
+
+DenseMap<Instruction*, uint64_t> llvm::computeMinimumValueSizes(
+  ArrayRef<BasicBlock*> Blocks, DemandedBits &DB,
+  const TargetTransformInfo *TTI) {
+
+  // DemandedBits will give us every value's live-out bits. But we want
+  // to ensure no extra casts would need to be inserted, so every DAG
+  // of connected values must have the same minimum bitwidth.
+  EquivalenceClasses<Value*> ECs;
+  SmallVector<Value*,16> Worklist;
+  SmallPtrSet<Value*,4> Roots;
+  SmallPtrSet<Value*,16> Visited;
+  DenseMap<Value*,uint64_t> DBits;
+  SmallPtrSet<Instruction*,4> InstructionSet;
+  DenseMap<Instruction*, uint64_t> MinBWs;
+  
+  // Determine the roots. We work bottom-up, from truncs or icmps.
+  bool SeenExtFromIllegalType = false;
+  for (auto *BB : Blocks)
+    for (auto &I : *BB) {
+      InstructionSet.insert(&I);
+
+      if (TTI && (isa<ZExtInst>(&I) || isa<SExtInst>(&I)) &&
+          !TTI->isTypeLegal(I.getOperand(0)->getType()))
+        SeenExtFromIllegalType = true;
+    
+      // Only deal with non-vector integers up to 64-bits wide.
+      if ((isa<TruncInst>(&I) || isa<ICmpInst>(&I)) &&
+          !I.getType()->isVectorTy() &&
+          I.getOperand(0)->getType()->getScalarSizeInBits() <= 64) {
+        // Don't make work for ourselves. If we know the loaded type is legal,
+        // don't add it to the worklist.
+        if (TTI && isa<TruncInst>(&I) && TTI->isTypeLegal(I.getType()))
+          continue;
+      
+        Worklist.push_back(&I);
+        Roots.insert(&I);
+      }
+    }
+  // Early exit.
+  if (Worklist.empty() || (TTI && !SeenExtFromIllegalType))
+    return MinBWs;
+  
+  // Now proceed breadth-first, unioning values together.
+  while (!Worklist.empty()) {
+    Value *Val = Worklist.pop_back_val();
+    Value *Leader = ECs.getOrInsertLeaderValue(Val);
+    
+    if (Visited.count(Val))
+      continue;
+    Visited.insert(Val);
+
+    // Non-instructions terminate a chain successfully.
+    if (!isa<Instruction>(Val))
+      continue;
+    Instruction *I = cast<Instruction>(Val);
+
+    // If we encounter a type that is larger than 64 bits, we can't represent
+    // it so bail out.
+    if (DB.getDemandedBits(I).getBitWidth() > 64)
+      return DenseMap<Instruction*,uint64_t>();
+    
+    uint64_t V = DB.getDemandedBits(I).getZExtValue();
+    DBits[Leader] |= V;
+    
+    // Casts, loads and instructions outside of our range terminate a chain
+    // successfully.
+    if (isa<SExtInst>(I) || isa<ZExtInst>(I) || isa<LoadInst>(I) ||
+        !InstructionSet.count(I))
+      continue;
+
+    // Unsafe casts terminate a chain unsuccessfully. We can't do anything
+    // useful with bitcasts, ptrtoints or inttoptrs and it'd be unsafe to
+    // transform anything that relies on them.
+    if (isa<BitCastInst>(I) || isa<PtrToIntInst>(I) || isa<IntToPtrInst>(I) ||
+        !I->getType()->isIntegerTy()) {
+      DBits[Leader] |= ~0ULL;
+      continue;
+    }
+
+    // We don't modify the types of PHIs. Reductions will already have been
+    // truncated if possible, and inductions' sizes will have been chosen by
+    // indvars.
+    if (isa<PHINode>(I))
+      continue;
+
+    if (DBits[Leader] == ~0ULL)
+      // All bits demanded, no point continuing.
+      continue;
+
+    for (Value *O : cast<User>(I)->operands()) {
+      ECs.unionSets(Leader, O);
+      Worklist.push_back(O);
+    }
+  }
+
+  // Now we've discovered all values, walk them to see if there are
+  // any users we didn't see. If there are, we can't optimize that
+  // chain.
+  for (auto &I : DBits)
+    for (auto *U : I.first->users())
+      if (U->getType()->isIntegerTy() && DBits.count(U) == 0)
+        DBits[ECs.getOrInsertLeaderValue(I.first)] |= ~0ULL;
+  
+  for (auto I = ECs.begin(), E = ECs.end(); I != E; ++I) {
+    uint64_t LeaderDemandedBits = 0;
+    for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI)
+      LeaderDemandedBits |= DBits[*MI];
+
+    uint64_t MinBW = (sizeof(LeaderDemandedBits) * 8) -
+                     llvm::countLeadingZeros(LeaderDemandedBits);
+    // Round up to a power of 2
+    if (!isPowerOf2_64((uint64_t)MinBW))
+      MinBW = NextPowerOf2(MinBW);
+    for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) {
+      if (!isa<Instruction>(*MI))
+        continue;
+      Type *Ty = (*MI)->getType();
+      if (Roots.count(*MI))
+        Ty = cast<Instruction>(*MI)->getOperand(0)->getType();
+      if (MinBW < Ty->getScalarSizeInBits())
+        MinBWs[cast<Instruction>(*MI)] = MinBW;
+    }
+  }
+
+  return MinBWs;
+}
index bd381d31de7568eb5c5529c61d4f9fe29434ea8e..ec91e138e8642babd168141cfc8d2f5e5c9d14e0 100644 (file)
@@ -48,7 +48,6 @@
 
 #include "llvm/Transforms/Vectorize.h"
 #include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/EquivalenceClasses.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SetVector.h"
@@ -63,6 +62,7 @@
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/BlockFrequencyInfo.h"
 #include "llvm/Analysis/CodeMetrics.h"
+#include "llvm/Analysis/DemandedBits.h"
 #include "llvm/Analysis/GlobalsModRef.h"
 #include "llvm/Analysis/LoopAccessAnalysis.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/Transforms/Utils/LoopUtils.h"
 #include <algorithm>
+#include <functional>
 #include <map>
 #include <tuple>
 
@@ -280,7 +281,12 @@ public:
         AddedSafetyChecks(false) {}
 
   // Perform the actual loop widening (vectorization).
-  void vectorize(LoopVectorizationLegality *L) {
+  // MinimumBitWidths maps scalar integer values to the smallest bitwidth they
+  // can be validly truncated to. The cost model has assumed this truncation
+  // will happen when vectorizing.
+  void vectorize(LoopVectorizationLegality *L,
+                 DenseMap<Instruction*,uint64_t> MinimumBitWidths) {
+    MinBWs = MinimumBitWidths;
     Legal = L;
     // Create a new empty loop. Unlink the old loop and connect the new one.
     createEmptyLoop();
@@ -329,6 +335,9 @@ protected:
   /// See PR14725.
   void fixLCSSAPHIs();
 
+  /// Shrinks vector element sizes based on information in "MinBWs".
+  void truncateToMinimalBitwidths();
+  
   /// A helper function that computes the predicate of the block BB, assuming
   /// that the header block of the loop is set to True. It returns the *entry*
   /// mask for the block BB.
@@ -339,7 +348,7 @@ protected:
 
   /// A helper function to vectorize a single BB within the innermost loop.
   void vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV);
-
+  
   /// Vectorize a single PHINode in a block. This method handles the induction
   /// variable canonicalization. It supports both VF = 1 for unrolled loops and
   /// arbitrary length vectors.
@@ -499,6 +508,10 @@ protected:
   /// Trip count of the widened loop (TripCount - TripCount % (VF*UF))
   Value *VectorTripCount;
 
+  /// Map of scalar integer values to the smallest bitwidth they can be legally
+  /// represented as. The vector equivalents of these values should be truncated
+  /// to this type.
+  DenseMap<Instruction*,uint64_t> MinBWs;
   LoopVectorizationLegality *Legal;
 
   // Record whether runtime check is added.
@@ -1346,10 +1359,11 @@ public:
   LoopVectorizationCostModel(Loop *L, ScalarEvolution *SE, LoopInfo *LI,
                              LoopVectorizationLegality *Legal,
                              const TargetTransformInfo &TTI,
-                             const TargetLibraryInfo *TLI, AssumptionCache *AC,
+                             const TargetLibraryInfo *TLI, DemandedBits *DB,
+                             AssumptionCache *AC,
                              const Function *F, const LoopVectorizeHints *Hints,
                              SmallPtrSetImpl<const Value *> &ValuesToIgnore)
-      : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI),
+      : TheLoop(L), SE(SE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB),
         TheFunction(F), Hints(Hints), ValuesToIgnore(ValuesToIgnore) {}
 
   /// Information about vectorization costs
@@ -1419,6 +1433,12 @@ private:
     emitAnalysisDiag(TheFunction, TheLoop, *Hints, Message);
   }
 
+public:
+  /// Map of scalar integer values to the smallest bitwidth they can be legally
+  /// represented as. The vector equivalents of these values should be truncated
+  /// to this type.
+  DenseMap<Instruction*,uint64_t> MinBWs;
+
   /// The loop that we evaluate.
   Loop *TheLoop;
   /// Scev analysis.
@@ -1431,6 +1451,8 @@ private:
   const TargetTransformInfo &TTI;
   /// Target Library Info.
   const TargetLibraryInfo *TLI;
+  /// Demanded bits analysis
+  DemandedBits *DB;
   const Function *TheFunction;
   // Loop Vectorize Hint.
   const LoopVectorizeHints *Hints;
@@ -1523,6 +1545,7 @@ struct LoopVectorize : public FunctionPass {
   DominatorTree *DT;
   BlockFrequencyInfo *BFI;
   TargetLibraryInfo *TLI;
+  DemandedBits *DB;
   AliasAnalysis *AA;
   AssumptionCache *AC;
   LoopAccessAnalysis *LAA;
@@ -1542,6 +1565,7 @@ struct LoopVectorize : public FunctionPass {
     AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
     AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
     LAA = &getAnalysis<LoopAccessAnalysis>();
+    DB = &getAnalysis<DemandedBits>();
 
     // Compute some weights outside of the loop over the loops. Compute this
     // using a BranchProbability to re-use its scaling math.
@@ -1687,7 +1711,7 @@ struct LoopVectorize : public FunctionPass {
     }
 
     // Use the cost model.
-    LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, AC, F, &Hints,
+    LoopVectorizationCostModel CM(L, SE, LI, &LVL, *TTI, TLI, DB, AC, F, &Hints,
                                   ValuesToIgnore);
 
     // Check the function attributes to find out if this function should be
@@ -1800,7 +1824,7 @@ struct LoopVectorize : public FunctionPass {
       // If we decided that it is not legal to vectorize the loop then
       // interleave it.
       InnerLoopUnroller Unroller(L, SE, LI, DT, TLI, TTI, IC);
-      Unroller.vectorize(&LVL);
+      Unroller.vectorize(&LVL, CM.MinBWs);
 
       emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(),
                              Twine("interleaved loop (interleaved count: ") +
@@ -1808,7 +1832,7 @@ struct LoopVectorize : public FunctionPass {
     } else {
       // If we decided that it is *legal* to vectorize the loop then do it.
       InnerLoopVectorizer LB(L, SE, LI, DT, TLI, TTI, VF.Width, IC);
-      LB.vectorize(&LVL);
+      LB.vectorize(&LVL, CM.MinBWs);
       ++LoopsVectorized;
 
       // Add metadata to disable runtime unrolling scalar loop when there's no
@@ -1842,6 +1866,7 @@ struct LoopVectorize : public FunctionPass {
     AU.addRequired<TargetTransformInfoWrapperPass>();
     AU.addRequired<AAResultsWrapperPass>();
     AU.addRequired<LoopAccessAnalysis>();
+    AU.addRequired<DemandedBits>();
     AU.addPreserved<LoopInfoWrapperPass>();
     AU.addPreserved<DominatorTreeWrapperPass>();
     AU.addPreserved<BasicAAWrapperPass>();
@@ -2009,6 +2034,7 @@ InnerLoopVectorizer::getVectorValue(Value *V) {
   // If this scalar is unknown, assume that it is a constant or that it is
   // loop invariant. Broadcast V and save the value for future uses.
   Value *B = getBroadcastInstrs(V);
+
   return WidenMap.splat(V, B);
 }
 
@@ -3102,6 +3128,117 @@ static unsigned getVectorIntrinsicCost(CallInst *CI, unsigned VF,
   return TTI.getIntrinsicInstrCost(ID, RetTy, Tys);
 }
 
+static Type *smallestIntegerVectorType(Type *T1, Type *T2) {
+  IntegerType *I1 = cast<IntegerType>(T1->getVectorElementType());
+  IntegerType *I2 = cast<IntegerType>(T2->getVectorElementType());
+  return I1->getBitWidth() < I2->getBitWidth() ? T1 : T2;
+}
+static Type *largestIntegerVectorType(Type *T1, Type *T2) {
+  IntegerType *I1 = cast<IntegerType>(T1->getVectorElementType());
+  IntegerType *I2 = cast<IntegerType>(T2->getVectorElementType());
+  return I1->getBitWidth() > I2->getBitWidth() ? T1 : T2;
+}
+
+void InnerLoopVectorizer::truncateToMinimalBitwidths() {
+  // For every instruction `I` in MinBWs, truncate the operands, create a
+  // truncated version of `I` and reextend its result. InstCombine runs
+  // later and will remove any ext/trunc pairs.
+  //
+  for (auto &KV : MinBWs) {
+    VectorParts &Parts = WidenMap.get(KV.first);
+    for (Value *&I : Parts) {
+      if (I->use_empty())
+        continue;
+      Type *OriginalTy = I->getType();
+      Type *ScalarTruncatedTy = IntegerType::get(OriginalTy->getContext(),
+                                                 KV.second);
+      Type *TruncatedTy = VectorType::get(ScalarTruncatedTy,
+                                          OriginalTy->getVectorNumElements());
+      if (TruncatedTy == OriginalTy)
+        continue;
+
+      IRBuilder<> B(cast<Instruction>(I));
+      auto ShrinkOperand = [&](Value *V) -> Value* {
+        if (auto *ZI = dyn_cast<ZExtInst>(V))
+          if (ZI->getSrcTy() == TruncatedTy)
+            return ZI->getOperand(0);
+        return B.CreateZExtOrTrunc(V, TruncatedTy);
+      };
+
+      // The actual instruction modification depends on the instruction type,
+      // unfortunately.
+      Value *NewI = nullptr;
+      if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I)) {
+        NewI = B.CreateBinOp(BO->getOpcode(),
+                             ShrinkOperand(BO->getOperand(0)),
+                             ShrinkOperand(BO->getOperand(1)));
+        cast<BinaryOperator>(NewI)->copyIRFlags(I);
+      } else if (ICmpInst *CI = dyn_cast<ICmpInst>(I)) {
+        NewI = B.CreateICmp(CI->getPredicate(),
+                            ShrinkOperand(CI->getOperand(0)),
+                            ShrinkOperand(CI->getOperand(1)));
+      } else if (SelectInst *SI = dyn_cast<SelectInst>(I)) {
+        NewI = B.CreateSelect(SI->getCondition(),
+                              ShrinkOperand(SI->getTrueValue()),
+                              ShrinkOperand(SI->getFalseValue()));
+      } else if (CastInst *CI = dyn_cast<CastInst>(I)) {
+        switch (CI->getOpcode()) {
+        default: llvm_unreachable("Unhandled cast!");
+        case Instruction::Trunc:
+          NewI = ShrinkOperand(CI->getOperand(0));
+          break;
+        case Instruction::SExt:
+          NewI = B.CreateSExtOrTrunc(CI->getOperand(0),
+                                     smallestIntegerVectorType(OriginalTy,
+                                                               TruncatedTy));
+          break;
+        case Instruction::ZExt:
+          NewI = B.CreateZExtOrTrunc(CI->getOperand(0),
+                                     smallestIntegerVectorType(OriginalTy,
+                                                               TruncatedTy));
+          break;
+        }
+      } else if (ShuffleVectorInst *SI = dyn_cast<ShuffleVectorInst>(I)) {
+        auto Elements0 = SI->getOperand(0)->getType()->getVectorNumElements();
+        auto *O0 =
+          B.CreateZExtOrTrunc(SI->getOperand(0),
+                              VectorType::get(ScalarTruncatedTy, Elements0));
+        auto Elements1 = SI->getOperand(1)->getType()->getVectorNumElements();
+        auto *O1 =
+          B.CreateZExtOrTrunc(SI->getOperand(1),
+                              VectorType::get(ScalarTruncatedTy, Elements1));
+
+        NewI = B.CreateShuffleVector(O0, O1, SI->getMask());
+      } else if (isa<LoadInst>(I)) {
+        // Don't do anything with the operands, just extend the result.
+        continue;
+      } else {
+        llvm_unreachable("Unhandled instruction type!");
+      }
+
+      // Lastly, extend the result.
+      NewI->takeName(cast<Instruction>(I));
+      Value *Res = B.CreateZExtOrTrunc(NewI, OriginalTy);
+      I->replaceAllUsesWith(Res);
+      cast<Instruction>(I)->eraseFromParent();
+      I = Res;
+    }
+  }
+
+  // We'll have created a bunch of ZExts that are now parentless. Clean up.
+  for (auto &KV : MinBWs) {
+    VectorParts &Parts = WidenMap.get(KV.first);
+    for (Value *&I : Parts) {
+      ZExtInst *Inst = dyn_cast<ZExtInst>(I);
+      if (Inst && Inst->use_empty()) {
+        Value *NewI = Inst->getOperand(0);
+        Inst->eraseFromParent();
+        I = NewI;
+      }
+    }
+  }
+}
+
 void InnerLoopVectorizer::vectorizeLoop() {
   //===------------------------------------------------===//
   //
@@ -3132,6 +3269,11 @@ void InnerLoopVectorizer::vectorizeLoop() {
        be = DFS.endRPO(); bb != be; ++bb)
     vectorizeBlockInLoop(*bb, &RdxPHIsToFix);
 
+  // Insert truncates and extends for any truncated instructions as hints to
+  // InstCombine.
+  if (VF > 1)
+    truncateToMinimalBitwidths();
+  
   // At this point every instruction in the original loop is widened to
   // a vector form. We are almost done. Now, we need to fix the PHI nodes
   // that we vectorized. The PHI nodes are currently empty because we did
@@ -3565,6 +3707,7 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) {
   // For each instruction in the old loop.
   for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
     VectorParts &Entry = WidenMap.get(it);
+
     switch (it->getOpcode()) {
     case Instruction::Br:
       // Nothing to do for PHIs and BR, since we already took care of the
@@ -3628,7 +3771,7 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) {
       VectorParts &Cond = getVectorValue(it->getOperand(0));
       VectorParts &Op0  = getVectorValue(it->getOperand(1));
       VectorParts &Op1  = getVectorValue(it->getOperand(2));
-
+      
       Value *ScalarCond = (VF == 1) ? Cond[0] :
         Builder.CreateExtractElement(Cond[0], Builder.getInt32(0));
 
@@ -4563,6 +4706,7 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) {
   unsigned TC = SE->getSmallConstantTripCount(TheLoop);
   DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n');
 
+  MinBWs = computeMinimumValueSizes(TheLoop->getBlocks(), *DB, &TTI);
   unsigned WidestType = getWidestType();
   unsigned WidestRegister = TTI.getRegisterBitWidth(true);
   unsigned MaxSafeDepDist = -1U;
@@ -5086,6 +5230,8 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) {
     VF = 1;
 
   Type *RetTy = I->getType();
+  if (VF > 1 && MinBWs.count(I))
+    RetTy = IntegerType::get(RetTy->getContext(), MinBWs[I]);
   Type *VectorTy = ToVectorTy(RetTy, VF);
 
   // TODO: We need to estimate the cost of intrinsic calls.
@@ -5168,6 +5314,8 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) {
   case Instruction::ICmp:
   case Instruction::FCmp: {
     Type *ValTy = I->getOperand(0)->getType();
+    if (VF > 1 && MinBWs.count(dyn_cast<Instruction>(I->getOperand(0))))
+      ValTy = IntegerType::get(ValTy->getContext(), MinBWs[I]);
     VectorTy = ToVectorTy(ValTy, VF);
     return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy);
   }
@@ -5291,8 +5439,28 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) {
         Legal->isInductionVariable(I->getOperand(0)))
       return TTI.getCastInstrCost(I->getOpcode(), I->getType(),
                                   I->getOperand(0)->getType());
-
-    Type *SrcVecTy = ToVectorTy(I->getOperand(0)->getType(), VF);
+    
+    Type *SrcScalarTy = I->getOperand(0)->getType();
+    Type *SrcVecTy = ToVectorTy(SrcScalarTy, VF);
+    if (VF > 1 && MinBWs.count(I)) {
+      // This cast is going to be shrunk. This may remove the cast or it might
+      // turn it into slightly different cast. For example, if MinBW == 16,
+      // "zext i8 %1 to i32" becomes "zext i8 %1 to i16".
+      //
+      // Calculate the modified src and dest types.
+      Type *MinVecTy = VectorTy;
+      if (I->getOpcode() == Instruction::Trunc) {
+        SrcVecTy = smallestIntegerVectorType(SrcVecTy, MinVecTy);
+        VectorTy = largestIntegerVectorType(ToVectorTy(I->getType(), VF),
+                                            MinVecTy);
+      } else if (I->getOpcode() == Instruction::ZExt ||
+                 I->getOpcode() == Instruction::SExt) {
+        SrcVecTy = largestIntegerVectorType(SrcVecTy, MinVecTy);
+        VectorTy = smallestIntegerVectorType(ToVectorTy(I->getType(), VF),
+                                             MinVecTy);
+      }
+    }
+    
     return TTI.getCastInstrCost(I->getOpcode(), VectorTy, SrcVecTy);
   }
   case Instruction::Call: {
@@ -5343,6 +5511,7 @@ INITIALIZE_PASS_DEPENDENCY(LCSSA)
 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
 INITIALIZE_PASS_DEPENDENCY(LoopAccessAnalysis)
+INITIALIZE_PASS_DEPENDENCY(DemandedBits)
 INITIALIZE_PASS_END(LoopVectorize, LV_NAME, lv_name, false, false)
 
 namespace llvm {
diff --git a/test/Transforms/LoopVectorize/AArch64/loop-vectorization-factors.ll b/test/Transforms/LoopVectorize/AArch64/loop-vectorization-factors.ll
new file mode 100644 (file)
index 0000000..f5b6a64
--- /dev/null
@@ -0,0 +1,243 @@
+; RUN: opt -S < %s -basicaa -loop-vectorize -simplifycfg -instsimplify -instcombine -licm -force-vector-interleave=1 2>&1 | FileCheck %s
+
+target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64"
+
+; CHECK-LABEL: @add_a(
+; CHECK: load <16 x i8>, <16 x i8>*
+; CHECK: add nuw nsw <16 x i8>
+; CHECK: store <16 x i8>
+; Function Attrs: nounwind
+define void @add_a(i8* noalias nocapture readonly %p, i8* noalias nocapture %q, i32 %len) #0 {
+entry:
+  %cmp8 = icmp sgt i32 %len, 0
+  br i1 %cmp8, label %for.body, label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  ret void
+
+for.body:                                         ; preds = %entry, %for.body
+  %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i8, i8* %p, i64 %indvars.iv
+  %0 = load i8, i8* %arrayidx
+  %conv = zext i8 %0 to i32
+  %add = add nuw nsw i32 %conv, 2
+  %conv1 = trunc i32 %add to i8
+  %arrayidx3 = getelementptr inbounds i8, i8* %q, i64 %indvars.iv
+  store i8 %conv1, i8* %arrayidx3
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %lftr.wideiv = trunc i64 %indvars.iv.next to i32
+  %exitcond = icmp eq i32 %lftr.wideiv, %len
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+}
+
+; CHECK-LABEL: @add_b(
+; CHECK: load <8 x i16>, <8 x i16>*
+; CHECK: add nuw nsw <8 x i16>
+; CHECK: store <8 x i16>
+; Function Attrs: nounwind
+define void @add_b(i16* noalias nocapture readonly %p, i16* noalias nocapture %q, i32 %len) #0 {
+entry:
+  %cmp9 = icmp sgt i32 %len, 0
+  br i1 %cmp9, label %for.body, label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  ret void
+
+for.body:                                         ; preds = %entry, %for.body
+  %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i16, i16* %p, i64 %indvars.iv
+  %0 = load i16, i16* %arrayidx
+  %conv8 = zext i16 %0 to i32
+  %add = add nuw nsw i32 %conv8, 2
+  %conv1 = trunc i32 %add to i16
+  %arrayidx3 = getelementptr inbounds i16, i16* %q, i64 %indvars.iv
+  store i16 %conv1, i16* %arrayidx3
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %lftr.wideiv = trunc i64 %indvars.iv.next to i32
+  %exitcond = icmp eq i32 %lftr.wideiv, %len
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+}
+
+; CHECK-LABEL: @add_c(
+; CHECK: load <8 x i8>, <8 x i8>*
+; CHECK: add nuw nsw <8 x i16>
+; CHECK: store <8 x i16>
+; Function Attrs: nounwind
+define void @add_c(i8* noalias nocapture readonly %p, i16* noalias nocapture %q, i32 %len) #0 {
+entry:
+  %cmp8 = icmp sgt i32 %len, 0
+  br i1 %cmp8, label %for.body, label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  ret void
+
+for.body:                                         ; preds = %entry, %for.body
+  %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i8, i8* %p, i64 %indvars.iv
+  %0 = load i8, i8* %arrayidx
+  %conv = zext i8 %0 to i32
+  %add = add nuw nsw i32 %conv, 2
+  %conv1 = trunc i32 %add to i16
+  %arrayidx3 = getelementptr inbounds i16, i16* %q, i64 %indvars.iv
+  store i16 %conv1, i16* %arrayidx3
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %lftr.wideiv = trunc i64 %indvars.iv.next to i32
+  %exitcond = icmp eq i32 %lftr.wideiv, %len
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+}
+
+; CHECK-LABEL: @add_d(
+; CHECK: load <4 x i16>
+; CHECK: add nsw <4 x i32>
+; CHECK: store <4 x i32>
+define void @add_d(i16* noalias nocapture readonly %p, i32* noalias nocapture %q, i32 %len) #0 {
+entry:
+  %cmp7 = icmp sgt i32 %len, 0
+  br i1 %cmp7, label %for.body, label %for.cond.cleanup
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  ret void
+
+for.body:                                         ; preds = %entry, %for.body
+  %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i16, i16* %p, i64 %indvars.iv
+  %0 = load i16, i16* %arrayidx
+  %conv = sext i16 %0 to i32
+  %add = add nsw i32 %conv, 2
+  %arrayidx2 = getelementptr inbounds i32, i32* %q, i64 %indvars.iv
+  store i32 %add, i32* %arrayidx2
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %lftr.wideiv = trunc i64 %indvars.iv.next to i32
+  %exitcond = icmp eq i32 %lftr.wideiv, %len
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+}
+
+; CHECK-LABEL: @add_e(
+; CHECK: load <16 x i8>
+; CHECK: shl <16 x i8>
+; CHECK: add nuw nsw <16 x i8>
+; CHECK: or <16 x i8>
+; CHECK: mul nuw nsw <16 x i8>
+; CHECK: and <16 x i8>
+; CHECK: xor <16 x i8>
+; CHECK: mul nuw nsw <16 x i8>
+; CHECK: store <16 x i8>
+define void @add_e(i8* noalias nocapture readonly %p, i8* noalias nocapture %q, i8 %arg1, i8 %arg2, i32 %len) #0 {
+entry:
+  %cmp.32 = icmp sgt i32 %len, 0
+  br i1 %cmp.32, label %for.body.lr.ph, label %for.cond.cleanup
+
+for.body.lr.ph:                                   ; preds = %entry
+  %conv11 = zext i8 %arg2 to i32
+  %conv13 = zext i8 %arg1 to i32
+  br label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  ret void
+
+for.body:                                         ; preds = %for.body, %for.body.lr.ph
+  %indvars.iv = phi i64 [ 0, %for.body.lr.ph ], [ %indvars.iv.next, %for.body ]
+  %arrayidx = getelementptr inbounds i8, i8* %p, i64 %indvars.iv
+  %0 = load i8, i8* %arrayidx
+  %conv = zext i8 %0 to i32
+  %add = shl i32 %conv, 4
+  %conv2 = add nuw nsw i32 %add, 32
+  %or = or i32 %conv, 51
+  %mul = mul nuw nsw i32 %or, 60
+  %and = and i32 %conv2, %conv13
+  %mul.masked = and i32 %mul, 252
+  %conv17 = xor i32 %mul.masked, %conv11
+  %mul18 = mul nuw nsw i32 %conv17, %and
+  %conv19 = trunc i32 %mul18 to i8
+  %arrayidx21 = getelementptr inbounds i8, i8* %q, i64 %indvars.iv
+  store i8 %conv19, i8* %arrayidx21
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %lftr.wideiv = trunc i64 %indvars.iv.next to i32
+  %exitcond = icmp eq i32 %lftr.wideiv, %len
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+}
+
+; CHECK-LABEL: @add_f
+; CHECK: load <8 x i16>
+; CHECK: trunc <8 x i16>
+; CHECK: shl <8 x i8>
+; CHECK: add nsw <8 x i8>
+; CHECK: or <8 x i8>
+; CHECK: mul nuw nsw <8 x i8>
+; CHECK: and <8 x i8>
+; CHECK: xor <8 x i8>
+; CHECK: mul nuw nsw <8 x i8>
+; CHECK: store <8 x i8>
+define void @add_f(i16* noalias nocapture readonly %p, i8* noalias nocapture %q, i8 %arg1, i8 %arg2, i32 %len) #0 {
+entry:
+  %cmp.32 = icmp sgt i32 %len, 0
+  br i1 %cmp.32, label %for.body.lr.ph, label %for.cond.cleanup
+
+for.body.lr.ph:                                   ; preds = %entry
+  %conv11 = zext i8 %arg2 to i32
+  %conv13 = zext i8 %arg1 to i32
+  br label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.body, %entry
+  ret void
+
+for.body:                                         ; preds = %for.body, %for.body.lr.ph
+  %indvars.iv = phi i64 [ 0, %for.body.lr.ph ], [ %indvars.iv.next, %for.body ]
+  %arrayidx = getelementptr inbounds i16, i16* %p, i64 %indvars.iv
+  %0 = load i16, i16* %arrayidx
+  %conv = sext i16 %0 to i32
+  %add = shl i32 %conv, 4
+  %conv2 = add nsw i32 %add, 32
+  %or = and i32 %conv, 204
+  %conv8 = or i32 %or, 51
+  %mul = mul nuw nsw i32 %conv8, 60
+  %and = and i32 %conv2, %conv13
+  %mul.masked = and i32 %mul, 252
+  %conv17 = xor i32 %mul.masked, %conv11
+  %mul18 = mul nuw nsw i32 %conv17, %and
+  %conv19 = trunc i32 %mul18 to i8
+  %arrayidx21 = getelementptr inbounds i8, i8* %q, i64 %indvars.iv
+  store i8 %conv19, i8* %arrayidx21
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %lftr.wideiv = trunc i64 %indvars.iv.next to i32
+  %exitcond = icmp eq i32 %lftr.wideiv, %len
+  br i1 %exitcond, label %for.cond.cleanup, label %for.body
+}
+
+; CHECK-LABEL: @add_g
+; CHECK: load <16 x i8>
+; CHECK: xor <16 x i8>
+; CHECK: icmp ult <16 x i8>
+; CHECK: select <16 x i1> {{.*}}, <16 x i8>
+; CHECK: store <16 x i8>
+define void @add_g(i8* noalias nocapture readonly %p, i8* noalias nocapture readonly %q, i8* noalias nocapture %r, i8 %arg1, i32 %len) #0 {
+  %1 = icmp sgt i32 %len, 0
+  br i1 %1, label %.lr.ph, label %._crit_edge
+
+.lr.ph:                                           ; preds = %0
+  %2 = sext i8 %arg1 to i64
+  br label %3
+
+._crit_edge:                                      ; preds = %3, %0
+  ret void
+
+; <label>:3                                       ; preds = %3, %.lr.ph
+  %indvars.iv = phi i64 [ 0, %.lr.ph ], [ %indvars.iv.next, %3 ]
+  %x4 = getelementptr inbounds i8, i8* %p, i64 %indvars.iv
+  %x5 = load i8, i8* %x4
+  %x7 = getelementptr inbounds i8, i8* %q, i64 %indvars.iv
+  %x8 = load i8, i8* %x7
+  %x9 = zext i8 %x5 to i32
+  %x10 = xor i32 %x9, 255
+  %x11 = icmp ult i32 %x10, 24
+  %x12 = select i1 %x11, i32 %x10, i32 24
+  %x13 = trunc i32 %x12 to i8
+  store i8 %x13, i8* %x4
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %lftr.wideiv = trunc i64 %indvars.iv.next to i32
+  %exitcond = icmp eq i32 %lftr.wideiv, %len
+  br i1 %exitcond, label %._crit_edge, label %3
+}
+
+attributes #0 = { nounwind }