[Bitcode][Asm] Teach LLVM to read and write operand bundles.
[oota-llvm.git] / lib / Analysis / VectorUtils.cpp
index eab5887a17ea6720f7bec28403db8264455020d4..93720857662f988a8e0a046cd10cdfeb7afbe165 100644 (file)
 #include "llvm/IR/GetElementPtrTypeIterator.h"
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/Value.h"
+#include "llvm/IR/Constants.h"
+
+using namespace llvm;
+using namespace llvm::PatternMatch;
 
 /// \brief Identify if the intrinsic is trivially vectorizable.
 /// This method returns true if the intrinsic's argument types are all
@@ -79,7 +83,7 @@ bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID,
 /// d) call should only reads memory.
 /// If all these condition is met then return ValidIntrinsicID
 /// else return not_intrinsic.
-llvm::Intrinsic::ID
+Intrinsic::ID
 llvm::checkUnaryFloatSignature(const CallInst &I,
                                Intrinsic::ID ValidIntrinsicID) {
   if (I.getNumArgOperands() != 1 ||
@@ -98,7 +102,7 @@ llvm::checkUnaryFloatSignature(const CallInst &I,
 /// d) call should only reads memory.
 /// If all these condition is met then return ValidIntrinsicID
 /// else return not_intrinsic.
-llvm::Intrinsic::ID
+Intrinsic::ID
 llvm::checkBinaryFloatSignature(const CallInst &I,
                                 Intrinsic::ID ValidIntrinsicID) {
   if (I.getNumArgOperands() != 2 ||
@@ -114,8 +118,8 @@ llvm::checkBinaryFloatSignature(const CallInst &I,
 /// \brief Returns intrinsic ID for call.
 /// For the input call instruction it finds mapping intrinsic and returns
 /// its ID, in case it does not found it return not_intrinsic.
-llvm::Intrinsic::ID llvm::getIntrinsicIDForCall(CallInst *CI,
-                                                const TargetLibraryInfo *TLI) {
+Intrinsic::ID llvm::getIntrinsicIDForCall(CallInst *CI,
+                                          const TargetLibraryInfo *TLI) {
   // If we have an intrinsic call, check if it is trivially vectorizable.
   if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
     Intrinsic::ID ID = II->getIntrinsicID();
@@ -228,8 +232,7 @@ unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) {
       cast<PointerType>(Gep->getType()->getScalarType())->getElementType());
 
   // Walk backwards and try to peel off zeros.
-  while (LastOperand > 1 &&
-         match(Gep->getOperand(LastOperand), llvm::PatternMatch::m_Zero())) {
+  while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) {
     // Find the type we're currently indexing into.
     gep_type_iterator GEPTI = gep_type_begin(Gep);
     std::advance(GEPTI, LastOperand - 1);
@@ -247,8 +250,7 @@ unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) {
 /// \brief If the argument is a GEP, then returns the operand identified by
 /// getGEPInductionOperand. However, if there is some other non-loop-invariant
 /// operand, it returns that instead.
-llvm::Value *llvm::stripGetElementPtr(llvm::Value *Ptr, ScalarEvolution *SE,
-                                      Loop *Lp) {
+Value *llvm::stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
   if (!GEP)
     return Ptr;
@@ -265,8 +267,8 @@ llvm::Value *llvm::stripGetElementPtr(llvm::Value *Ptr, ScalarEvolution *SE,
 }
 
 /// \brief If a value has only one user that is a CastInst, return it.
-llvm::Value *llvm::getUniqueCastUse(llvm::Value *Ptr, Loop *Lp, Type *Ty) {
-  llvm::Value *UniqueCast = nullptr;
+Value *llvm::getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) {
+  Value *UniqueCast = nullptr;
   for (User *U : Ptr->users()) {
     CastInst *CI = dyn_cast<CastInst>(U);
     if (CI && CI->getType() == Ty) {
@@ -281,16 +283,15 @@ llvm::Value *llvm::getUniqueCastUse(llvm::Value *Ptr, Loop *Lp, Type *Ty) {
 
 /// \brief Get the stride of a pointer access in a loop. Looks for symbolic
 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise.
-llvm::Value *llvm::getStrideFromPointer(llvm::Value *Ptr, ScalarEvolution *SE,
-                                        Loop *Lp) {
-  const PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType());
+Value *llvm::getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) {
+  auto *PtrTy = dyn_cast<PointerType>(Ptr->getType());
   if (!PtrTy || PtrTy->isAggregateType())
     return nullptr;
 
   // Try to remove a gep instruction to make the pointer (actually index at this
   // point) easier analyzable. If OrigPtr is equal to Ptr we are analzying the
   // pointer, otherwise, we are analyzing the index.
-  llvm::Value *OrigPtr = Ptr;
+  Value *OrigPtr = Ptr;
 
   // The size of the pointer access.
   int64_t PtrAccessSize = 1;
@@ -346,7 +347,7 @@ llvm::Value *llvm::getStrideFromPointer(llvm::Value *Ptr, ScalarEvolution *SE,
   if (!U)
     return nullptr;
 
-  llvm::Value *Stride = U->getValue();
+  Value *Stride = U->getValue();
   if (!Lp->isLoopInvariant(Stride))
     return nullptr;
 
@@ -357,3 +358,79 @@ llvm::Value *llvm::getStrideFromPointer(llvm::Value *Ptr, ScalarEvolution *SE,
 
   return Stride;
 }
+
+/// \brief Given a vector and an element number, see if the scalar value is
+/// already around as a register, for example if it were inserted then extracted
+/// from the vector.
+Value *llvm::findScalarElement(Value *V, unsigned EltNo) {
+  assert(V->getType()->isVectorTy() && "Not looking at a vector?");
+  VectorType *VTy = cast<VectorType>(V->getType());
+  unsigned Width = VTy->getNumElements();
+  if (EltNo >= Width)  // Out of range access.
+    return UndefValue::get(VTy->getElementType());
+
+  if (Constant *C = dyn_cast<Constant>(V))
+    return C->getAggregateElement(EltNo);
+
+  if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) {
+    // If this is an insert to a variable element, we don't know what it is.
+    if (!isa<ConstantInt>(III->getOperand(2)))
+      return nullptr;
+    unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue();
+
+    // If this is an insert to the element we are looking for, return the
+    // inserted value.
+    if (EltNo == IIElt)
+      return III->getOperand(1);
+
+    // Otherwise, the insertelement doesn't modify the value, recurse on its
+    // vector input.
+    return findScalarElement(III->getOperand(0), EltNo);
+  }
+
+  if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) {
+    unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements();
+    int InEl = SVI->getMaskValue(EltNo);
+    if (InEl < 0)
+      return UndefValue::get(VTy->getElementType());
+    if (InEl < (int)LHSWidth)
+      return findScalarElement(SVI->getOperand(0), InEl);
+    return findScalarElement(SVI->getOperand(1), InEl - LHSWidth);
+  }
+
+  // Extract a value from a vector add operation with a constant zero.
+  Value *Val = nullptr; Constant *Con = nullptr;
+  if (match(V, m_Add(m_Value(Val), m_Constant(Con))))
+    if (Constant *Elt = Con->getAggregateElement(EltNo))
+      if (Elt->isNullValue())
+        return findScalarElement(Val, EltNo);
+
+  // Otherwise, we don't know.
+  return nullptr;
+}
+
+/// \brief Get splat value if the input is a splat vector or return nullptr.
+/// This function is not fully general. It checks only 2 cases:
+/// the input value is (1) a splat constants vector or (2) a sequence
+/// of instructions that broadcast a single value into a vector.
+///
+llvm::Value *llvm::getSplatValue(Value *V) {
+  if (auto *CV = dyn_cast<ConstantDataVector>(V))
+    return CV->getSplatValue();
+
+  auto *ShuffleInst = dyn_cast<ShuffleVectorInst>(V);
+  if (!ShuffleInst)
+    return nullptr;
+  // All-zero (or undef) shuffle mask elements.
+  for (int MaskElt : ShuffleInst->getShuffleMask())
+    if (MaskElt != 0 && MaskElt != -1)
+      return nullptr;
+  // The first shuffle source is 'insertelement' with index 0.
+  auto *InsertEltInst =
+    dyn_cast<InsertElementInst>(ShuffleInst->getOperand(0));
+  if (!InsertEltInst || !isa<ConstantInt>(InsertEltInst->getOperand(2)) ||
+      !cast<ConstantInt>(InsertEltInst->getOperand(2))->isNullValue())
+    return nullptr;
+
+  return InsertEltInst->getOperand(1);
+}