[llvm-size] Fix time to check if time of use bug.
[oota-llvm.git] / lib / Analysis / VectorUtils.cpp
index 1ebff0f7c056ced889e780881f6865b49cb0e5e1..93720857662f988a8e0a046cd10cdfeb7afbe165 100644 (file)
 #include "llvm/IR/GetElementPtrTypeIterator.h"
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/Value.h"
+#include "llvm/IR/Constants.h"
+
+using namespace llvm;
+using namespace llvm::PatternMatch;
 
 /// \brief Identify if the intrinsic is trivially vectorizable.
 /// This method returns true if the intrinsic's argument types are all
@@ -79,7 +83,7 @@ bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID,
 /// d) call should only reads memory.
 /// If all these condition is met then return ValidIntrinsicID
 /// else return not_intrinsic.
-llvm::Intrinsic::ID
+Intrinsic::ID
 llvm::checkUnaryFloatSignature(const CallInst &I,
                                Intrinsic::ID ValidIntrinsicID) {
   if (I.getNumArgOperands() != 1 ||
@@ -98,7 +102,7 @@ llvm::checkUnaryFloatSignature(const CallInst &I,
 /// d) call should only reads memory.
 /// If all these condition is met then return ValidIntrinsicID
 /// else return not_intrinsic.
-llvm::Intrinsic::ID
+Intrinsic::ID
 llvm::checkBinaryFloatSignature(const CallInst &I,
                                 Intrinsic::ID ValidIntrinsicID) {
   if (I.getNumArgOperands() != 2 ||
@@ -114,8 +118,8 @@ llvm::checkBinaryFloatSignature(const CallInst &I,
 /// \brief Returns intrinsic ID for call.
 /// For the input call instruction it finds mapping intrinsic and returns
 /// its ID, in case it does not found it return not_intrinsic.
-llvm::Intrinsic::ID llvm::getIntrinsicIDForCall(CallInst *CI,
-                                                const TargetLibraryInfo *TLI) {
+Intrinsic::ID llvm::getIntrinsicIDForCall(CallInst *CI,
+                                          const TargetLibraryInfo *TLI) {
   // If we have an intrinsic call, check if it is trivially vectorizable.
   if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
     Intrinsic::ID ID = II->getIntrinsicID();
@@ -228,8 +232,7 @@ unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) {
       cast<PointerType>(Gep->getType()->getScalarType())->getElementType());
 
   // Walk backwards and try to peel off zeros.
-  while (LastOperand > 1 &&
-         match(Gep->getOperand(LastOperand), llvm::PatternMatch::m_Zero())) {
+  while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) {
     // Find the type we're currently indexing into.
     gep_type_iterator GEPTI = gep_type_begin(Gep);
     std::advance(GEPTI, LastOperand - 1);
@@ -247,8 +250,7 @@ unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) {
 /// \brief If the argument is a GEP, then returns the operand identified by
 /// getGEPInductionOperand. However, if there is some other non-loop-invariant
 /// operand, it returns that instead.
-llvm::Value *llvm::stripGetElementPtr(llvm::Value *Ptr, ScalarEvolution *SE,
-                                      Loop *Lp) {
+Value *llvm::stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
   if (!GEP)
     return Ptr;
@@ -265,8 +267,8 @@ llvm::Value *llvm::stripGetElementPtr(llvm::Value *Ptr, ScalarEvolution *SE,
 }
 
 /// \brief If a value has only one user that is a CastInst, return it.
-llvm::Value *llvm::getUniqueCastUse(llvm::Value *Ptr, Loop *Lp, Type *Ty) {
-  llvm::Value *UniqueCast = nullptr;
+Value *llvm::getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) {
+  Value *UniqueCast = nullptr;
   for (User *U : Ptr->users()) {
     CastInst *CI = dyn_cast<CastInst>(U);
     if (CI && CI->getType() == Ty) {
@@ -281,8 +283,7 @@ llvm::Value *llvm::getUniqueCastUse(llvm::Value *Ptr, Loop *Lp, Type *Ty) {
 
 /// \brief Get the stride of a pointer access in a loop. Looks for symbolic
 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise.
-llvm::Value *llvm::getStrideFromPointer(llvm::Value *Ptr, ScalarEvolution *SE,
-                                        Loop *Lp) {
+Value *llvm::getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
   auto *PtrTy = dyn_cast<PointerType>(Ptr->getType());
   if (!PtrTy || PtrTy->isAggregateType())
     return nullptr;
@@ -290,7 +291,7 @@ llvm::Value *llvm::getStrideFromPointer(llvm::Value *Ptr, ScalarEvolution *SE,
   // Try to remove a gep instruction to make the pointer (actually index at this
   // point) easier analyzable. If OrigPtr is equal to Ptr we are analzying the
   // pointer, otherwise, we are analyzing the index.
-  llvm::Value *OrigPtr = Ptr;
+  Value *OrigPtr = Ptr;
 
   // The size of the pointer access.
   int64_t PtrAccessSize = 1;
@@ -346,7 +347,7 @@ llvm::Value *llvm::getStrideFromPointer(llvm::Value *Ptr, ScalarEvolution *SE,
   if (!U)
     return nullptr;
 
-  llvm::Value *Stride = U->getValue();
+  Value *Stride = U->getValue();
   if (!Lp->isLoopInvariant(Stride))
     return nullptr;
 
@@ -361,7 +362,7 @@ llvm::Value *llvm::getStrideFromPointer(llvm::Value *Ptr, ScalarEvolution *SE,
 /// \brief Given a vector and an element number, see if the scalar value is
 /// already around as a register, for example if it were inserted then extracted
 /// from the vector.
-llvm::Value *llvm::findScalarElement(llvm::Value *V, unsigned EltNo) {
+Value *llvm::findScalarElement(Value *V, unsigned EltNo) {
   assert(V->getType()->isVectorTy() && "Not looking at a vector?");
   VectorType *VTy = cast<VectorType>(V->getType());
   unsigned Width = VTy->getNumElements();
@@ -399,13 +400,37 @@ llvm::Value *llvm::findScalarElement(llvm::Value *V, unsigned EltNo) {
 
   // Extract a value from a vector add operation with a constant zero.
   Value *Val = nullptr; Constant *Con = nullptr;
-  if (match(V,
-            llvm::PatternMatch::m_Add(llvm::PatternMatch::m_Value(Val),
-                                      llvm::PatternMatch::m_Constant(Con)))) {
-    if (Con->getAggregateElement(EltNo)->isNullValue())
-      return findScalarElement(Val, EltNo);
-  }
+  if (match(V, m_Add(m_Value(Val), m_Constant(Con))))
+    if (Constant *Elt = Con->getAggregateElement(EltNo))
+      if (Elt->isNullValue())
+        return findScalarElement(Val, EltNo);
 
   // Otherwise, we don't know.
   return nullptr;
 }
+
+/// \brief Get splat value if the input is a splat vector or return nullptr.
+/// This function is not fully general. It checks only 2 cases:
+/// the input value is (1) a splat constants vector or (2) a sequence
+/// of instructions that broadcast a single value into a vector.
+///
+llvm::Value *llvm::getSplatValue(Value *V) {
+  if (auto *CV = dyn_cast<ConstantDataVector>(V))
+    return CV->getSplatValue();
+
+  auto *ShuffleInst = dyn_cast<ShuffleVectorInst>(V);
+  if (!ShuffleInst)
+    return nullptr;
+  // All-zero (or undef) shuffle mask elements.
+  for (int MaskElt : ShuffleInst->getShuffleMask())
+    if (MaskElt != 0 && MaskElt != -1)
+      return nullptr;
+  // The first shuffle source is 'insertelement' with index 0.
+  auto *InsertEltInst =
+    dyn_cast<InsertElementInst>(ShuffleInst->getOperand(0));
+  if (!InsertEltInst || !isa<ConstantInt>(InsertEltInst->getOperand(2)) ||
+      !cast<ConstantInt>(InsertEltInst->getOperand(2))->isNullValue())
+    return nullptr;
+
+  return InsertEltInst->getOperand(1);
+}