For PR950:
[oota-llvm.git] / lib / Transforms / Scalar / PredicateSimplifier.cpp
index 8a677c4461cded940dd7f7a6ac0e26b407957327..6578bccff12e1c32deefe8c1183f03bbbd30efef 100644 (file)
@@ -559,7 +559,8 @@ namespace {
     void addToWorklist(Instruction *I) {
       //DOUT << "addToWorklist: " << *I << "\n";
 
-      if (!isa<BinaryOperator>(I) && !isa<SelectInst>(I)) return;
+      if (!isa<BinaryOperator>(I) && !isa<SelectInst>(I) && !isa<CmpInst>(I)) 
+        return;
 
       const Type *Ty = I->getType();
       if (Ty == Type::VoidTy || Ty->isFPOrFPVector()) return;
@@ -855,102 +856,6 @@ namespace {
             addEqual(BO, ConstantExpr::get(BO->getOpcode(), CI1, CI2));
 
           switch (BO->getOpcode()) {
-            case Instruction::SetEQ:
-              // "seteq int %a, %b" EQ true  then %a EQ %b
-              // "seteq int %a, %b" EQ false then %a NE %b
-              if (Canonical == ConstantBool::getTrue())
-                addEqual(Op0, Op1);
-              else if (Canonical == ConstantBool::getFalse())
-                addNotEqual(Op0, Op1);
-
-              // %a EQ %b then "seteq int %a, %b" EQ true
-              // %a NE %b then "seteq int %a, %b" EQ false
-              if (isEqual(Op0, Op1))
-                addEqual(BO, ConstantBool::getTrue());
-              else if (isNotEqual(Op0, Op1))
-                addEqual(BO, ConstantBool::getFalse());
-
-              break;
-            case Instruction::SetNE:
-              // "setne int %a, %b" EQ true  then %a NE %b
-              // "setne int %a, %b" EQ false then %a EQ %b
-              if (Canonical == ConstantBool::getTrue())
-                addNotEqual(Op0, Op1);
-              else if (Canonical == ConstantBool::getFalse())
-                addEqual(Op0, Op1);
-
-              // %a EQ %b then "setne int %a, %b" EQ false
-              // %a NE %b then "setne int %a, %b" EQ true
-              if (isEqual(Op0, Op1))
-                addEqual(BO, ConstantBool::getFalse());
-              else if (isNotEqual(Op0, Op1))
-                addEqual(BO, ConstantBool::getTrue());
-
-              break;
-            case Instruction::SetLT:
-              // "setlt int %a, %b" EQ true  then %a LT %b
-              // "setlt int %a, %b" EQ false then %b LE %a
-              if (Canonical == ConstantBool::getTrue())
-                addLess(Op0, Op1);
-              else if (Canonical == ConstantBool::getFalse())
-                addLessEqual(Op1, Op0);
-
-              // %a LT %b then "setlt int %a, %b" EQ true
-              // %a GE %b then "setlt int %a, %b" EQ false
-              if (isLess(Op0, Op1))
-                addEqual(BO, ConstantBool::getTrue());
-              else if (isGreaterEqual(Op0, Op1))
-                addEqual(BO, ConstantBool::getFalse());
-
-              break;
-            case Instruction::SetLE:
-              // "setle int %a, %b" EQ true  then %a LE %b
-              // "setle int %a, %b" EQ false then %b LT %a
-              if (Canonical == ConstantBool::getTrue())
-                addLessEqual(Op0, Op1);
-              else if (Canonical == ConstantBool::getFalse())
-                addLess(Op1, Op0);
-
-              // %a LE %b then "setle int %a, %b" EQ true
-              // %a GT %b then "setle int %a, %b" EQ false
-              if (isLessEqual(Op0, Op1))
-                addEqual(BO, ConstantBool::getTrue());
-              else if (isGreater(Op0, Op1))
-                addEqual(BO, ConstantBool::getFalse());
-
-              break;
-            case Instruction::SetGT:
-              // "setgt int %a, %b" EQ true  then %b LT %a
-              // "setgt int %a, %b" EQ false then %a LE %b
-              if (Canonical == ConstantBool::getTrue())
-                addLess(Op1, Op0);
-              else if (Canonical == ConstantBool::getFalse())
-                addLessEqual(Op0, Op1);
-
-              // %a GT %b then "setgt int %a, %b" EQ true
-              // %a LE %b then "setgt int %a, %b" EQ false
-              if (isGreater(Op0, Op1))
-                addEqual(BO, ConstantBool::getTrue());
-              else if (isLessEqual(Op0, Op1))
-                addEqual(BO, ConstantBool::getFalse());
-
-              break;
-            case Instruction::SetGE:
-              // "setge int %a, %b" EQ true  then %b LE %a
-              // "setge int %a, %b" EQ false then %a LT %b
-              if (Canonical == ConstantBool::getTrue())
-                addLessEqual(Op1, Op0);
-              else if (Canonical == ConstantBool::getFalse())
-                addLess(Op0, Op1);
-
-              // %a GE %b then "setge int %a, %b" EQ true
-              // %a LT %b then "setlt int %a, %b" EQ false
-              if (isGreaterEqual(Op0, Op1))
-                addEqual(BO, ConstantBool::getTrue());
-              else if (isLess(Op0, Op1))
-                addEqual(BO, ConstantBool::getFalse());
-
-              break;
             case Instruction::And: {
               // "and int %a, %b"  EQ -1   then %a EQ -1   and %b EQ -1
               // "and bool %a, %b" EQ true then %a EQ true and %b EQ true
@@ -1030,6 +935,250 @@ namespace {
                 break;
             }
           }
+        } else if (FCmpInst *CI = dyn_cast<FCmpInst>(I)) {
+          Value *Op0 = cIG.canonicalize(CI->getOperand(0)),
+                *Op1 = cIG.canonicalize(CI->getOperand(1));
+
+          ConstantFP *CI1 = dyn_cast<ConstantFP>(Op0),
+                     *CI2 = dyn_cast<ConstantFP>(Op1);
+
+          if (CI1 && CI2)
+            addEqual(CI, ConstantExpr::getFCmp(CI->getPredicate(), CI1, CI2));
+
+          switch (CI->getPredicate()) {
+            case FCmpInst::FCMP_OEQ:
+            case FCmpInst::FCMP_UEQ:
+              // "eq int %a, %b" EQ true  then %a EQ %b
+              // "eq int %a, %b" EQ false then %a NE %b
+              if (Canonical == ConstantBool::getTrue())
+                addEqual(Op0, Op1);
+              else if (Canonical == ConstantBool::getFalse())
+                addNotEqual(Op0, Op1);
+
+              // %a EQ %b then "eq int %a, %b" EQ true
+              // %a NE %b then "eq int %a, %b" EQ false
+              if (isEqual(Op0, Op1))
+                addEqual(CI, ConstantBool::getTrue());
+              else if (isNotEqual(Op0, Op1))
+                addEqual(CI, ConstantBool::getFalse());
+
+              break;
+            case FCmpInst::FCMP_ONE:
+            case FCmpInst::FCMP_UNE:
+              // "ne int %a, %b" EQ true  then %a NE %b
+              // "ne int %a, %b" EQ false then %a EQ %b
+              if (Canonical == ConstantBool::getTrue())
+                addNotEqual(Op0, Op1);
+              else if (Canonical == ConstantBool::getFalse())
+                addEqual(Op0, Op1);
+
+              // %a EQ %b then "ne int %a, %b" EQ false
+              // %a NE %b then "ne int %a, %b" EQ true
+              if (isEqual(Op0, Op1))
+                addEqual(CI, ConstantBool::getFalse());
+              else if (isNotEqual(Op0, Op1))
+                addEqual(CI, ConstantBool::getTrue());
+
+              break;
+            case FCmpInst::FCMP_ULT:
+            case FCmpInst::FCMP_OLT:
+              // "lt int %a, %b" EQ true  then %a LT %b
+              // "lt int %a, %b" EQ false then %b LE %a
+              if (Canonical == ConstantBool::getTrue())
+                addLess(Op0, Op1);
+              else if (Canonical == ConstantBool::getFalse())
+                addLessEqual(Op1, Op0);
+
+              // %a LT %b then "lt int %a, %b" EQ true
+              // %a GE %b then "lt int %a, %b" EQ false
+              if (isLess(Op0, Op1))
+                addEqual(CI, ConstantBool::getTrue());
+              else if (isGreaterEqual(Op0, Op1))
+                addEqual(CI, ConstantBool::getFalse());
+
+              break;
+            case FCmpInst::FCMP_ULE:
+            case FCmpInst::FCMP_OLE:
+              // "le int %a, %b" EQ true  then %a LE %b
+              // "le int %a, %b" EQ false then %b LT %a
+              if (Canonical == ConstantBool::getTrue())
+                addLessEqual(Op0, Op1);
+              else if (Canonical == ConstantBool::getFalse())
+                addLess(Op1, Op0);
+
+              // %a LE %b then "le int %a, %b" EQ true
+              // %a GT %b then "le int %a, %b" EQ false
+              if (isLessEqual(Op0, Op1))
+                addEqual(CI, ConstantBool::getTrue());
+              else if (isGreater(Op0, Op1))
+                addEqual(CI, ConstantBool::getFalse());
+
+              break;
+            case FCmpInst::FCMP_UGT:
+            case FCmpInst::FCMP_OGT:
+              // "gt int %a, %b" EQ true  then %b LT %a
+              // "gt int %a, %b" EQ false then %a LE %b
+              if (Canonical == ConstantBool::getTrue())
+                addLess(Op1, Op0);
+              else if (Canonical == ConstantBool::getFalse())
+                addLessEqual(Op0, Op1);
+
+              // %a GT %b then "gt int %a, %b" EQ true
+              // %a LE %b then "gt int %a, %b" EQ false
+              if (isGreater(Op0, Op1))
+                addEqual(CI, ConstantBool::getTrue());
+              else if (isLessEqual(Op0, Op1))
+                addEqual(CI, ConstantBool::getFalse());
+
+              break;
+            case FCmpInst::FCMP_UGE:
+            case FCmpInst::FCMP_OGE:
+              // "ge int %a, %b" EQ true  then %b LE %a
+              // "ge int %a, %b" EQ false then %a LT %b
+              if (Canonical == ConstantBool::getTrue())
+                addLessEqual(Op1, Op0);
+              else if (Canonical == ConstantBool::getFalse())
+                addLess(Op0, Op1);
+
+              // %a GE %b then "ge int %a, %b" EQ true
+              // %a LT %b then "lt int %a, %b" EQ false
+              if (isGreaterEqual(Op0, Op1))
+                addEqual(CI, ConstantBool::getTrue());
+              else if (isLess(Op0, Op1))
+                addEqual(CI, ConstantBool::getFalse());
+
+              break;
+            default:
+              break;
+          }
+
+          // "%x = add int %y, %z" and %x EQ %y then %z EQ 0
+          // "%x = mul int %y, %z" and %x EQ %y then %z EQ 1
+          // 1. Repeat all of the above, with order of operands reversed.
+          // "%x = fdiv float %y, %z" and %x EQ %y then %z EQ 1
+          Value *Known = Op0, *Unknown = Op1;
+          if (Known != BO) std::swap(Known, Unknown);
+        } else if (ICmpInst *CI = dyn_cast<ICmpInst>(I)) {
+          Value *Op0 = cIG.canonicalize(CI->getOperand(0)),
+                *Op1 = cIG.canonicalize(CI->getOperand(1));
+
+          ConstantIntegral *CI1 = dyn_cast<ConstantIntegral>(Op0),
+                           *CI2 = dyn_cast<ConstantIntegral>(Op1);
+
+          if (CI1 && CI2)
+            addEqual(CI, ConstantExpr::getICmp(CI->getPredicate(), CI1, CI2));
+
+          switch (CI->getPredicate()) {
+            case ICmpInst::ICMP_EQ:
+              // "eq int %a, %b" EQ true  then %a EQ %b
+              // "eq int %a, %b" EQ false then %a NE %b
+              if (Canonical == ConstantBool::getTrue())
+                addEqual(Op0, Op1);
+              else if (Canonical == ConstantBool::getFalse())
+                addNotEqual(Op0, Op1);
+
+              // %a EQ %b then "eq int %a, %b" EQ true
+              // %a NE %b then "eq int %a, %b" EQ false
+              if (isEqual(Op0, Op1))
+                addEqual(CI, ConstantBool::getTrue());
+              else if (isNotEqual(Op0, Op1))
+                addEqual(CI, ConstantBool::getFalse());
+
+              break;
+            case ICmpInst::ICMP_NE:
+              // "ne int %a, %b" EQ true  then %a NE %b
+              // "ne int %a, %b" EQ false then %a EQ %b
+              if (Canonical == ConstantBool::getTrue())
+                addNotEqual(Op0, Op1);
+              else if (Canonical == ConstantBool::getFalse())
+                addEqual(Op0, Op1);
+
+              // %a EQ %b then "ne int %a, %b" EQ false
+              // %a NE %b then "ne int %a, %b" EQ true
+              if (isEqual(Op0, Op1))
+                addEqual(CI, ConstantBool::getFalse());
+              else if (isNotEqual(Op0, Op1))
+                addEqual(CI, ConstantBool::getTrue());
+
+              break;
+            case ICmpInst::ICMP_ULT:
+            case ICmpInst::ICMP_SLT:
+              // "lt int %a, %b" EQ true  then %a LT %b
+              // "lt int %a, %b" EQ false then %b LE %a
+              if (Canonical == ConstantBool::getTrue())
+                addLess(Op0, Op1);
+              else if (Canonical == ConstantBool::getFalse())
+                addLessEqual(Op1, Op0);
+
+              // %a LT %b then "lt int %a, %b" EQ true
+              // %a GE %b then "lt int %a, %b" EQ false
+              if (isLess(Op0, Op1))
+                addEqual(CI, ConstantBool::getTrue());
+              else if (isGreaterEqual(Op0, Op1))
+                addEqual(CI, ConstantBool::getFalse());
+
+              break;
+            case ICmpInst::ICMP_ULE:
+            case ICmpInst::ICMP_SLE:
+              // "le int %a, %b" EQ true  then %a LE %b
+              // "le int %a, %b" EQ false then %b LT %a
+              if (Canonical == ConstantBool::getTrue())
+                addLessEqual(Op0, Op1);
+              else if (Canonical == ConstantBool::getFalse())
+                addLess(Op1, Op0);
+
+              // %a LE %b then "le int %a, %b" EQ true
+              // %a GT %b then "le int %a, %b" EQ false
+              if (isLessEqual(Op0, Op1))
+                addEqual(CI, ConstantBool::getTrue());
+              else if (isGreater(Op0, Op1))
+                addEqual(CI, ConstantBool::getFalse());
+
+              break;
+            case ICmpInst::ICMP_UGT:
+            case ICmpInst::ICMP_SGT:
+              // "gt int %a, %b" EQ true  then %b LT %a
+              // "gt int %a, %b" EQ false then %a LE %b
+              if (Canonical == ConstantBool::getTrue())
+                addLess(Op1, Op0);
+              else if (Canonical == ConstantBool::getFalse())
+                addLessEqual(Op0, Op1);
+
+              // %a GT %b then "gt int %a, %b" EQ true
+              // %a LE %b then "gt int %a, %b" EQ false
+              if (isGreater(Op0, Op1))
+                addEqual(CI, ConstantBool::getTrue());
+              else if (isLessEqual(Op0, Op1))
+                addEqual(CI, ConstantBool::getFalse());
+
+              break;
+            case ICmpInst::ICMP_UGE:
+            case ICmpInst::ICMP_SGE:
+              // "ge int %a, %b" EQ true  then %b LE %a
+              // "ge int %a, %b" EQ false then %a LT %b
+              if (Canonical == ConstantBool::getTrue())
+                addLessEqual(Op1, Op0);
+              else if (Canonical == ConstantBool::getFalse())
+                addLess(Op0, Op1);
+
+              // %a GE %b then "ge int %a, %b" EQ true
+              // %a LT %b then "lt int %a, %b" EQ false
+              if (isGreaterEqual(Op0, Op1))
+                addEqual(CI, ConstantBool::getTrue());
+              else if (isLess(Op0, Op1))
+                addEqual(CI, ConstantBool::getFalse());
+
+              break;
+            default:
+              break;
+          }
+
+          // "%x = add int %y, %z" and %x EQ %y then %z EQ 0
+          // "%x = mul int %y, %z" and %x EQ %y then %z EQ 1
+          // 1. Repeat all of the above, with order of operands reversed.
+          // "%x = fdiv float %y, %z" and %x EQ %y then %z EQ 1
+          Value *Known = Op0, *Unknown = Op1;
+          if (Known != BO) std::swap(Known, Unknown);
         } else if (SelectInst *SI = dyn_cast<SelectInst>(I)) {
           // Given: "%a = select bool %x, int %b, int %c"
           // %a EQ %b then %x EQ true
@@ -1108,6 +1257,7 @@ namespace {
       void visitStoreInst(StoreInst &SI);
 
       void visitBinaryOperator(BinaryOperator &BO);
+      void visitCmpInst(CmpInst &CI) {}
     };
 
     // Used by terminator instructions to proceed from the current basic