Improve handling of SelectInst.
authorNick Lewycky <nicholas@mxc.ca>
Sat, 2 Sep 2006 19:40:38 +0000 (19:40 +0000)
committerNick Lewycky <nicholas@mxc.ca>
Sat, 2 Sep 2006 19:40:38 +0000 (19:40 +0000)
Reorder operations to remove duplicated work.
Fix to leave floating-point types out of the optimization.
Add tests to predsimplify.ll for SwitchInst and SelectInst handling.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@30055 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/PredicateSimplifier.cpp
test/Transforms/PredicateSimplifier/2006-08-02-Switch.ll [new file with mode: 0644]
test/Transforms/PredicateSimplifier/predsimplify.ll

index ba28040ed1a4ccd7744aa31d154835df87b61511..0efff64aa1cbd7443ee28418db6814da46d30969 100644 (file)
@@ -28,9 +28,6 @@
 //
 //===------------------------------------------------------------------===//
 
-// TODO:
-// * Check handling of NAN in floating point types
-
 #define DEBUG_TYPE "predsimplify"
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Constants.h"
@@ -112,14 +109,22 @@ namespace {
     }
 
     void addEqual(Value *V1, Value *V2) {
+      // If %x = 0. and %y = -0., seteq %x, %y is true, but
+      // copysign(%x) is not the same as copysign(%y).
+      if (V2->getType()->isFloatingPoint()) return;
+
       order(V1, V2);
       if (isa<Constant>(V2)) return; // refuse to set false == true.
 
+      DEBUG(std::cerr << "equal: " << *V1 << " and " << *V2 << "\n");
       union_find.unionSets(V1, V2);
       addImpliedProperties(EQ, V1, V2);
     }
 
     void addNotEqual(Value *V1, Value *V2) {
+      // If %x = NAN then seteq %x, %x is false.
+      if (V2->getType()->isFloatingPoint()) return;
+
       DEBUG(std::cerr << "not equal: " << *V1 << " and " << *V2 << "\n");
       if (findProperty(NE, V1, V2) != Properties.end())
         return; // found.
@@ -180,15 +185,9 @@ namespace {
     struct Property {
       Property(Ops opcode, Value *v1, Value *v2)
         : Opcode(opcode), V1(v1), V2(v2)
-      { assert(opcode != EQ && "Equality belongs in the synonym set,"
+      { assert(opcode != EQ && "Equality belongs in the synonym set, "
                "not a property."); }
 
-      bool operator<(const Property &rhs) const {
-        if (Opcode != rhs.Opcode) return Opcode < rhs.Opcode;
-        if (V1 != rhs.V1) return V1 < rhs.V1;
-        return V2 < rhs.V2;
-      }
-
       Ops Opcode;
       Value *V1, *V2;
     };
@@ -208,7 +207,7 @@ namespace {
       }
     }
 
-    // Finds the properties implied by a synonym and adds them too.
+    // Finds the properties implied by a equivalence and adds them too.
     // Example: ("seteq %a, %b", true,  EQ) --> (%a, %b, EQ)
     //          ("seteq %a, %b", false, EQ) --> (%a, %b, NE)
     void addImpliedProperties(Ops Opcode, Value *V1, Value *V2) {
@@ -267,13 +266,25 @@ namespace {
         default:
           break;
         }
+      } else if (SelectInst *SI = dyn_cast<SelectInst>(V2)) {
+        if (Opcode != EQ && Opcode != NE) return;
+
+        ConstantBool *True  = (Opcode==EQ) ? ConstantBool::True
+                                           : ConstantBool::False,
+                     *False = (Opcode==EQ) ? ConstantBool::False
+                                           : ConstantBool::True;
+
+        if (V1 == SI->getTrueValue())
+          addEqual(SI->getCondition(), True);
+        else if (V1 == SI->getFalseValue())
+          addEqual(SI->getCondition(), False);
+        else if (Opcode == EQ)
+          assert("Result of select not equal to either value.");
       }
     }
 
-    std::map<Value *, unsigned> SynonymMap;
-    std::vector<Value *> Synonyms;
-
   public:
+#ifdef DEBUG
     void debug(std::ostream &os) const {
       for (EquivalenceClasses<Value*>::iterator I = union_find.begin(),
            E = union_find.end(); I != E; ++I) {
@@ -284,6 +295,7 @@ namespace {
         std::cerr << "\n--\n";
       }
     }
+#endif
 
     std::vector<Property> Properties;
   };
@@ -351,13 +363,13 @@ void PredicateSimplifier::getAnalysisUsage(AnalysisUsage &AU) const {
 
 // resolve catches cases addProperty won't because it wasn't used as a
 // condition in the branch, and that visit won't, because the instruction
-// was defined outside of the range that the properties apply to.
+// was defined outside of the scope that the properties apply to.
 Value *PredicateSimplifier::resolve(SetCondInst *SCI,
                                     const PropertySet &KP) {
   // Attempt to resolve the SetCondInst to a boolean.
 
-  Value *SCI0 = SCI->getOperand(0),
-        *SCI1 = SCI->getOperand(1);
+  Value *SCI0 = resolve(SCI->getOperand(0), KP),
+        *SCI1 = resolve(SCI->getOperand(1), KP);
   PropertySet::ConstPropertyIterator NE =
                    KP.findProperty(PropertySet::NE, SCI0, SCI1);
 
@@ -378,9 +390,6 @@ Value *PredicateSimplifier::resolve(SetCondInst *SCI,
     }
   }
 
-  SCI0 = KP.canonicalize(SCI0);
-  SCI1 = KP.canonicalize(SCI1);
-
   ConstantIntegral *CI1 = dyn_cast<ConstantIntegral>(SCI0),
                    *CI2 = dyn_cast<ConstantIntegral>(SCI1);
 
@@ -445,6 +454,8 @@ Value *PredicateSimplifier::resolve(Value *V, const PropertySet &KP) {
 
   V = KP.canonicalize(V);
 
+  DEBUG(std::cerr << "peering into " << *V << "\n");
+
   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(V))
     return resolve(BO, KP);
   else if (SelectInst *SI = dyn_cast<SelectInst>(V))
@@ -466,8 +477,19 @@ void PredicateSimplifier::visit(Instruction *I, DominatorTree::Node *DTNode,
   DEBUG(std::cerr << "Considering instruction " << *I << "\n");
   DEBUG(KnownProperties.debug(std::cerr));
 
-  // Substitute values known to be equal.
-  for (unsigned i = 0, E = I->getNumOperands(); i != E; ++i) {
+  // Try to replace whole instruction.
+  Value *V = resolve(I, KnownProperties);
+  assert(V && "resolve not supposed to return NULL.");
+  if (V != I) {
+    modified = true;
+    ++NumInstruction;
+    I->replaceAllUsesWith(V);
+    I->eraseFromParent();
+    return;
+  }
+
+  // Try to substitute operands.
+  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
     Value *Oper = I->getOperand(i);
     Value *V = resolve(Oper, KnownProperties);
     assert(V && "resolve not supposed to return NULL.");
@@ -480,15 +502,6 @@ void PredicateSimplifier::visit(Instruction *I, DominatorTree::Node *DTNode,
     }
   }
 
-  Value *V = resolve(I, KnownProperties);
-  assert(V && "resolve not supposed to return NULL.");
-  if (V != I) {
-    modified = true;
-    ++NumInstruction;
-    I->replaceAllUsesWith(V);
-    I->eraseFromParent();
-  }
-
   if (TerminatorInst *TI = dyn_cast<TerminatorInst>(I))
     visit(TI, DTNode, KnownProperties);
   else if (LoadInst *LI = dyn_cast<LoadInst>(I))
diff --git a/test/Transforms/PredicateSimplifier/2006-08-02-Switch.ll b/test/Transforms/PredicateSimplifier/2006-08-02-Switch.ll
new file mode 100644 (file)
index 0000000..97f2a5a
--- /dev/null
@@ -0,0 +1,22 @@
+; RUN: llvm-as < %s | opt -predsimplify -disable-output
+
+fastcc void %_ov_splice(int %n1, int %n2, int %ch2) {
+entry:
+       %tmp = setgt int %n1, %n2               ; <bool> [#uses=1]
+       %n.0 = select bool %tmp, int %n2, int %n1               ; <int> [#uses=1]
+       %tmp104 = setlt int 0, %ch2             ; <bool> [#uses=1]
+       br bool %tmp104, label %cond_true105, label %return
+
+cond_true95:           ; preds = %cond_true105
+       ret void
+
+bb98:          ; preds = %cond_true105
+       ret void
+
+cond_true105:          ; preds = %entry
+       %tmp94 = setgt int %n.0, 0              ; <bool> [#uses=1]
+       br bool %tmp94, label %cond_true95, label %bb98
+
+return:                ; preds = %entry
+       ret void
+}
index 056d8c97dc05704cd85c253a1039f31991f31ae6..89d5d4ce4982b4c00b00367826f7512e3c5763d8 100644 (file)
@@ -1,4 +1,5 @@
-; RUN: llvm-as < %s | opt -predsimplify -instcombine -simplifycfg | llvm-dis | grep -v declare | not grep fail
+; RUN: llvm-as < %s | opt -predsimplify -instcombine -simplifycfg | llvm-dis | grep -v declare | not grep fail &&
+; RUN: llvm-as < %s | opt -predsimplify -instcombine -simplifycfg | llvm-dis | grep -v declare | grep pass | wc -l | grep 3
 
 void %test1(int %x) {
 entry:
@@ -124,6 +125,167 @@ else.2:
   ret void
 }
 
+void %test9(int %y, int %z) {
+entry:
+  %x = add int %y, %z
+  %A = seteq int %y, 3
+  %B = seteq int %z, 5
+  %C = and bool %A, %B
+  br bool %C, label %cond_true, label %return
+
+cond_true:
+  %D = seteq int %x, 8
+  br bool %D, label %then, label %oops
+
+then:
+  call void (...)* %pass( )
+  ret void
+
+oops:
+  call void (...)* %fail( )
+  ret void
+
+return:
+  ret void
+}
+
+void %switch1(int %x) {
+entry:
+  %A = seteq int %x, 10
+  br bool %A, label %return, label %cond_false
+
+cond_false:
+  switch int %x, label %return [
+    int 9, label %then1
+    int 10, label %then2
+  ]
+
+then1:
+  call void (...)* %pass( )
+  ret void
+
+then2:
+  call void (...)* %fail( )
+  ret void
+
+return:
+  ret void
+}
+
+void %switch2(int %x) {
+entry:
+  %A = seteq int %x, 10
+  br bool %A, label %return, label %cond_false
+
+cond_false:
+  switch int %x, label %return [
+    int 8, label %then1
+    int 9, label %then1
+    int 10, label %then1
+  ]
+
+then1:
+  %B = setne int %x, 8
+  br bool %B, label %then2, label %return
+
+then2:
+  call void (...)* %pass( )
+  ret void
+
+return:
+  ret void
+}
+
+void %switch3(int %x) {
+entry:
+  %A = seteq int %x, 10
+  br bool %A, label %return, label %cond_false
+
+cond_false:
+  switch int %x, label %return [
+    int 9, label %then1
+    int 10, label %then1
+  ]
+
+then1:
+  %B = seteq int %x, 9
+  br bool %B, label %return, label %oops
+
+oops:
+  call void (...)* %fail( )
+  ret void
+
+return:
+  ret void
+}
+
+void %switch4(int %x) {
+entry:
+  %A = seteq int %x, 10
+  br bool %A, label %then1, label %cond_false
+
+cond_false:
+  switch int %x, label %default [
+    int 9, label %then1
+    int 10, label %then2
+  ]
+
+then1:
+  ret void
+
+then2:
+  ret void
+
+default:
+  %B = seteq int %x, 9
+  br bool %B, label %oops, label %then1
+
+oops:
+  call void (...)* %fail( )
+  ret void
+}
+
+void %select1(int %x) {
+entry:
+  %A = seteq int %x, 10
+  %B = select bool %A, int 1, int 2
+  %C = seteq int %B, 1
+  br bool %C, label %then, label %else
+
+then:
+  br bool %A, label %return, label %oops
+
+else:
+  br bool %A, label %oops, label %return
+
+oops:
+  call void (...)* %fail( )
+  ret void
+
+return:
+  ret void
+}
+
+void %select2(int %x) {
+entry:
+  %A = seteq int %x, 10
+  %B = select bool %A, int 1, int 2
+  %C = seteq int %B, 1
+  br bool %A, label %then, label %else
+
+then:
+  br bool %C, label %return, label %oops
+
+else:
+  br bool %C, label %oops, label %return
+
+oops:
+  call void (...)* %fail( )
+  ret void
+
+return:
+  ret void
+}
 
 declare void %fail(...)