[SCEV] Don't create SCEV expressions that break LCSSA
[oota-llvm.git] / lib / Analysis / VectorUtils.cpp
index 92a880c3762b501c808efc440cc18df4eca26efa..4153c843c40eba11f98bfc70ded54514e20b04a9 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
+#include "llvm/ADT/EquivalenceClasses.h"
+#include "llvm/Analysis/DemandedBits.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/GetElementPtrTypeIterator.h"
 #include "llvm/IR/PatternMatch.h"
@@ -410,25 +413,154 @@ Value *llvm::findScalarElement(Value *V, unsigned EltNo) {
 }
 
 /// \brief Get splat value if the input is a splat vector or return nullptr.
-/// The value may be extracted from a splat constants vector or from
-/// a sequence of instructions that broadcast a single value into a vector.
+/// This function is not fully general. It checks only 2 cases:
+/// the input value is (1) a splat constants vector or (2) a sequence
+/// of instructions that broadcast a single value into a vector.
+///
 llvm::Value *llvm::getSplatValue(Value *V) {
-  llvm::ConstantDataVector *CV = dyn_cast<llvm::ConstantDataVector>(V);
-  if (CV)
+  if (auto *CV = dyn_cast<ConstantDataVector>(V))
     return CV->getSplatValue();
-  llvm::ShuffleVectorInst *ShuffleInst = dyn_cast<llvm::ShuffleVectorInst>(V);
+
+  auto *ShuffleInst = dyn_cast<ShuffleVectorInst>(V);
   if (!ShuffleInst)
     return nullptr;
-  // All-zero (our undef) shuffle mask elements.
-  for (int i : ShuffleInst->getShuffleMask())
-    if (i != 0 && i != -1)
+  // All-zero (or undef) shuffle mask elements.
+  for (int MaskElt : ShuffleInst->getShuffleMask())
+    if (MaskElt != 0 && MaskElt != -1)
       return nullptr;
   // The first shuffle source is 'insertelement' with index 0.
-  llvm::InsertElementInst *InsertEltInst =
-    dyn_cast<llvm::InsertElementInst>(ShuffleInst->getOperand(0));
+  auto *InsertEltInst =
+    dyn_cast<InsertElementInst>(ShuffleInst->getOperand(0));
   if (!InsertEltInst || !isa<ConstantInt>(InsertEltInst->getOperand(2)) ||
       !cast<ConstantInt>(InsertEltInst->getOperand(2))->isNullValue())
     return nullptr;
 
   return InsertEltInst->getOperand(1);
 }
+
+DenseMap<Instruction*, uint64_t> llvm::computeMinimumValueSizes(
+  ArrayRef<BasicBlock*> Blocks, DemandedBits &DB,
+  const TargetTransformInfo *TTI) {
+
+  // DemandedBits will give us every value's live-out bits. But we want
+  // to ensure no extra casts would need to be inserted, so every DAG
+  // of connected values must have the same minimum bitwidth.
+  EquivalenceClasses<Value*> ECs;
+  SmallVector<Value*,16> Worklist;
+  SmallPtrSet<Value*,4> Roots;
+  SmallPtrSet<Value*,16> Visited;
+  DenseMap<Value*,uint64_t> DBits;
+  SmallPtrSet<Instruction*,4> InstructionSet;
+  DenseMap<Instruction*, uint64_t> MinBWs;
+  
+  // Determine the roots. We work bottom-up, from truncs or icmps.
+  bool SeenExtFromIllegalType = false;
+  for (auto *BB : Blocks)
+    for (auto &I : *BB) {
+      InstructionSet.insert(&I);
+
+      if (TTI && (isa<ZExtInst>(&I) || isa<SExtInst>(&I)) &&
+          !TTI->isTypeLegal(I.getOperand(0)->getType()))
+        SeenExtFromIllegalType = true;
+    
+      // Only deal with non-vector integers up to 64-bits wide.
+      if ((isa<TruncInst>(&I) || isa<ICmpInst>(&I)) &&
+          !I.getType()->isVectorTy() &&
+          I.getOperand(0)->getType()->getScalarSizeInBits() <= 64) {
+        // Don't make work for ourselves. If we know the loaded type is legal,
+        // don't add it to the worklist.
+        if (TTI && isa<TruncInst>(&I) && TTI->isTypeLegal(I.getType()))
+          continue;
+      
+        Worklist.push_back(&I);
+        Roots.insert(&I);
+      }
+    }
+  // Early exit.
+  if (Worklist.empty() || (TTI && !SeenExtFromIllegalType))
+    return MinBWs;
+  
+  // Now proceed breadth-first, unioning values together.
+  while (!Worklist.empty()) {
+    Value *Val = Worklist.pop_back_val();
+    Value *Leader = ECs.getOrInsertLeaderValue(Val);
+    
+    if (Visited.count(Val))
+      continue;
+    Visited.insert(Val);
+
+    // Non-instructions terminate a chain successfully.
+    if (!isa<Instruction>(Val))
+      continue;
+    Instruction *I = cast<Instruction>(Val);
+
+    // If we encounter a type that is larger than 64 bits, we can't represent
+    // it so bail out.
+    if (DB.getDemandedBits(I).getBitWidth() > 64)
+      return DenseMap<Instruction*,uint64_t>();
+    
+    uint64_t V = DB.getDemandedBits(I).getZExtValue();
+    DBits[Leader] |= V;
+    
+    // Casts, loads and instructions outside of our range terminate a chain
+    // successfully.
+    if (isa<SExtInst>(I) || isa<ZExtInst>(I) || isa<LoadInst>(I) ||
+        !InstructionSet.count(I))
+      continue;
+
+    // Unsafe casts terminate a chain unsuccessfully. We can't do anything
+    // useful with bitcasts, ptrtoints or inttoptrs and it'd be unsafe to
+    // transform anything that relies on them.
+    if (isa<BitCastInst>(I) || isa<PtrToIntInst>(I) || isa<IntToPtrInst>(I) ||
+        !I->getType()->isIntegerTy()) {
+      DBits[Leader] |= ~0ULL;
+      continue;
+    }
+
+    // We don't modify the types of PHIs. Reductions will already have been
+    // truncated if possible, and inductions' sizes will have been chosen by
+    // indvars.
+    if (isa<PHINode>(I))
+      continue;
+
+    if (DBits[Leader] == ~0ULL)
+      // All bits demanded, no point continuing.
+      continue;
+
+    for (Value *O : cast<User>(I)->operands()) {
+      ECs.unionSets(Leader, O);
+      Worklist.push_back(O);
+    }
+  }
+
+  // Now we've discovered all values, walk them to see if there are
+  // any users we didn't see. If there are, we can't optimize that
+  // chain.
+  for (auto &I : DBits)
+    for (auto *U : I.first->users())
+      if (U->getType()->isIntegerTy() && DBits.count(U) == 0)
+        DBits[ECs.getOrInsertLeaderValue(I.first)] |= ~0ULL;
+  
+  for (auto I = ECs.begin(), E = ECs.end(); I != E; ++I) {
+    uint64_t LeaderDemandedBits = 0;
+    for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI)
+      LeaderDemandedBits |= DBits[*MI];
+
+    uint64_t MinBW = (sizeof(LeaderDemandedBits) * 8) -
+                     llvm::countLeadingZeros(LeaderDemandedBits);
+    // Round up to a power of 2
+    if (!isPowerOf2_64((uint64_t)MinBW))
+      MinBW = NextPowerOf2(MinBW);
+    for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) {
+      if (!isa<Instruction>(*MI))
+        continue;
+      Type *Ty = (*MI)->getType();
+      if (Roots.count(*MI))
+        Ty = cast<Instruction>(*MI)->getOperand(0)->getType();
+      if (MinBW < Ty->getScalarSizeInBits())
+        MinBWs[cast<Instruction>(*MI)] = MinBW;
+    }
+  }
+
+  return MinBWs;
+}