Teach SCEVExpander to expand arithmetic involving pointers into GEP
[oota-llvm.git] / lib / Analysis / ScalarEvolutionExpander.cpp
index fd132746ad121829d423e2973a8c399b66053ed5..36b6206a9bc1902077f2c835dc38bac566f9d06d 100644 (file)
@@ -15,6 +15,7 @@
 
 #include "llvm/Analysis/ScalarEvolutionExpander.h"
 #include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Target/TargetData.h"
 using namespace llvm;
 
 /// InsertCastOfTo - Insert a cast of V to the specified type, doing what
@@ -130,10 +131,9 @@ Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode, Value *LHS,
     BasicBlock::iterator IP = InsertPt;
     --IP;
     for (; ScanLimit; --IP, --ScanLimit) {
-      if (BinaryOperator *BinOp = dyn_cast<BinaryOperator>(IP))
-        if (BinOp->getOpcode() == Opcode && BinOp->getOperand(0) == LHS &&
-            BinOp->getOperand(1) == RHS)
-          return BinOp;
+      if (IP->getOpcode() == (unsigned)Opcode && IP->getOperand(0) == LHS &&
+          IP->getOperand(1) == RHS)
+        return IP;
       if (IP == BlockBegin) break;
     }
   }
@@ -144,9 +144,156 @@ Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode, Value *LHS,
   return BO;
 }
 
+/// expandAddToGEP - Expand a SCEVAddExpr with a pointer type into a GEP
+/// instead of using ptrtoint+arithmetic+inttoptr.
+Value *SCEVExpander::expandAddToGEP(const SCEVAddExpr *S,
+                                    const PointerType *PTy,
+                                    const Type *Ty,
+                                    Value *V) {
+  const Type *ElTy = PTy->getElementType();
+  SmallVector<Value *, 4> GepIndices;
+  std::vector<SCEVHandle> Ops = S->getOperands();
+  bool AnyNonZeroIndices = false;
+  Ops.pop_back();
+
+  // Decend down the pointer's type and attempt to convert the other
+  // operands into GEP indices, at each level. The first index in a GEP
+  // indexes into the array implied by the pointer operand; the rest of
+  // the indices index into the element or field type selected by the
+  // preceding index.
+  for (;;) {
+    APInt ElSize = APInt(SE.getTypeSizeInBits(Ty),
+                         ElTy->isSized() ?  SE.TD->getTypeAllocSize(ElTy) : 0);
+    std::vector<SCEVHandle> NewOps;
+    std::vector<SCEVHandle> ScaledOps;
+    for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
+      if (ElSize != 0) {
+        if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i]))
+          if (!C->getValue()->getValue().srem(ElSize)) {
+            ConstantInt *CI =
+              ConstantInt::get(C->getValue()->getValue().sdiv(ElSize));
+            SCEVHandle Div = SE.getConstant(CI);
+            ScaledOps.push_back(Div);
+            continue;
+          }
+        if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i]))
+          if (const SCEVConstant *C = dyn_cast<SCEVConstant>(M->getOperand(0)))
+            if (C->getValue()->getValue() == ElSize) {
+              for (unsigned j = 1, f = M->getNumOperands(); j != f; ++j)
+                ScaledOps.push_back(M->getOperand(j));
+              continue;
+            }
+        if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Ops[i]))
+          if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getValue()))
+            if (BO->getOpcode() == Instruction::Mul)
+              if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1)))
+                if (CI->getValue() == ElSize) {
+                  ScaledOps.push_back(SE.getUnknown(BO->getOperand(0)));
+                  continue;
+                }
+        if (ElSize == 1) {
+          ScaledOps.push_back(Ops[i]);
+          continue;
+        }
+      }
+      NewOps.push_back(Ops[i]);
+    }
+    Ops = NewOps;
+    AnyNonZeroIndices |= !ScaledOps.empty();
+    Value *Scaled = ScaledOps.empty() ?
+                    Constant::getNullValue(Ty) :
+                    expandCodeFor(SE.getAddExpr(ScaledOps), Ty);
+    GepIndices.push_back(Scaled);
+
+    // Collect struct field index operands.
+    if (!Ops.empty())
+      while (const StructType *STy = dyn_cast<StructType>(ElTy)) {
+        if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[0]))
+          if (SE.getTypeSizeInBits(C->getType()) <= 64) {
+            const StructLayout &SL = *SE.TD->getStructLayout(STy);
+            uint64_t FullOffset = C->getValue()->getZExtValue();
+            if (FullOffset < SL.getSizeInBytes()) {
+              unsigned ElIdx = SL.getElementContainingOffset(FullOffset);
+              GepIndices.push_back(ConstantInt::get(Type::Int32Ty, ElIdx));
+              ElTy = STy->getTypeAtIndex(ElIdx);
+              Ops[0] =
+                SE.getConstant(ConstantInt::get(Ty,
+                                                FullOffset -
+                                                  SL.getElementOffset(ElIdx)));
+              AnyNonZeroIndices = true;
+              continue;
+            }
+          }
+        break;
+      }
+
+    if (const ArrayType *ATy = dyn_cast<ArrayType>(ElTy)) {
+      ElTy = ATy->getElementType();
+      continue;
+    }
+    break;
+  }
+
+  // If none of the operands were convertable to proper GEP indices, cast
+  // the base to i8* and do an ugly getelementptr with that. It's still
+  // better than ptrtoint+arithmetic+inttoptr at least.
+  if (!AnyNonZeroIndices) {
+    V = InsertNoopCastOfTo(V,
+                           Type::Int8Ty->getPointerTo(PTy->getAddressSpace()));
+    Value *Idx = expand(SE.getAddExpr(Ops));
+    Idx = InsertNoopCastOfTo(Idx, Ty);
+
+    // Fold a GEP with constant operands.
+    if (Constant *CLHS = dyn_cast<Constant>(V))
+      if (Constant *CRHS = dyn_cast<Constant>(Idx))
+        return ConstantExpr::get(Instruction::GetElementPtr, CLHS, CRHS);
+
+    // Do a quick scan to see if we have this GEP nearby.  If so, reuse it.
+    unsigned ScanLimit = 6;
+    BasicBlock::iterator BlockBegin = InsertPt->getParent()->begin();
+    if (InsertPt != BlockBegin) {
+      // Scanning starts from the last instruction before InsertPt.
+      BasicBlock::iterator IP = InsertPt;
+      --IP;
+      for (; ScanLimit; --IP, --ScanLimit) {
+        if (IP->getOpcode() == Instruction::GetElementPtr &&
+            IP->getOperand(0) == V && IP->getOperand(1) == Idx)
+          return IP;
+        if (IP == BlockBegin) break;
+      }
+    }
+
+    Value *GEP = GetElementPtrInst::Create(V, Idx, "scevgep", InsertPt);
+    InsertedValues.insert(GEP);
+    return GEP;
+  }
+
+  // Insert a pretty getelementptr.
+  Value *GEP = GetElementPtrInst::Create(V,
+                                         GepIndices.begin(),
+                                         GepIndices.end(),
+                                         "scevgep", InsertPt);
+  Ops.push_back(SE.getUnknown(GEP));
+  InsertedValues.insert(GEP);
+  return expand(SE.getAddExpr(Ops));
+}
+
 Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
   const Type *Ty = SE.getEffectiveSCEVType(S->getType());
   Value *V = expand(S->getOperand(S->getNumOperands()-1));
+
+  // Turn things like ptrtoint+arithmetic+inttoptr into GEP. This helps
+  // BasicAliasAnalysis analyze the result. However, it suffers from the
+  // underlying bug described in PR2831. Addition in LLVM currently always
+  // has two's complement wrapping guaranteed. However, the semantics for
+  // getelementptr overflow are ambiguous. In the common case though, this
+  // expansion gets used when a GEP in the original code has been converted
+  // into integer arithmetic, in which case the resulting code will be no
+  // more undefined than it was originally.
+  if (SE.TD)
+    if (const PointerType *PTy = dyn_cast<PointerType>(V->getType()))
+      return expandAddToGEP(S, PTy, Ty, V);
+
   V = InsertNoopCastOfTo(V, Ty);
 
   // Emit a bunch of add instructions
@@ -157,7 +304,7 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
   }
   return V;
 }
-    
+
 Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {
   const Type *Ty = SE.getEffectiveSCEVType(S->getType());
   int FirstOp = 0;  // Set if we should emit a subtract.
@@ -206,15 +353,10 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
 
   // {X,+,F} --> X + {0,+,F}
   if (!S->getStart()->isZero()) {
-    Value *Start = expand(S->getStart());
-    Start = InsertNoopCastOfTo(Start, Ty);
-    std::vector<SCEVHandle> NewOps(S->op_begin(), S->op_end());
+    std::vector<SCEVHandle> NewOps(S->getOperands());
     NewOps[0] = SE.getIntegerSCEV(0, Ty);
     Value *Rest = expand(SE.getAddRecExpr(NewOps, L));
-    Rest = InsertNoopCastOfTo(Rest, Ty);
-
-    // FIXME: look for an existing add to use.
-    return InsertBinop(Instruction::Add, Rest, Start, InsertPt);
+    return expand(SE.getAddExpr(S->getStart(), SE.getUnknown(Rest)));
   }
 
   // {0,+,1} --> Insert a canonical induction variable into the loop!
@@ -265,7 +407,7 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
     // point loop.  If we can, move the multiply to the outer most loop that it
     // is safe to be in.
     BasicBlock::iterator MulInsertPt = getInsertionPoint();
-    Loop *InsertPtLoop = LI.getLoopFor(MulInsertPt->getParent());
+    Loop *InsertPtLoop = SE.LI->getLoopFor(MulInsertPt->getParent());
     if (InsertPtLoop != L && InsertPtLoop &&
         L->contains(InsertPtLoop->getHeader())) {
       do {
@@ -363,10 +505,13 @@ Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) {
 
 Value *SCEVExpander::expandCodeFor(SCEVHandle SH, const Type *Ty) {
   // Expand the code for this SCEV.
-  assert(SE.getTypeSizeInBits(Ty) == SE.getTypeSizeInBits(SH->getType()) &&
-         "non-trivial casts should be done with the SCEVs directly!");
   Value *V = expand(SH);
-  return InsertNoopCastOfTo(V, Ty);
+  if (Ty) {
+    assert(SE.getTypeSizeInBits(Ty) == SE.getTypeSizeInBits(SH->getType()) &&
+           "non-trivial casts should be done with the SCEVs directly!");
+    V = InsertNoopCastOfTo(V, Ty);
+  }
+  return V;
 }
 
 Value *SCEVExpander::expand(const SCEV *S) {