Remove a bunch of duplicated code. Among other things, this fixes
authorChris Lattner <sabre@nondot.org>
Fri, 12 Jan 2007 18:42:52 +0000 (18:42 +0000)
committerChris Lattner <sabre@nondot.org>
Fri, 12 Jan 2007 18:42:52 +0000 (18:42 +0000)
constant folding of signed comparisons of bool.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@33134 91177308-0d34-0410-b5e6-96231b3b80d8

lib/VMCore/ConstantFold.cpp

index 8aeca7b47a51b747351c4cce1f3b70ec2489032f..dcd8657bd44fa297d041e61a0f3d43fd9cac0678 100644 (file)
@@ -554,71 +554,55 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
   // so look at directly computing the value.
   if (const ConstantInt *CI1 = dyn_cast<ConstantInt>(C1)) {
     if (const ConstantInt *CI2 = dyn_cast<ConstantInt>(C2)) {
-      if (CI1->getType() == Type::Int1Ty && CI2->getType() == Type::Int1Ty) {
-        switch (Opcode) {
-          default:
-            break;
-          case Instruction::And:
-            return ConstantInt::get(Type::Int1Ty, 
-                                    CI1->getZExtValue() & CI2->getZExtValue());
-          case Instruction::Or:
-            return ConstantInt::get(Type::Int1Ty, 
-                                    CI1->getZExtValue() | CI2->getZExtValue());
-          case Instruction::Xor:
-            return ConstantInt::get(Type::Int1Ty, 
-                                    CI1->getZExtValue() ^ CI2->getZExtValue());
-        }
-      } else {
-        uint64_t C1Val = CI1->getZExtValue();
-        uint64_t C2Val = CI2->getZExtValue();
-        switch (Opcode) {
-        default:
-          break;
-        case Instruction::Add:     
-          return ConstantInt::get(C1->getType(), C1Val + C2Val);
-        case Instruction::Sub:     
-          return ConstantInt::get(C1->getType(), C1Val - C2Val);
-        case Instruction::Mul:     
-          return ConstantInt::get(C1->getType(), C1Val * C2Val);
-        case Instruction::UDiv:
-          if (CI2->isNullValue())                  // X / 0 -> can't fold
-            return 0;
-          return ConstantInt::get(C1->getType(), C1Val / C2Val);
-        case Instruction::SDiv:
-          if (CI2->isNullValue()) return 0;        // X / 0 -> can't fold
-          if (CI2->isAllOnesValue() &&
-              (((CI1->getType()->getPrimitiveSizeInBits() == 64) && 
-                (CI1->getSExtValue() == INT64_MIN)) ||
-               (CI1->getSExtValue() == -CI1->getSExtValue())))
-            return 0;                              // MIN_INT / -1 -> overflow
-          return ConstantInt::get(C1->getType(), 
-                                  CI1->getSExtValue() / CI2->getSExtValue());
-        case Instruction::URem:    
-          if (C2->isNullValue()) return 0;         // X / 0 -> can't fold
-          return ConstantInt::get(C1->getType(), C1Val % C2Val);
-        case Instruction::SRem:    
-          if (CI2->isNullValue()) return 0;        // X % 0 -> can't fold
-          if (CI2->isAllOnesValue() &&              
-              (((CI1->getType()->getPrimitiveSizeInBits() == 64) && 
-                (CI1->getSExtValue() == INT64_MIN)) ||
-               (CI1->getSExtValue() == -CI1->getSExtValue())))
-            return 0;                              // MIN_INT % -1 -> overflow
-          return ConstantInt::get(C1->getType(), 
-                                  CI1->getSExtValue() % CI2->getSExtValue());
-        case Instruction::And:
-          return ConstantInt::get(C1->getType(), C1Val & C2Val);
-        case Instruction::Or:
-          return ConstantInt::get(C1->getType(), C1Val | C2Val);
-        case Instruction::Xor:
-          return ConstantInt::get(C1->getType(), C1Val ^ C2Val);
-        case Instruction::Shl:
-          return ConstantInt::get(C1->getType(), C1Val << C2Val);
-        case Instruction::LShr:
-          return ConstantInt::get(C1->getType(), C1Val >> C2Val);
-        case Instruction::AShr:
-          return ConstantInt::get(C1->getType(), 
-                                  CI1->getSExtValue() >> C2Val);
-        }
+      uint64_t C1Val = CI1->getZExtValue();
+      uint64_t C2Val = CI2->getZExtValue();
+      switch (Opcode) {
+      default:
+        break;
+      case Instruction::Add:     
+        return ConstantInt::get(C1->getType(), C1Val + C2Val);
+      case Instruction::Sub:     
+        return ConstantInt::get(C1->getType(), C1Val - C2Val);
+      case Instruction::Mul:     
+        return ConstantInt::get(C1->getType(), C1Val * C2Val);
+      case Instruction::UDiv:
+        if (CI2->isNullValue())                  // X / 0 -> can't fold
+          return 0;
+        return ConstantInt::get(C1->getType(), C1Val / C2Val);
+      case Instruction::SDiv:
+        if (CI2->isNullValue()) return 0;        // X / 0 -> can't fold
+        if (CI2->isAllOnesValue() &&
+            (((CI1->getType()->getPrimitiveSizeInBits() == 64) && 
+              (CI1->getSExtValue() == INT64_MIN)) ||
+             (CI1->getSExtValue() == -CI1->getSExtValue())))
+          return 0;                              // MIN_INT / -1 -> overflow
+        return ConstantInt::get(C1->getType(), 
+                                CI1->getSExtValue() / CI2->getSExtValue());
+      case Instruction::URem:    
+        if (C2->isNullValue()) return 0;         // X / 0 -> can't fold
+        return ConstantInt::get(C1->getType(), C1Val % C2Val);
+      case Instruction::SRem:    
+        if (CI2->isNullValue()) return 0;        // X % 0 -> can't fold
+        if (CI2->isAllOnesValue() &&              
+            (((CI1->getType()->getPrimitiveSizeInBits() == 64) && 
+              (CI1->getSExtValue() == INT64_MIN)) ||
+             (CI1->getSExtValue() == -CI1->getSExtValue())))
+          return 0;                              // MIN_INT % -1 -> overflow
+        return ConstantInt::get(C1->getType(), 
+                                CI1->getSExtValue() % CI2->getSExtValue());
+      case Instruction::And:
+        return ConstantInt::get(C1->getType(), C1Val & C2Val);
+      case Instruction::Or:
+        return ConstantInt::get(C1->getType(), C1Val | C2Val);
+      case Instruction::Xor:
+        return ConstantInt::get(C1->getType(), C1Val ^ C2Val);
+      case Instruction::Shl:
+        return ConstantInt::get(C1->getType(), C1Val << C2Val);
+      case Instruction::LShr:
+        return ConstantInt::get(C1->getType(), C1Val >> C2Val);
+      case Instruction::AShr:
+        return ConstantInt::get(C1->getType(), 
+                                CI1->getSExtValue() >> C2Val);
       }
     }
   } else if (const ConstantFP *CFP1 = dyn_cast<ConstantFP>(C1)) {
@@ -1059,34 +1043,7 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
           return ConstantInt::getTrue();
   }
 
-  if (isa<ConstantInt>(C1) && isa<ConstantInt>(C2) &&
-      C1->getType() == Type::Int1Ty && C2->getType() == Type::Int1Ty) {
-    bool C1Val = cast<ConstantInt>(C1)->getZExtValue();
-    bool C2Val = cast<ConstantInt>(C2)->getZExtValue();
-    switch (pred) {
-    default: assert(0 && "Invalid ICmp Predicate"); return 0;
-    case ICmpInst::ICMP_EQ: 
-      return ConstantInt::get(Type::Int1Ty, C1Val == C2Val);
-    case ICmpInst::ICMP_NE: 
-      return ConstantInt::get(Type::Int1Ty, C1Val != C2Val);
-    case ICmpInst::ICMP_ULT:
-      return ConstantInt::get(Type::Int1Ty, C1Val <  C2Val);
-    case ICmpInst::ICMP_UGT:
-      return ConstantInt::get(Type::Int1Ty, C1Val >  C2Val);
-    case ICmpInst::ICMP_ULE:
-      return ConstantInt::get(Type::Int1Ty, C1Val <= C2Val);
-    case ICmpInst::ICMP_UGE:
-      return ConstantInt::get(Type::Int1Ty, C1Val >= C2Val);
-    case ICmpInst::ICMP_SLT:
-      return ConstantInt::get(Type::Int1Ty, C1Val <  C2Val);
-    case ICmpInst::ICMP_SGT:
-      return ConstantInt::get(Type::Int1Ty, C1Val >  C2Val);
-    case ICmpInst::ICMP_SLE:
-      return ConstantInt::get(Type::Int1Ty, C1Val <= C2Val);
-    case ICmpInst::ICMP_SGE:
-      return ConstantInt::get(Type::Int1Ty, C1Val >= C2Val);
-    }
-  } else if (isa<ConstantInt>(C1) && isa<ConstantInt>(C2)) {
+  if (isa<ConstantInt>(C1) && isa<ConstantInt>(C2)) {
     if (ICmpInst::isSignedPredicate(ICmpInst::Predicate(pred))) {
       int64_t V1 = cast<ConstantInt>(C1)->getSExtValue();
       int64_t V2 = cast<ConstantInt>(C2)->getSExtValue();