InstCombine: Squash an icmp+select into bitwise arithmetic
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineSelect.cpp
index 83a2f2c563f33249a7fd979853379a339a22c2ba..df55e86f0590ee7243a46089c67284fccd20ebf6 100644 (file)
@@ -617,26 +617,44 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
     }
   }
 
-  {
+  if (unsigned BitWidth = TrueVal->getType()->getScalarSizeInBits()) {
+    APInt MinSignedValue = APInt::getSignBit(BitWidth);
     Value *X;
     const APInt *Y, *C;
-    if (match(CmpLHS, m_And(m_Value(X), m_Power2(Y))) &&
+    bool TrueWhenUnset;
+    bool IsBitTest = false;
+    if (ICmpInst::isEquality(Pred) &&
+        match(CmpLHS, m_And(m_Value(X), m_Power2(Y))) &&
         match(CmpRHS, m_Zero())) {
+      IsBitTest = true;
+      TrueWhenUnset = Pred == ICmpInst::ICMP_EQ;
+    } else if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, m_Zero())) {
+      X = CmpLHS;
+      Y = &MinSignedValue;
+      IsBitTest = true;
+      TrueWhenUnset = false;
+    } else if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_AllOnes())) {
+      X = CmpLHS;
+      Y = &MinSignedValue;
+      IsBitTest = true;
+      TrueWhenUnset = true;
+    }
+    if (IsBitTest) {
       Value *V = nullptr;
       // (X & Y) == 0 ? X : X ^ Y  --> X & ~Y
-      if (Pred == ICmpInst::ICMP_EQ && TrueVal == X &&
+      if (TrueWhenUnset && TrueVal == X &&
           match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C)
         V = Builder->CreateAnd(X, ~(*Y));
       // (X & Y) != 0 ? X ^ Y : X  --> X & ~Y
-      else if (Pred == ICmpInst::ICMP_NE && FalseVal == X &&
+      else if (!TrueWhenUnset && FalseVal == X &&
                match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C)
         V = Builder->CreateAnd(X, ~(*Y));
       // (X & Y) == 0 ? X ^ Y : X  --> X | Y
-      else if (Pred == ICmpInst::ICMP_EQ && FalseVal == X &&
+      else if (TrueWhenUnset && FalseVal == X &&
                match(TrueVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C)
         V = Builder->CreateOr(X, *Y);
       // (X & Y) != 0 ? X : X ^ Y  --> X | Y
-      else if (Pred == ICmpInst::ICMP_NE && TrueVal == X &&
+      else if (!TrueWhenUnset && TrueVal == X &&
                match(FalseVal, m_Xor(m_Specific(X), m_APInt(C))) && *Y == *C)
         V = Builder->CreateOr(X, *Y);