inline the FoldICmpLogical functor.
authorChris Lattner <sabre@nondot.org>
Tue, 5 Jan 2010 06:59:49 +0000 (06:59 +0000)
committerChris Lattner <sabre@nondot.org>
Tue, 5 Jan 2010 06:59:49 +0000 (06:59 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@92695 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/InstCombine/InstructionCombining.cpp

index 236ab9c206a2f9a9ec4e4abe97b5df6ee24adb0b..f9acd20b55526b033fcfdcfc4dbe37561654fc5e 100644 (file)
@@ -1254,33 +1254,33 @@ static unsigned getFCmpCode(FCmpInst::Predicate CC, bool &isOrdered) {
 /// opcode and two operands into either a constant true or false, or a brand 
 /// new ICmp instruction. The sign is passed in to determine which kind
 /// of predicate to use in the new icmp instruction.
-static Value *getICmpValue(bool sign, unsigned code, Value *LHS, Value *RHS) {
-  switch (code) {
-  default: llvm_unreachable("Illegal ICmp code!");
-  case  0: return ConstantInt::getFalse(LHS->getContext());
-  case  1: 
-    if (sign)
+static Value *getICmpValue(bool Sign, unsigned Code, Value *LHS, Value *RHS) {
+  switch (Code) {
+  default: assert(0 && "Illegal ICmp code!");
+  case 0:
+    return ConstantInt::getFalse(LHS->getContext());
+  case 1: 
+    if (Sign)
       return new ICmpInst(ICmpInst::ICMP_SGT, LHS, RHS);
-    else
-      return new ICmpInst(ICmpInst::ICMP_UGT, LHS, RHS);
-  case  2: return new ICmpInst(ICmpInst::ICMP_EQ,  LHS, RHS);
-  case  3: 
-    if (sign)
+    return new ICmpInst(ICmpInst::ICMP_UGT, LHS, RHS);
+  case 2:
+    return new ICmpInst(ICmpInst::ICMP_EQ,  LHS, RHS);
+  case 3: 
+    if (Sign)
       return new ICmpInst(ICmpInst::ICMP_SGE, LHS, RHS);
-    else
-      return new ICmpInst(ICmpInst::ICMP_UGE, LHS, RHS);
-  case  4: 
-    if (sign)
+    return new ICmpInst(ICmpInst::ICMP_UGE, LHS, RHS);
+  case 4: 
+    if (Sign)
       return new ICmpInst(ICmpInst::ICMP_SLT, LHS, RHS);
-    else
-      return new ICmpInst(ICmpInst::ICMP_ULT, LHS, RHS);
-  case  5: return new ICmpInst(ICmpInst::ICMP_NE,  LHS, RHS);
-  case  6: 
-    if (sign)
+    return new ICmpInst(ICmpInst::ICMP_ULT, LHS, RHS);
+  case 5:
+    return new ICmpInst(ICmpInst::ICMP_NE,  LHS, RHS);
+  case 6: 
+    if (Sign)
       return new ICmpInst(ICmpInst::ICMP_SLE, LHS, RHS);
-    else
-      return new ICmpInst(ICmpInst::ICMP_ULE, LHS, RHS);
-  case  7: return ConstantInt::getTrue(LHS->getContext());
+    return new ICmpInst(ICmpInst::ICMP_ULE, LHS, RHS);
+  case 7:
+    return ConstantInt::getTrue(LHS->getContext());
   }
 }
 
@@ -1338,50 +1338,6 @@ static bool PredicatesFoldable(ICmpInst::Predicate p1, ICmpInst::Predicate p2) {
          (CmpInst::isSigned(p2) && ICmpInst::isEquality(p1));
 }
 
-namespace { 
-// FoldICmpLogical - Implements (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B)
-struct FoldICmpLogical {
-  InstCombiner &IC;
-  Value *LHS, *RHS;
-  ICmpInst::Predicate pred;
-  FoldICmpLogical(InstCombiner &ic, ICmpInst *ICI)
-    : IC(ic), LHS(ICI->getOperand(0)), RHS(ICI->getOperand(1)),
-      pred(ICI->getPredicate()) {}
-  bool shouldApply(Value *V) const {
-    if (ICmpInst *ICI = dyn_cast<ICmpInst>(V))
-      if (PredicatesFoldable(pred, ICI->getPredicate()))
-        return ((ICI->getOperand(0) == LHS && ICI->getOperand(1) == RHS) ||
-                (ICI->getOperand(0) == RHS && ICI->getOperand(1) == LHS));
-    return false;
-  }
-  Instruction *apply(Instruction &Log) const {
-    ICmpInst *ICI = cast<ICmpInst>(Log.getOperand(0));
-    if (ICI->getOperand(0) != LHS) {
-      assert(ICI->getOperand(1) == LHS);
-      ICI->swapOperands();  // Swap the LHS and RHS of the ICmp
-    }
-
-    ICmpInst *RHSICI = cast<ICmpInst>(Log.getOperand(1));
-    unsigned LHSCode = getICmpCode(ICI);
-    unsigned RHSCode = getICmpCode(RHSICI);
-    unsigned Code;
-    switch (Log.getOpcode()) {
-    case Instruction::And: Code = LHSCode & RHSCode; break;
-    case Instruction::Or:  Code = LHSCode | RHSCode; break;
-    case Instruction::Xor: Code = LHSCode ^ RHSCode; break;
-    default: llvm_unreachable("Illegal logical opcode!"); return 0;
-    }
-
-    bool isSigned = RHSICI->isSigned() || ICI->isSigned();
-    Value *RV = getICmpValue(isSigned, Code, LHS, RHS);
-    if (Instruction *I = dyn_cast<Instruction>(RV))
-      return I;
-    // Otherwise, it's a constant boolean value...
-    return IC.ReplaceInstUsesWith(Log, RV);
-  }
-};
-} // end anonymous namespace
-
 // OptAndOp - This handles expressions of the form ((val OP C1) & C2).  Where
 // the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'.  Op is
 // guaranteed to be a binary operator.
@@ -1635,16 +1591,31 @@ Value *InstCombiner::FoldLogicalPlusAnd(Value *LHS, Value *RHS,
 /// FoldAndOfICmps - Fold (icmp)&(icmp) if possible.
 Instruction *InstCombiner::FoldAndOfICmps(Instruction &I,
                                           ICmpInst *LHS, ICmpInst *RHS) {
-  Value *Val, *Val2;
-  ConstantInt *LHSCst, *RHSCst;
-  ICmpInst::Predicate LHSCC, RHSCC;
+  ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate();
+
+  // (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B)
+  if (PredicatesFoldable(LHSCC, RHSCC)) {
+    if (LHS->getOperand(0) == RHS->getOperand(1) &&
+        LHS->getOperand(1) == RHS->getOperand(0))
+      LHS->swapOperands();
+    if (LHS->getOperand(0) == RHS->getOperand(0) &&
+        LHS->getOperand(1) == RHS->getOperand(1)) {
+      Value *Op0 = LHS->getOperand(0), *Op1 = LHS->getOperand(1);
+      unsigned Code = getICmpCode(LHS) & getICmpCode(RHS);
+      bool isSigned = LHS->isSigned() || RHS->isSigned();
+      Value *RV = getICmpValue(isSigned, Code, Op0, Op1);
+      if (Instruction *I = dyn_cast<Instruction>(RV))
+        return I;
+      // Otherwise, it's a constant boolean value.
+      return ReplaceInstUsesWith(I, RV);
+    }
+  }
   
   // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2).
-  if (!match(LHS, m_ICmp(LHSCC, m_Value(Val),
-                         m_ConstantInt(LHSCst))) ||
-      !match(RHS, m_ICmp(RHSCC, m_Value(Val2),
-                         m_ConstantInt(RHSCst))))
-    return 0;
+  Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0);
+  ConstantInt *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1));
+  ConstantInt *RHSCst = dyn_cast<ConstantInt>(RHS->getOperand(1));
+  if (LHSCst == 0 || RHSCst == 0) return 0;
   
   if (LHSCst == RHSCst && LHSCC == RHSCC) {
     // (icmp ult A, C) & (icmp ult B, C) --> (icmp ult (A|B), C)
@@ -1696,7 +1667,7 @@ Instruction *InstCombiner::FoldAndOfICmps(Instruction &I,
   // comparing a value against two constants and and'ing the result
   // together.  Because of the above check, we know that we only have
   // icmp eq, icmp ne, icmp [su]lt, and icmp [SU]gt here. We also know 
-  // (from the FoldICmpLogical check above), that the two constants 
+  // (from the icmp folding check above), that the two constants 
   // are not equal and that the larger constant is on the RHS
   assert(LHSCst != RHSCst && "Compares not folded above?");
 
@@ -2074,15 +2045,10 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
       return BinaryOperator::CreateAnd(A, Op0);
   }
   
-  if (ICmpInst *RHS = dyn_cast<ICmpInst>(Op1)) {
-    // (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B)
-    if (Instruction *R = AssociativeOpt(I, FoldICmpLogical(*this, RHS)))
-      return R;
-
+  if (ICmpInst *RHS = dyn_cast<ICmpInst>(Op1))
     if (ICmpInst *LHS = dyn_cast<ICmpInst>(Op0))
       if (Instruction *Res = FoldAndOfICmps(I, LHS, RHS))
         return Res;
-  }
 
   // fold (and (cast A), (cast B)) -> (cast (and A, B))
   if (CastInst *Op0C = dyn_cast<CastInst>(Op0))
@@ -2312,16 +2278,32 @@ static Instruction *MatchSelectFromAndOr(Value *A, Value *B,
 /// FoldOrOfICmps - Fold (icmp)|(icmp) if possible.
 Instruction *InstCombiner::FoldOrOfICmps(Instruction &I,
                                          ICmpInst *LHS, ICmpInst *RHS) {
-  Value *Val, *Val2;
-  ConstantInt *LHSCst, *RHSCst;
-  ICmpInst::Predicate LHSCC, RHSCC;
+  ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate();
+
+  // (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B)
+  if (PredicatesFoldable(LHSCC, RHSCC)) {
+    if (LHS->getOperand(0) == RHS->getOperand(1) &&
+        LHS->getOperand(1) == RHS->getOperand(0))
+      LHS->swapOperands();
+    if (LHS->getOperand(0) == RHS->getOperand(0) &&
+        LHS->getOperand(1) == RHS->getOperand(1)) {
+      Value *Op0 = LHS->getOperand(0), *Op1 = LHS->getOperand(1);
+      unsigned Code = getICmpCode(LHS) | getICmpCode(RHS);
+      bool isSigned = LHS->isSigned() || RHS->isSigned();
+      Value *RV = getICmpValue(isSigned, Code, Op0, Op1);
+      if (Instruction *I = dyn_cast<Instruction>(RV))
+        return I;
+      // Otherwise, it's a constant boolean value.
+      return ReplaceInstUsesWith(I, RV);
+    }
+  }
   
   // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2).
-  if (!match(LHS, m_ICmp(LHSCC, m_Value(Val), m_ConstantInt(LHSCst))) ||
-      !match(RHS, m_ICmp(RHSCC, m_Value(Val2), m_ConstantInt(RHSCst))))
-    return 0;
+  Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0);
+  ConstantInt *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1));
+  ConstantInt *RHSCst = dyn_cast<ConstantInt>(RHS->getOperand(1));
+  if (LHSCst == 0 || RHSCst == 0) return 0;
 
-  
   // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0)
   if (LHSCst == RHSCst && LHSCC == RHSCC &&
       LHSCC == ICmpInst::ICMP_NE && LHSCst->isZero()) {
@@ -2363,7 +2345,7 @@ Instruction *InstCombiner::FoldOrOfICmps(Instruction &I,
   // comparing a value against two constants and or'ing the result
   // together.  Because of the above check, we know that we only have
   // ICMP_EQ, ICMP_NE, ICMP_LT, and ICMP_GT here. We also know (from the
-  // FoldICmpLogical check above), that the two constants are not
+  // icmp folding check above), that the two constants are not
   // equal.
   assert(LHSCst != RHSCst && "Compares not folded above?");
 
@@ -2780,15 +2762,10 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
         return BinaryOperator::CreateNot(And);
       }
 
-  // (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B)
-  if (ICmpInst *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) {
-    if (Instruction *R = AssociativeOpt(I, FoldICmpLogical(*this, RHS)))
-      return R;
-
+  if (ICmpInst *RHS = dyn_cast<ICmpInst>(I.getOperand(1)))
     if (ICmpInst *LHS = dyn_cast<ICmpInst>(I.getOperand(0)))
       if (Instruction *Res = FoldOrOfICmps(I, LHS, RHS))
         return Res;
-  }
     
   // fold (or (cast A), (cast B)) -> (cast (or A, B))
   if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) {
@@ -3093,8 +3070,23 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
     
   // (icmp1 A, B) ^ (icmp2 A, B) --> (icmp3 A, B)
   if (ICmpInst *RHS = dyn_cast<ICmpInst>(I.getOperand(1)))
-    if (Instruction *R = AssociativeOpt(I, FoldICmpLogical(*this, RHS)))
-      return R;
+    if (ICmpInst *LHS = dyn_cast<ICmpInst>(I.getOperand(0)))
+      if (PredicatesFoldable(LHS->getPredicate(), RHS->getPredicate())) {
+        if (LHS->getOperand(0) == RHS->getOperand(1) &&
+            LHS->getOperand(1) == RHS->getOperand(0))
+          LHS->swapOperands();
+        if (LHS->getOperand(0) == RHS->getOperand(0) &&
+            LHS->getOperand(1) == RHS->getOperand(1)) {
+          Value *Op0 = LHS->getOperand(0), *Op1 = LHS->getOperand(1);
+          unsigned Code = getICmpCode(LHS) ^ getICmpCode(RHS);
+          bool isSigned = LHS->isSigned() || RHS->isSigned();
+          Value *RV = getICmpValue(isSigned, Code, Op0, Op1);
+          if (Instruction *I = dyn_cast<Instruction>(RV))
+            return I;
+          // Otherwise, it's a constant boolean value.
+          return ReplaceInstUsesWith(I, RV);
+        }
+      }
 
   // fold (xor (cast A), (cast B)) -> (cast (xor A, B))
   if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) {