Reformat.
[oota-llvm.git] / lib / Transforms / Scalar / StraightLineStrengthReduce.cpp
index 2fc93681e9835ddb3823d6c0dbcb2b0d6085a1ad..453503ab61da66a433d472a5475b087bb53a161c 100644 (file)
@@ -61,6 +61,7 @@
 #include "llvm/ADT/FoldingSet.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/IRBuilder.h"
@@ -68,6 +69,7 @@
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/Local.h"
 
 using namespace llvm;
 using namespace PatternMatch;
@@ -403,20 +405,37 @@ void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd(
   }
 }
 
+// Returns true if A matches B + C where C is constant.
+static bool matchesAdd(Value *A, Value *&B, ConstantInt *&C) {
+  return (match(A, m_Add(m_Value(B), m_ConstantInt(C))) ||
+          match(A, m_Add(m_ConstantInt(C), m_Value(B))));
+}
+
+// Returns true if A matches B | C where C is constant.
+static bool matchesOr(Value *A, Value *&B, ConstantInt *&C) {
+  return (match(A, m_Or(m_Value(B), m_ConstantInt(C))) ||
+          match(A, m_Or(m_ConstantInt(C), m_Value(B))));
+}
+
 void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul(
     Value *LHS, Value *RHS, Instruction *I) {
   Value *B = nullptr;
   ConstantInt *Idx = nullptr;
-  // Only handle the canonical operand ordering.
-  if (match(LHS, m_Add(m_Value(B), m_ConstantInt(Idx)))) {
+  if (matchesAdd(LHS, B, Idx)) {
     // If LHS is in the form of "Base + Index", then I is in the form of
     // "(Base + Index) * RHS".
     allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS, I);
+  } else if (matchesOr(LHS, B, Idx) && haveNoCommonBitsSet(B, Idx, *DL)) {
+    // If LHS is in the form of "Base | Index" and Base and Index have no common
+    // bits set, then
+    //   Base | Index = Base + Index
+    // and I is thus in the form of "(Base + Index) * RHS".
+    allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS, I);
   } else {
     // Otherwise, at least try the form (LHS + 0) * RHS.
     ConstantInt *Zero = ConstantInt::get(cast<IntegerType>(I->getType()), 0);
     allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(LHS), Zero, RHS,
-                                  I);
+                                   I);
   }
 }
 
@@ -490,31 +509,34 @@ void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP(
   if (GEP->getType()->isVectorTy())
     return;
 
-  const SCEV *GEPExpr = SE->getSCEV(GEP);
-  Type *IntPtrTy = DL->getIntPtrType(GEP->getType());
+  SmallVector<const SCEV *, 4> IndexExprs;
+  for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I)
+    IndexExprs.push_back(SE->getSCEV(*I));
 
   gep_type_iterator GTI = gep_type_begin(GEP);
-  for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I) {
+  for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I) {
     if (!isa<SequentialType>(*GTI++))
       continue;
-    Value *ArrayIdx = *I;
-    // Compute the byte offset of this index.
+
+    const SCEV *OrigIndexExpr = IndexExprs[I - 1];
+    IndexExprs[I - 1] = SE->getConstant(OrigIndexExpr->getType(), 0);
+
+    // The base of this candidate is GEP's base plus the offsets of all
+    // indices except this current one.
+    const SCEV *BaseExpr = SE->getGEPExpr(GEP->getSourceElementType(),
+                                          SE->getSCEV(GEP->getPointerOperand()),
+                                          IndexExprs, GEP->isInBounds());
+    Value *ArrayIdx = GEP->getOperand(I);
     uint64_t ElementSize = DL->getTypeAllocSize(*GTI);
-    const SCEV *ElementSizeExpr = SE->getSizeOfExpr(IntPtrTy, *GTI);
-    const SCEV *ArrayIdxExpr = SE->getSCEV(ArrayIdx);
-    ArrayIdxExpr = SE->getTruncateOrSignExtend(ArrayIdxExpr, IntPtrTy);
-    const SCEV *LocalOffset =
-        SE->getMulExpr(ArrayIdxExpr, ElementSizeExpr, SCEV::FlagNSW);
-    // The base of this candidate equals GEPExpr less the byte offset of this
-    // index.
-    const SCEV *Base = SE->getMinusSCEV(GEPExpr, LocalOffset);
-    factorArrayIndex(ArrayIdx, Base, ElementSize, GEP);
+    factorArrayIndex(ArrayIdx, BaseExpr, ElementSize, GEP);
     // When ArrayIdx is the sext of a value, we try to factor that value as
     // well.  Handling this case is important because array indices are
     // typically sign-extended to the pointer size.
     Value *TruncatedArrayIdx = nullptr;
     if (match(ArrayIdx, m_SExt(m_Value(TruncatedArrayIdx))))
-      factorArrayIndex(TruncatedArrayIdx, Base, ElementSize, GEP);
+      factorArrayIndex(TruncatedArrayIdx, BaseExpr, ElementSize, GEP);
+
+    IndexExprs[I - 1] = OrigIndexExpr;
   }
 }
 
@@ -599,9 +621,14 @@ void StraightLineStrengthReduce::rewriteCandidateWithBasis(
   switch (C.CandidateKind) {
   case Candidate::Add:
   case Candidate::Mul:
+    // C = Basis + Bump
     if (BinaryOperator::isNeg(Bump)) {
+      // If Bump is a neg instruction, emit C = Basis - (-Bump).
       Reduced =
           Builder.CreateSub(Basis.Ins, BinaryOperator::getNegArgument(Bump));
+      // We only use the negative argument of Bump, and Bump itself may be
+      // trivially dead.
+      RecursivelyDeleteTriviallyDeadInstructions(Bump);
     } else {
       Reduced = Builder.CreateAdd(Basis.Ins, Bump);
     }
@@ -637,7 +664,6 @@ void StraightLineStrengthReduce::rewriteCandidateWithBasis(
   };
   Reduced->takeName(C.Ins);
   C.Ins->replaceAllUsesWith(Reduced);
-  C.Ins->dropAllReferences();
   // Unlink C.Ins so that we can skip other candidates also corresponding to
   // C.Ins. The actual deletion is postponed to the end of runOnFunction.
   C.Ins->removeFromParent();
@@ -670,8 +696,13 @@ bool StraightLineStrengthReduce::runOnFunction(Function &F) {
   }
 
   // Delete all unlink instructions.
-  for (auto I : UnlinkedInstructions) {
-    delete I;
+  for (auto *UnlinkedInst : UnlinkedInstructions) {
+    for (unsigned I = 0, E = UnlinkedInst->getNumOperands(); I != E; ++I) {
+      Value *Op = UnlinkedInst->getOperand(I);
+      UnlinkedInst->setOperand(I, nullptr);
+      RecursivelyDeleteTriviallyDeadInstructions(Op);
+    }
+    delete UnlinkedInst;
   }
   bool Ret = !UnlinkedInstructions.empty();
   UnlinkedInstructions.clear();