Add more comments and update to new asm syntax.
authorNick Lewycky <nicholas@mxc.ca>
Fri, 16 Mar 2007 02:37:39 +0000 (02:37 +0000)
committerNick Lewycky <nicholas@mxc.ca>
Fri, 16 Mar 2007 02:37:39 +0000 (02:37 +0000)
Add new micro-optimizations.

Add icmp predicate snuggling. Given %x ULT 4, "icmp ugt %x, 2" becomes
"icmp eq %x, 3". This doesn't apply in any non-trivial cases yet due to missing
support for NE values in ValueRanges.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@35119 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/PredicateSimplifier.cpp

index 836720115a0c16904d1819d0e24e4eca33c382e5..9c88e2832848e6efe5b6ec1aa27ac057d25ef005 100644 (file)
@@ -22,9 +22,9 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This pass focusses on four properties; equals, not equals, less-than
-// and less-than-or-equals-to. The greater-than forms are also held just
-// to allow walking from a lesser node to a greater one. These properties
+// The InequalityGraph focusses on four properties; equals, not equals,
+// less-than and less-than-or-equals-to. The greater-than forms are also held
+// just to allow walking from a lesser node to a greater one. These properties
 // are stored in a lattice; LE can become LT or EQ, NE can become LT or GT.
 //
 // These relationships define a graph between values of the same type. Each
 // that the dividend is not equal to zero.
 //
 //===----------------------------------------------------------------------===//
+//
+// The ValueRanges class stores the known integer bounds of a Value. When we
+// encounter i8 %a u< %b, the ValueRanges stores that %a = [1, 255] and
+// %b = [0, 254]. Because we store these by Value*, you should always
+// canonicalize through the InequalityGraph first.
+//
+// It never stores an empty range, because that means that the code is
+// unreachable. It never stores a single-element range since that's an equality
+// relationship and better stored in the InequalityGraph.
+//
+//===----------------------------------------------------------------------===//
 
 #define DEBUG_TYPE "predsimplify"
 #include "llvm/Transforms/Scalar.h"
@@ -96,6 +107,7 @@ STATISTIC(NumVarsReplaced, "Number of argument substitutions");
 STATISTIC(NumInstruction , "Number of instructions removed");
 STATISTIC(NumSimple      , "Number of simple replacements");
 STATISTIC(NumBlocks      , "Number of blocks marked unreachable");
+STATISTIC(NumSnuggle     , "Number of comparisons snuggled");
 
 namespace {
   // SLT SGT ULT UGT EQ
@@ -155,6 +167,7 @@ namespace {
   /// reversePredicate - reverse the direction of the inequality
   static LatticeVal reversePredicate(LatticeVal LV) {
     unsigned reverse = LV ^ (SLT_BIT|SGT_BIT|ULT_BIT|UGT_BIT); //preserve EQ_BIT
+
     if ((reverse & (SLT_BIT|SGT_BIT)) == 0)
       reverse |= (SLT_BIT|SGT_BIT);
 
@@ -167,10 +180,10 @@ namespace {
   }
 
   /// This is a StrictWeakOrdering predicate that sorts ETNodes by how many
-  /// children they have. With this, you can iterate through a list sorted by
-  /// this operation and the first matching entry is the most specific match
-  /// for your basic block. The order provided is total; ETNodes with the
-  /// same number of children are sorted by pointer address.
+  /// descendants they have. With this, you can iterate through a list sorted
+  /// by this operation and the first matching entry is the most specific
+  /// match for your basic block. The order provided is stable; ETNodes with
+  /// the same number of children are sorted by pointer address.
   struct VISIBILITY_HIDDEN OrderByDominance {
     bool operator()(const ETNode *LHS, const ETNode *RHS) const {
       unsigned LHS_spread = LHS->getDFSNumOut() - LHS->getDFSNumIn();
@@ -830,7 +843,7 @@ namespace {
       ConstantRange CR2 = rangeFromValue(V2, Subtree, W);
 
       // True iff all values in CR1 are LV to all values in CR2.
-      switch(LV) {
+      switch (LV) {
       default: assert(!"Impossible lattice value!");
       case NE:
         return CR1.intersectWith(CR2).isEmptySet();
@@ -919,7 +932,6 @@ namespace {
     }
   };
 
-
   /// UnreachableBlocks keeps tracks of blocks that are for one reason or
   /// another discovered to be unreachable. This is used to cull the graph when
   /// analyzing instructions, and to mark blocks with the "unreachable"
@@ -1366,7 +1378,7 @@ namespace {
 
         switch (BO->getOpcode()) {
           case Instruction::And: {
-            // "and i32 %a, %b"  EQ -1 then %a EQ -1 and %b EQ -1
+            // "and i32 %a, %b" EQ -1 then %a EQ -1 and %b EQ -1
             ConstantInt *CI = ConstantInt::getAllOnesValue(Ty);
             if (Canonical == CI) {
               add(CI, Op0, ICmpInst::ICMP_EQ, NewContext);
@@ -1459,8 +1471,8 @@ namespace {
             return;
           }
 
-        // "%y = and bool true, %x" then %x EQ %y.
-        // "%y = or bool false, %x" then %x EQ %y.
+        // "%y = and i1 true, %x" then %x EQ %y.
+        // "%y = or i1 false, %x" then %x EQ %y.
         if (BO->getOpcode() == Instruction::Or) {
           Constant *Zero = Constant::getNullValue(BO->getType());
           if (Op0 == Zero) {
@@ -1486,16 +1498,16 @@ namespace {
         // 1. Repeat all of the above, with order of operands reversed.
         // "%x = udiv i32 %y, %z" and %x EQ %y then %z EQ 1
 
+        Instruction::BinaryOps Opcode = BO->getOpcode();
+        const Type *Ty = BO->getType();
+        assert(!Ty->isFPOrFPVector() && "Float in work queue!");
+
         Value *Known = Op0, *Unknown = Op1;
         if (Known != BO) std::swap(Known, Unknown);
         if (Known == BO) {
-          const Type *Ty = BO->getType();
-          assert(!Ty->isFPOrFPVector() && "Float in work queue!");
-
-          switch (BO->getOpcode()) {
+          switch (Opcode) {
             default: break;
             case Instruction::Xor:
-            case Instruction::Or:
             case Instruction::Add:
             case Instruction::Sub:
               add(Unknown, Constant::getNullValue(Ty), ICmpInst::ICMP_EQ,
@@ -1504,7 +1516,6 @@ namespace {
             case Instruction::UDiv:
             case Instruction::SDiv:
               if (Unknown == Op0) break; // otherwise, fallthrough
-            case Instruction::And:
             case Instruction::Mul:
               if (isa<ConstantInt>(Unknown)) {
                 Constant *One = ConstantInt::get(Ty, 1);
@@ -1517,7 +1528,7 @@ namespace {
         // TODO: "%a = add i32 %b, 1" and %b > %z then %a >= %z.
 
       } else if (ICmpInst *IC = dyn_cast<ICmpInst>(I)) {
-        // "%a = icmp ult i32 %b, %c" and %b u< %c  then %a EQ true
+        // "%a = icmp ult i32 %b, %c" and %b u<  %c then %a EQ true
         // "%a = icmp ult i32 %b, %c" and %b u>= %c then %a EQ false
         // etc.
 
@@ -1531,12 +1542,8 @@ namespace {
           add(IC, ConstantInt::getFalse(), ICmpInst::ICMP_EQ, NewContext);
         }
 
-        // TODO: "i1 %x s<u> %y" implies %x = true and %y = false.
-
-        // TODO: make the predicate more strict, if possible.
-
       } else if (SelectInst *SI = dyn_cast<SelectInst>(I)) {
-        // Given: "%a = select bool %x, int %b, int %c"
+        // Given: "%a = select i1 %x, i32 %b, i32 %c"
         // %x EQ true  then %a EQ %b
         // %x EQ false then %a EQ %c
         // %b EQ %c then %a EQ %b
@@ -1752,6 +1759,7 @@ namespace {
       void visitZExtInst(ZExtInst &ZI);
 
       void visitBinaryOperator(BinaryOperator &BO);
+      void visitICmpInst(ICmpInst &IC);
     };
   
     // Used by terminator instructions to proceed from the current basic
@@ -1823,10 +1831,11 @@ namespace {
       }
 #endif
 
-      DOUT << "push (%" << I->getParent()->getName() << ")\n";
+      std::string name = I->getParent()->getName();
+      DOUT << "push (%" << name << ")\n";
       Forwards visit(this, DT);
       visit.visit(*I);
-      DOUT << "pop (%" << I->getParent()->getName() << ")\n";
+      DOUT << "pop (%" << name << ")\n";
     }
   };
 
@@ -1979,6 +1988,7 @@ namespace {
     Instruction::BinaryOps ops = BO.getOpcode();
 
     switch (ops) {
+    default: break;
       case Instruction::URem:
       case Instruction::SRem:
       case Instruction::UDiv:
@@ -1990,8 +2000,100 @@ namespace {
         VRP.solve();
         break;
       }
-      default:
-        break;
+    }
+
+    switch (ops) {
+      default: break;
+      case Instruction::Shl: {
+        VRPSolver VRP(IG, UB, VR, PS->Forest, PS->modified, &BO);
+        VRP.add(&BO, BO.getOperand(0), ICmpInst::ICMP_UGE);
+        VRP.solve();
+      } break;
+      case Instruction::AShr: {
+        VRPSolver VRP(IG, UB, VR, PS->Forest, PS->modified, &BO);
+        VRP.add(&BO, BO.getOperand(0), ICmpInst::ICMP_SLE);
+        VRP.solve();
+      } break;
+      case Instruction::LShr:
+      case Instruction::UDiv: {
+        VRPSolver VRP(IG, UB, VR, PS->Forest, PS->modified, &BO);
+        VRP.add(&BO, BO.getOperand(0), ICmpInst::ICMP_ULE);
+        VRP.solve();
+      } break;
+      case Instruction::URem: {
+        VRPSolver VRP(IG, UB, VR, PS->Forest, PS->modified, &BO);
+        VRP.add(&BO, BO.getOperand(1), ICmpInst::ICMP_ULE);
+        VRP.solve();
+      } break;
+      case Instruction::And: {
+        VRPSolver VRP(IG, UB, VR, PS->Forest, PS->modified, &BO);
+        VRP.add(&BO, BO.getOperand(0), ICmpInst::ICMP_ULE);
+        VRP.add(&BO, BO.getOperand(1), ICmpInst::ICMP_ULE);
+        VRP.solve();
+      } break;
+      case Instruction::Or: {
+        VRPSolver VRP(IG, UB, VR, PS->Forest, PS->modified, &BO);
+        VRP.add(&BO, BO.getOperand(0), ICmpInst::ICMP_UGE);
+        VRP.add(&BO, BO.getOperand(1), ICmpInst::ICMP_UGE);
+        VRP.solve();
+      } break;
+    }
+  }
+
+  void PredicateSimplifier::Forwards::visitICmpInst(ICmpInst &IC) {
+    // If possible, squeeze the ICmp predicate into something simpler.
+    // Eg., if x = [0, 4) and we're being asked icmp uge %x, 3 then change
+    // the predicate to eq.
+
+    ICmpInst::Predicate Pred = IC.getPredicate();
+
+    if (ConstantInt *Op1 = dyn_cast<ConstantInt>(IC.getOperand(1))) {
+      ConstantInt *NextVal = 0;
+      switch(Pred) {
+        default: break;
+        case ICmpInst::ICMP_SLT:
+        case ICmpInst::ICMP_ULT:
+          if (Op1->getValue() != 0)
+            NextVal = cast<ConstantInt>(ConstantExpr::getAdd(
+                          Op1, ConstantInt::get(Op1->getType(), -1)));
+         break;
+        case ICmpInst::ICMP_SGT:
+        case ICmpInst::ICMP_UGT:
+          if (!Op1->getValue().isAllOnesValue())
+            NextVal = cast<ConstantInt>(ConstantExpr::getAdd(
+                          Op1, ConstantInt::get(Op1->getType(), 1)));
+         break;
+
+      }
+      if (NextVal) {
+        VRPSolver VRP(IG, UB, VR, PS->Forest, PS->modified, &IC);
+        if (VRP.isRelatedBy(IC.getOperand(0), NextVal,
+                            ICmpInst::getInversePredicate(Pred))) {
+          ICmpInst *NewIC = new ICmpInst(ICmpInst::ICMP_EQ, IC.getOperand(0),
+                                         NextVal, "", &IC);
+          NewIC->takeName(&IC);
+          IC.replaceAllUsesWith(NewIC);
+          IG.remove(&IC); // XXX: prove this isn't necessary
+          IC.eraseFromParent();
+          ++NumSnuggle;
+          PS->modified = true;
+          return;
+        }
+      }
+    }
+
+    switch(Pred) {
+      default: return;
+      case ICmpInst::ICMP_ULE: Pred = ICmpInst::ICMP_ULT; break;
+      case ICmpInst::ICMP_UGE: Pred = ICmpInst::ICMP_UGT; break;
+      case ICmpInst::ICMP_SLE: Pred = ICmpInst::ICMP_SLT; break;
+      case ICmpInst::ICMP_SGE: Pred = ICmpInst::ICMP_SGT; break;
+    }
+    VRPSolver VRP(IG, UB, VR, PS->Forest, PS->modified, &IC);
+    if (VRP.isRelatedBy(IC.getOperand(1), IC.getOperand(0), Pred)) {
+      ++NumSnuggle;
+      PS->modified = true;
+      IC.setPredicate(Pred);
     }
   }