[LIR] Move all the helpers to be private and re-order the methods in
[oota-llvm.git] / lib / Transforms / Scalar / NaryReassociate.cpp
index 6ac5ff85e32cef855520c39ae66be431dc95eefd..58b9c9d092db6137832754a99cf666e3f34b9595 100644 (file)
 // 1) We only considers n-ary adds for now. This should be extended and
 // generalized.
 //
-// 2) Besides arithmetic operations, similar reassociation can be applied to
-// GEPs. For example, if
-//   X = &arr[a]
-// dominates
-//   Y = &arr[a + b]
-// we may rewrite Y into X + b.
-//
 //===----------------------------------------------------------------------===//
 
+#include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/Local.h"
 using namespace llvm;
@@ -115,6 +112,7 @@ public:
     AU.addPreserved<DominatorTreeWrapperPass>();
     AU.addPreserved<ScalarEvolution>();
     AU.addPreserved<TargetLibraryInfoWrapperPass>();
+    AU.addRequired<AssumptionCacheTracker>();
     AU.addRequired<DominatorTreeWrapperPass>();
     AU.addRequired<ScalarEvolution>();
     AU.addRequired<TargetLibraryInfoWrapperPass>();
@@ -163,12 +161,18 @@ private:
   // GEP's pointer size, i.e., whether Index needs to be sign-extended in order
   // to be an index of GEP.
   bool requiresSignExtension(Value *Index, GetElementPtrInst *GEP);
+  // Returns whether V is known to be non-negative at context \c Ctxt.
+  bool isKnownNonNegative(Value *V, Instruction *Ctxt);
+  // Returns whether AO may sign overflow at context \c Ctxt. It computes a
+  // conservative result -- it answers true when not sure.
+  bool maySignOverflow(AddOperator *AO, Instruction *Ctxt);
 
+  AssumptionCache *AC;
+  const DataLayout *DL;
   DominatorTree *DT;
   ScalarEvolution *SE;
   TargetLibraryInfo *TLI;
   TargetTransformInfo *TTI;
-  const DataLayout *DL;
   // A lookup table quickly telling which instructions compute the given SCEV.
   // Note that there can be multiple instructions at different locations
   // computing to the same SCEV, so we map a SCEV to an instruction list.  For
@@ -185,6 +189,7 @@ private:
 char NaryReassociate::ID = 0;
 INITIALIZE_PASS_BEGIN(NaryReassociate, "nary-reassociate", "Nary reassociation",
                       false, false)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
@@ -200,6 +205,7 @@ bool NaryReassociate::runOnFunction(Function &F) {
   if (skipOptnoneFunction(F))
     return false;
 
+  AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
   DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
   SE = &getAnalysis<ScalarEvolution>();
   TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
@@ -317,8 +323,10 @@ static bool isGEPFoldable(GetElementPtrInst *GEP,
       BaseOffset += DL->getStructLayout(STy)->getElementOffset(Field);
     }
   }
+
+  unsigned AddrSpace = GEP->getPointerAddressSpace();
   return TTI->isLegalAddressingMode(GEP->getType()->getElementType(), BaseGV,
-                                    BaseOffset, HasBaseReg, Scale);
+                                    BaseOffset, HasBaseReg, Scale, AddrSpace);
 }
 
 Instruction *NaryReassociate::tryReassociateGEP(GetElementPtrInst *GEP) {
@@ -344,18 +352,44 @@ bool NaryReassociate::requiresSignExtension(Value *Index,
   return cast<IntegerType>(Index->getType())->getBitWidth() < PointerSizeInBits;
 }
 
+bool NaryReassociate::isKnownNonNegative(Value *V, Instruction *Ctxt) {
+  bool NonNegative, Negative;
+  // TODO: ComputeSignBits is expensive. Consider caching the results.
+  ComputeSignBit(V, NonNegative, Negative, *DL, 0, AC, Ctxt, DT);
+  return NonNegative;
+}
+
+bool NaryReassociate::maySignOverflow(AddOperator *AO, Instruction *Ctxt) {
+  if (AO->hasNoSignedWrap())
+    return false;
+
+  Value *LHS = AO->getOperand(0), *RHS = AO->getOperand(1);
+  // If LHS or RHS has the same sign as the sum, AO doesn't sign overflow.
+  // TODO: handle the negative case as well.
+  if (isKnownNonNegative(AO, Ctxt) &&
+      (isKnownNonNegative(LHS, Ctxt) || isKnownNonNegative(RHS, Ctxt)))
+    return false;
+
+  return true;
+}
+
 GetElementPtrInst *
 NaryReassociate::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, unsigned I,
                                           Type *IndexedType) {
   Value *IndexToSplit = GEP->getOperand(I + 1);
-  if (SExtInst *SExt = dyn_cast<SExtInst>(IndexToSplit))
+  if (SExtInst *SExt = dyn_cast<SExtInst>(IndexToSplit)) {
     IndexToSplit = SExt->getOperand(0);
+  } else if (ZExtInst *ZExt = dyn_cast<ZExtInst>(IndexToSplit)) {
+    // zext can be treated as sext if the source is non-negative.
+    if (isKnownNonNegative(ZExt->getOperand(0), GEP))
+      IndexToSplit = ZExt->getOperand(0);
+  }
 
   if (AddOperator *AO = dyn_cast<AddOperator>(IndexToSplit)) {
     // If the I-th index needs sext and the underlying add is not equipped with
     // nsw, we cannot split the add because
     //   sext(LHS + RHS) != sext(LHS) + sext(RHS).
-    if (requiresSignExtension(IndexToSplit, GEP) && !AO->hasNoSignedWrap())
+    if (requiresSignExtension(IndexToSplit, GEP) && maySignOverflow(AO, GEP))
       return nullptr;
     Value *LHS = AO->getOperand(0), *RHS = AO->getOperand(1);
     // IndexToSplit = LHS + RHS.
@@ -371,10 +405,9 @@ NaryReassociate::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, unsigned I,
   return nullptr;
 }
 
-GetElementPtrInst *
-NaryReassociate::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, unsigned I,
-                                          Value *LHS, Value *RHS,
-                                          Type *IndexedType) {
+GetElementPtrInst *NaryReassociate::tryReassociateGEPAtIndex(
+    GetElementPtrInst *GEP, unsigned I, Value *LHS, Value *RHS,
+    Type *IndexedType) {
   // Look for GEP's closest dominator that has the same SCEV as GEP except that
   // the I-th index is replaced with LHS.
   SmallVector<const SCEV *, 4> IndexExprs;
@@ -382,6 +415,16 @@ NaryReassociate::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, unsigned I,
     IndexExprs.push_back(SE->getSCEV(*Index));
   // Replace the I-th index with LHS.
   IndexExprs[I] = SE->getSCEV(LHS);
+  if (isKnownNonNegative(LHS, GEP) &&
+      DL->getTypeSizeInBits(LHS->getType()) <
+          DL->getTypeSizeInBits(GEP->getOperand(I)->getType())) {
+    // Zero-extend LHS if it is non-negative. InstCombine canonicalizes sext to
+    // zext if the source operand is proved non-negative. We should do that
+    // consistently so that CandidateExpr more likely appears before. See
+    // @reassociate_gep_assume for an example of this canonicalization.
+    IndexExprs[I] =
+        SE->getZeroExtendExpr(IndexExprs[I], GEP->getOperand(I)->getType());
+  }
   const SCEV *CandidateExpr = SE->getGEPExpr(
       GEP->getSourceElementType(), SE->getSCEV(GEP->getPointerOperand()),
       IndexExprs, GEP->isInBounds());
@@ -465,11 +508,6 @@ Instruction *NaryReassociate::tryReassociateAdd(Value *LHS, Value *RHS,
 
 Instruction *NaryReassociate::tryReassociatedAdd(const SCEV *LHSExpr,
                                                  Value *RHS, Instruction *I) {
-  auto Pos = SeenExprs.find(LHSExpr);
-  // Bail out if LHSExpr is not previously seen.
-  if (Pos == SeenExprs.end())
-    return nullptr;
-
   // Look for the closest dominator LHS of I that computes LHSExpr, and replace
   // I with LHS + RHS.
   auto *LHS = findClosestMatchingDominator(LHSExpr, I);