Add a SCEV class and supporting code for sign-extend expressions.
authorDan Gohman <gohman@apple.com>
Fri, 15 Jun 2007 14:38:12 +0000 (14:38 +0000)
committerDan Gohman <gohman@apple.com>
Fri, 15 Jun 2007 14:38:12 +0000 (14:38 +0000)
This created an ambiguity for expandInTy to decide when to use
sign-extension or zero-extension, but it turns out that most of its callers
don't actually need a type conversion, now that LLVM types don't have
explicit signedness. Drop expandInTy in favor of plain expand, and change
the few places that actually need a type conversion to do it themselves.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@37591 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolutionExpander.h
include/llvm/Analysis/ScalarEvolutionExpressions.h
lib/Analysis/ScalarEvolution.cpp
lib/Analysis/ScalarEvolutionExpander.cpp
lib/Transforms/Scalar/IndVarSimplify.cpp
lib/Transforms/Scalar/LoopStrengthReduce.cpp

index 44e8fb0a9a5b5300b3f152fdb0650edcaa019a06..a5cc7138cacba9bd53dfc722a513e9d3ea0a1a0e 100644 (file)
@@ -78,13 +78,10 @@ namespace llvm {
     /// expandCodeFor - Insert code to directly compute the specified SCEV
     /// expression into the program.  The inserted code is inserted into the
     /// specified block.
-    ///
-    /// If a particular value sign is required, a type may be specified for the
-    /// result.
-    Value *expandCodeFor(SCEVHandle SH, Instruction *IP, const Type *Ty = 0) {
+    Value *expandCodeFor(SCEVHandle SH, Instruction *IP) {
       // Expand the code for this SCEV.
       this->InsertPt = IP;
-      return expandInTy(SH, Ty);
+      return expand(SH);
     }
 
     /// InsertCastOfTo - Insert a cast of V to the specified type, doing what
@@ -107,25 +104,6 @@ namespace llvm {
       return V;
     }
 
-    Value *expandInTy(SCEV *S, const Type *Ty) {
-      Value *V = expand(S);
-      if (Ty && V->getType() != Ty) {
-        if (isa<PointerType>(Ty) && V->getType()->isInteger())
-          return InsertCastOfTo(Instruction::IntToPtr, V, Ty);
-        else if (Ty->isInteger() && isa<PointerType>(V->getType()))
-          return InsertCastOfTo(Instruction::PtrToInt, V, Ty);
-        else if (Ty->getPrimitiveSizeInBits() == 
-                 V->getType()->getPrimitiveSizeInBits())
-          return InsertCastOfTo(Instruction::BitCast, V, Ty);
-        else if (Ty->getPrimitiveSizeInBits() > 
-                 V->getType()->getPrimitiveSizeInBits())
-          return InsertCastOfTo(Instruction::ZExt, V, Ty);
-        else
-          return InsertCastOfTo(Instruction::Trunc, V, Ty);
-      }
-      return V;
-    }
-
     Value *visitConstant(SCEVConstant *S) {
       return S->getValue();
     }
@@ -136,17 +114,21 @@ namespace llvm {
     }
 
     Value *visitZeroExtendExpr(SCEVZeroExtendExpr *S) {
-      Value *V = expandInTy(S->getOperand(), S->getType());
+      Value *V = expand(S->getOperand());
       return CastInst::createZExtOrBitCast(V, S->getType(), "tmp.", InsertPt);
     }
 
+    Value *visitSignExtendExpr(SCEVSignExtendExpr *S) {
+      Value *V = expand(S->getOperand());
+      return CastInst::createSExtOrBitCast(V, S->getType(), "tmp.", InsertPt);
+    }
+
     Value *visitAddExpr(SCEVAddExpr *S) {
-      const Type *Ty = S->getType();
-      Value *V = expandInTy(S->getOperand(S->getNumOperands()-1), Ty);
+      Value *V = expand(S->getOperand(S->getNumOperands()-1));
 
       // Emit a bunch of add instructions
       for (int i = S->getNumOperands()-2; i >= 0; --i)
-        V = InsertBinop(Instruction::Add, V, expandInTy(S->getOperand(i), Ty),
+        V = InsertBinop(Instruction::Add, V, expand(S->getOperand(i)),
                         InsertPt);
       return V;
     }
@@ -154,9 +136,8 @@ namespace llvm {
     Value *visitMulExpr(SCEVMulExpr *S);
 
     Value *visitSDivExpr(SCEVSDivExpr *S) {
-      const Type *Ty = S->getType();
-      Value *LHS = expandInTy(S->getLHS(), Ty);
-      Value *RHS = expandInTy(S->getRHS(), Ty);
+      Value *LHS = expand(S->getLHS());
+      Value *RHS = expand(S->getRHS());
       return InsertBinop(Instruction::SDiv, LHS, RHS, InsertPt);
     }
 
index af795377c2b95e39d1557042d829d36ae8b53817..dd6871fdd1833a5dd0414f84fc2b5ac09c96df9a 100644 (file)
@@ -24,8 +24,8 @@ namespace llvm {
   enum SCEVTypes {
     // These should be ordered in terms of increasing complexity to make the
     // folders simpler.
-    scConstant, scTruncate, scZeroExtend, scAddExpr, scMulExpr, scSDivExpr,
-    scAddRecExpr, scUnknown, scCouldNotCompute
+    scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr,
+    scSDivExpr, scAddRecExpr, scUnknown, scCouldNotCompute
   };
 
   //===--------------------------------------------------------------------===//
@@ -166,6 +166,53 @@ namespace llvm {
     }
   };
 
+  //===--------------------------------------------------------------------===//
+  /// SCEVSignExtendExpr - This class represents a sign extension of a small
+  /// integer value to a larger integer value.
+  ///
+  class SCEVSignExtendExpr : public SCEV {
+    SCEVHandle Op;
+    const Type *Ty;
+    SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty);
+    virtual ~SCEVSignExtendExpr();
+  public:
+    /// get method - This just gets and returns a new SCEVSignExtend object
+    ///
+    static SCEVHandle get(const SCEVHandle &Op, const Type *Ty);
+
+    const SCEVHandle &getOperand() const { return Op; }
+    virtual const Type *getType() const { return Ty; }
+
+    virtual bool isLoopInvariant(const Loop *L) const {
+      return Op->isLoopInvariant(L);
+    }
+
+    virtual bool hasComputableLoopEvolution(const Loop *L) const {
+      return Op->hasComputableLoopEvolution(L);
+    }
+
+    /// getValueRange - Return the tightest constant bounds that this value is
+    /// known to have.  This method is only valid on integer SCEV objects.
+    virtual ConstantRange getValueRange() const;
+
+    SCEVHandle replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
+                                                 const SCEVHandle &Conc) const {
+      SCEVHandle H = Op->replaceSymbolicValuesWithConcrete(Sym, Conc);
+      if (H == Op)
+        return this;
+      return get(H, Ty);
+    }
+
+    virtual void print(std::ostream &OS) const;
+    void print(std::ostream *OS) const { if (OS) print(*OS); }
+
+    /// Methods for support type inquiry through isa, cast, and dyn_cast:
+    static inline bool classof(const SCEVSignExtendExpr *S) { return true; }
+    static inline bool classof(const SCEV *S) {
+      return S->getSCEVType() == scSignExtend;
+    }
+  };
+
 
   //===--------------------------------------------------------------------===//
   /// SCEVCommutativeExpr - This node is the base class for n'ary commutative
@@ -503,6 +550,8 @@ namespace llvm {
         return ((SC*)this)->visitTruncateExpr((SCEVTruncateExpr*)S);
       case scZeroExtend:
         return ((SC*)this)->visitZeroExtendExpr((SCEVZeroExtendExpr*)S);
+      case scSignExtend:
+        return ((SC*)this)->visitSignExtendExpr((SCEVSignExtendExpr*)S);
       case scAddExpr:
         return ((SC*)this)->visitAddExpr((SCEVAddExpr*)S);
       case scMulExpr:
index bf67fd3fffca2db3ec2e08ef34051af0a3487d07..3ae65286fa731547faf1862e7b637e21d4a3e6d8 100644 (file)
@@ -245,6 +245,32 @@ void SCEVZeroExtendExpr::print(std::ostream &OS) const {
   OS << "(zeroextend " << *Op << " to " << *Ty << ")";
 }
 
+// SCEVSignExtends - Only allow the creation of one SCEVSignExtendExpr for any
+// particular input.  Don't use a SCEVHandle here, or else the object will never
+// be deleted!
+static ManagedStatic<std::map<std::pair<SCEV*, const Type*>,
+                     SCEVSignExtendExpr*> > SCEVSignExtends;
+
+SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty)
+  : SCEV(scSignExtend), Op(op), Ty(ty) {
+  assert(Op->getType()->isInteger() && Ty->isInteger() &&
+         "Cannot sign extend non-integer value!");
+  assert(Op->getType()->getPrimitiveSizeInBits() < Ty->getPrimitiveSizeInBits()
+         && "This is not an extending conversion!");
+}
+
+SCEVSignExtendExpr::~SCEVSignExtendExpr() {
+  SCEVSignExtends->erase(std::make_pair(Op, Ty));
+}
+
+ConstantRange SCEVSignExtendExpr::getValueRange() const {
+  return getOperand()->getValueRange().signExtend(getBitWidth());
+}
+
+void SCEVSignExtendExpr::print(std::ostream &OS) const {
+  OS << "(signextend " << *Op << " to " << *Ty << ")";
+}
+
 // SCEVCommExprs - Only allow the creation of one SCEVCommutativeExpr for any
 // particular input.  Don't use a SCEVHandle here, or else the object will never
 // be deleted!
@@ -588,6 +614,21 @@ SCEVHandle SCEVZeroExtendExpr::get(const SCEVHandle &Op, const Type *Ty) {
   return Result;
 }
 
+SCEVHandle SCEVSignExtendExpr::get(const SCEVHandle &Op, const Type *Ty) {
+  if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
+    return SCEVUnknown::get(
+        ConstantExpr::getSExt(SC->getValue(), Ty));
+
+  // FIXME: If the input value is a chrec scev, and we can prove that the value
+  // did not overflow the old, smaller, value, we can sign extend all of the
+  // operands (often constants).  This would allow analysis of something like
+  // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
+
+  SCEVSignExtendExpr *&Result = (*SCEVSignExtends)[std::make_pair(Op, Ty)];
+  if (Result == 0) Result = new SCEVSignExtendExpr(Op, Ty);
+  return Result;
+}
+
 // get - Get a canonical add expression, or something simpler if possible.
 SCEVHandle SCEVAddExpr::get(std::vector<SCEVHandle> &Ops) {
   assert(!Ops.empty() && "Cannot get empty add!");
@@ -1370,6 +1411,9 @@ static APInt GetConstantFactor(SCEVHandle S) {
   if (SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S))
     return GetConstantFactor(E->getOperand()).zext(
                                cast<IntegerType>(E->getType())->getBitWidth());
+  if (SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S))
+    return GetConstantFactor(E->getOperand()).sext(
+                               cast<IntegerType>(E->getType())->getBitWidth());
   
   if (SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
     // The result is the min of all operands.
@@ -1470,6 +1514,9 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
     case Instruction::ZExt:
       return SCEVZeroExtendExpr::get(getSCEV(I->getOperand(0)), I->getType());
 
+    case Instruction::SExt:
+      return SCEVSignExtendExpr::get(getSCEV(I->getOperand(0)), I->getType());
+
     case Instruction::BitCast:
       // BitCasts are no-op casts so we just eliminate the cast.
       if (I->getType()->isInteger() &&
index c88c78119540970ae8b080c2f52ac6c374fee1a0..c8c683cb3f62e7fe68c64791d962af75b045ac9b 100644 (file)
@@ -93,18 +93,17 @@ Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode, Value *LHS,
 }
 
 Value *SCEVExpander::visitMulExpr(SCEVMulExpr *S) {
-  const Type *Ty = S->getType();
   int FirstOp = 0;  // Set if we should emit a subtract.
   if (SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getOperand(0)))
     if (SC->getValue()->isAllOnesValue())
       FirstOp = 1;
 
   int i = S->getNumOperands()-2;
-  Value *V = expandInTy(S->getOperand(i+1), Ty);
+  Value *V = expand(S->getOperand(i+1));
 
   // Emit a bunch of multiply instructions
   for (; i >= FirstOp; --i)
-    V = InsertBinop(Instruction::Mul, V, expandInTy(S->getOperand(i), Ty),
+    V = InsertBinop(Instruction::Mul, V, expand(S->getOperand(i)),
                     InsertPt);
   // -1 * ...  --->  0 - ...
   if (FirstOp == 1)
@@ -122,10 +121,10 @@ Value *SCEVExpander::visitAddRecExpr(SCEVAddRecExpr *S) {
   // {X,+,F} --> X + {0,+,F}
   if (!isa<SCEVConstant>(S->getStart()) ||
       !cast<SCEVConstant>(S->getStart())->getValue()->isZero()) {
-    Value *Start = expandInTy(S->getStart(), Ty);
+    Value *Start = expand(S->getStart());
     std::vector<SCEVHandle> NewOps(S->op_begin(), S->op_end());
     NewOps[0] = SCEVUnknown::getIntegerSCEV(0, Ty);
-    Value *Rest = expandInTy(SCEVAddRecExpr::get(NewOps, L), Ty);
+    Value *Rest = expand(SCEVAddRecExpr::get(NewOps, L));
 
     // FIXME: look for an existing add to use.
     return InsertBinop(Instruction::Add, Rest, Start, InsertPt);
@@ -164,7 +163,7 @@ Value *SCEVExpander::visitAddRecExpr(SCEVAddRecExpr *S) {
 
   // If this is a simple linear addrec, emit it now as a special case.
   if (S->getNumOperands() == 2) {   // {0,+,F} --> i*F
-    Value *F = expandInTy(S->getOperand(1), Ty);
+    Value *F = expand(S->getOperand(1));
     
     // IF the step is by one, just return the inserted IV.
     if (ConstantInt *CI = dyn_cast<ConstantInt>(F))
@@ -201,5 +200,5 @@ Value *SCEVExpander::visitAddRecExpr(SCEVAddRecExpr *S) {
   SCEVHandle V = S->evaluateAtIteration(IH);
   //cerr << "Evaluated: " << *this << "\n     to: " << *V << "\n";
 
-  return expandInTy(V, Ty);
+  return expand(V);
 }
index 8042d62581bfc1151e3847e5a2e6254c38bae343..5965d1a88559b70866328815a1d3826b136396ec 100644 (file)
@@ -277,8 +277,7 @@ Instruction *IndVarSimplify::LinearFunctionTestReplace(Loop *L,
 
   // Expand the code for the iteration count into the preheader of the loop.
   BasicBlock *Preheader = L->getLoopPreheader();
-  Value *ExitCnt = RW.expandCodeFor(TripCount, Preheader->getTerminator(),
-                                    IndVar->getType());
+  Value *ExitCnt = RW.expandCodeFor(TripCount, Preheader->getTerminator());
 
   // Insert a new icmp_ne or icmp_eq instruction before the branch.
   ICmpInst::Predicate Opcode;
@@ -383,7 +382,7 @@ void IndVarSimplify::RewriteLoopExitValues(Loop *L) {
         // just reuse it.
         Value *&ExitVal = ExitValues[Inst];
         if (!ExitVal)
-          ExitVal = Rewriter.expandCodeFor(ExitValue, InsertPt,Inst->getType());
+          ExitVal = Rewriter.expandCodeFor(ExitValue, InsertPt);
         
         DOUT << "INDVARS: RLEV: AfterLoopVal = " << *ExitVal
              << "  LoopVal = " << *Inst << "\n";
@@ -519,9 +518,12 @@ bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) {
   Changed = true;
   DOUT << "INDVARS: New CanIV: " << *IndVar;
 
-  if (!isa<SCEVCouldNotCompute>(IterationCount))
+  if (!isa<SCEVCouldNotCompute>(IterationCount)) {
+    if (IterationCount->getType() != LargestType)
+      IterationCount = SCEVZeroExtendExpr::get(IterationCount, LargestType);
     if (Instruction *DI = LinearFunctionTestReplace(L, IterationCount,Rewriter))
       DeadInsts.insert(DI);
+  }
 
   // Now that we have a canonical induction variable, we can rewrite any
   // recurrences in terms of the induction variable.  Start with the auxillary
@@ -555,8 +557,7 @@ bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) {
   std::map<unsigned, Value*> InsertedSizes;
   while (!IndVars.empty()) {
     PHINode *PN = IndVars.back().first;
-    Value *NewVal = Rewriter.expandCodeFor(IndVars.back().second, InsertPt,
-                                           PN->getType());
+    Value *NewVal = Rewriter.expandCodeFor(IndVars.back().second, InsertPt);
     DOUT << "INDVARS: Rewrote IV '" << *IndVars.back().second << "' " << *PN
          << "   into = " << *NewVal << "\n";
     NewVal->takeName(PN);
index bd7d1d9249325e7b1d4c18e9f77e301f66cb6203..0c4807d31aa6163ca0d16a161a8f6e8c047ca495 100644 (file)
@@ -555,8 +555,7 @@ Value *BasedUser::InsertCodeForBaseAtPosition(const SCEVHandle &NewBase,
   // If there is no immediate value, skip the next part.
   if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Imm))
     if (SC->getValue()->isZero())
-      return Rewriter.expandCodeFor(NewBase, BaseInsertPt,
-                                    OperandValToReplace->getType());
+      return Rewriter.expandCodeFor(NewBase, BaseInsertPt);
 
   Value *Base = Rewriter.expandCodeFor(NewBase, BaseInsertPt);
 
@@ -567,8 +566,7 @@ Value *BasedUser::InsertCodeForBaseAtPosition(const SCEVHandle &NewBase,
   
   // Always emit the immediate (if non-zero) into the same block as the user.
   SCEVHandle NewValSCEV = SCEVAddExpr::get(SCEVUnknown::get(Base), Imm);
-  return Rewriter.expandCodeFor(NewValSCEV, IP,
-                                OperandValToReplace->getType());
+  return Rewriter.expandCodeFor(NewValSCEV, IP);
   
 }
 
@@ -598,6 +596,11 @@ void BasedUser::RewriteInstructionToUseNewBase(const SCEVHandle &NewBase,
       }
     }
     Value *NewVal = InsertCodeForBaseAtPosition(NewBase, Rewriter, InsertPt, L);
+    // Adjust the type back to match the Inst.
+    if (isa<PointerType>(OperandValToReplace->getType())) {
+      NewVal = new IntToPtrInst(NewVal, OperandValToReplace->getType(), "cast",
+                                InsertPt);
+    }
     // Replace the use of the operand Value with the new Phi we just created.
     Inst->replaceUsesOfWith(OperandValToReplace, NewVal);
     DOUT << "    CHANGED: IMM =" << *Imm;
@@ -644,6 +647,11 @@ void BasedUser::RewriteInstructionToUseNewBase(const SCEVHandle &NewBase,
         // Insert the code into the end of the predecessor block.
         Instruction *InsertPt = PN->getIncomingBlock(i)->getTerminator();
         Code = InsertCodeForBaseAtPosition(NewBase, Rewriter, InsertPt, L);
+
+        // Adjust the type back to match the PHI.
+        if (isa<PointerType>(PN->getType())) {
+          Code = new IntToPtrInst(Code, PN->getType(), "cast", InsertPt);
+        }
       }
       
       // Replace the use of the operand Value with the new Phi we just created.
@@ -1112,8 +1120,7 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(const SCEVHandle &Stride,
 
   // Emit the initial base value into the loop preheader.
   Value *CommonBaseV
-    = PreheaderRewriter.expandCodeFor(CommonExprs, PreInsertPt,
-                                      ReplacedTy);
+    = PreheaderRewriter.expandCodeFor(CommonExprs, PreInsertPt);
 
   if (RewriteFactor == 0) {
     // Create a new Phi for this base, and stick it in the loop header.
@@ -1131,8 +1138,7 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(const SCEVHandle &Stride,
       IncAmount = SCEV::getNegativeSCEV(Stride);
     
     // Insert the stride into the preheader.
-    Value *StrideV = PreheaderRewriter.expandCodeFor(IncAmount, PreInsertPt,
-                                                     ReplacedTy);
+    Value *StrideV = PreheaderRewriter.expandCodeFor(IncAmount, PreInsertPt);
     if (!isa<ConstantInt>(StrideV)) ++NumVariable;
 
     // Emit the increment of the base value before the terminator of the loop
@@ -1142,8 +1148,7 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(const SCEVHandle &Stride,
       IncExp = SCEV::getNegativeSCEV(IncExp);
     IncExp = SCEVAddExpr::get(SCEVUnknown::get(NewPHI), IncExp);
   
-    IncV = Rewriter.expandCodeFor(IncExp, LatchBlock->getTerminator(),
-                                  ReplacedTy);
+    IncV = Rewriter.expandCodeFor(IncExp, LatchBlock->getTerminator());
     IncV->setName(NewPHI->getName()+".inc");
     NewPHI->addIncoming(IncV, LatchBlock);
 
@@ -1199,8 +1204,7 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(const SCEVHandle &Stride,
     SCEVHandle Base = UsersToProcess.back().Base;
 
     // Emit the code for Base into the preheader.
-    Value *BaseV = PreheaderRewriter.expandCodeFor(Base, PreInsertPt,
-                                                   ReplacedTy);
+    Value *BaseV = PreheaderRewriter.expandCodeFor(Base, PreInsertPt);
 
     DOUT << "  INSERTING code for BASE = " << *Base << ":";
     if (BaseV->hasName())