[InstCombine] fold bitcasts around an extractelement
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineAndOrXor.cpp
index 5459c8955a1ec7a56a5bebb40939b4a27c621e4d..2bf6faa47b93d108870e793d97d4a3c417a8b58a 100644 (file)
@@ -1208,6 +1208,11 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I,
   auto Opcode = I.getOpcode();
   assert((Opcode == Instruction::And || Opcode == Instruction::Or) &&
          "Trying to match De Morgan's Laws with something other than and/or");
+  // Flip the logic operation.
+  if (Opcode == Instruction::And)
+    Opcode = Instruction::Or;
+  else
+    Opcode = Instruction::And;
 
   Value *Op0 = I.getOperand(0);
   Value *Op1 = I.getOperand(1);
@@ -1215,16 +1220,31 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I,
   if (Value *Op0NotVal = dyn_castNotVal(Op0))
     if (Value *Op1NotVal = dyn_castNotVal(Op1))
       if (Op0->hasOneUse() && Op1->hasOneUse()) {
-        // Flip the logic operation.
-        if (Opcode == Instruction::And)
-          Opcode = Instruction::Or;
-        else
-          Opcode = Instruction::And;
         Value *LogicOp = Builder->CreateBinOp(Opcode, Op0NotVal, Op1NotVal,
                                               I.getName() + ".demorgan");
         return BinaryOperator::CreateNot(LogicOp);
       }
 
+  // De Morgan's Law in disguise:
+  // (zext(bool A) ^ 1) & (zext(bool B) ^ 1) -> zext(~(A | B))
+  // (zext(bool A) ^ 1) | (zext(bool B) ^ 1) -> zext(~(A & B))
+  Value *A = nullptr;
+  Value *B = nullptr;
+  ConstantInt *C1 = nullptr;
+  if (match(Op0, m_OneUse(m_Xor(m_ZExt(m_Value(A)), m_ConstantInt(C1)))) &&
+      match(Op1, m_OneUse(m_Xor(m_ZExt(m_Value(B)), m_Specific(C1))))) {
+    // TODO: This check could be loosened to handle different type sizes.
+    // Alternatively, we could fix the definition of m_Not to recognize a not
+    // operation hidden by a zext?
+    if (A->getType()->isIntegerTy(1) && B->getType()->isIntegerTy(1) &&
+        C1->isOne()) {
+      Value *LogicOp = Builder->CreateBinOp(Opcode, A, B,
+                                            I.getName() + ".demorgan");
+      Value *Not = Builder->CreateNot(LogicOp);
+      return CastInst::CreateZExtOrBitCast(Not, I.getType());
+    }
+  }
+
   return nullptr;
 }
 
@@ -1468,14 +1488,15 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
         return ReplaceInstUsesWith(I, Res);
 
 
-  // fold (and (cast A), (cast B)) -> (cast (and A, B))
-  if (CastInst *Op0C = dyn_cast<CastInst>(Op0))
+  if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) {
+    Value *Op0COp = Op0C->getOperand(0);
+    Type *SrcTy = Op0COp->getType();
+    // fold (and (cast A), (cast B)) -> (cast (and A, B))
     if (CastInst *Op1C = dyn_cast<CastInst>(Op1)) {
-      Type *SrcTy = Op0C->getOperand(0)->getType();
       if (Op0C->getOpcode() == Op1C->getOpcode() && // same cast kind ?
           SrcTy == Op1C->getOperand(0)->getType() &&
           SrcTy->isIntOrIntVectorTy()) {
-        Value *Op0COp = Op0C->getOperand(0), *Op1COp = Op1C->getOperand(0);
+        Value *Op1COp = Op1C->getOperand(0);
 
         // Only do this if the casts both really cause code to be generated.
         if (ShouldOptimizeCast(Op0C->getOpcode(), Op0COp, I.getType()) &&
@@ -1500,6 +1521,20 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
       }
     }
 
+    // If we are masking off the sign bit of a floating-point value, convert
+    // this to the canonical fabs intrinsic call and cast back to integer.
+    // The backend should know how to optimize fabs().
+    // TODO: This transform should also apply to vectors.
+    ConstantInt *CI;
+    if (isa<BitCastInst>(Op0C) && SrcTy->isFloatingPointTy() &&
+        match(Op1, m_ConstantInt(CI)) && CI->isMaxValue(true)) {
+      Module *M = I.getParent()->getParent()->getParent();
+      Function *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, SrcTy);
+      Value *Call = Builder->CreateCall(Fabs, Op0COp, "fabs");
+      return CastInst::CreateBitOrPointerCast(Call, I.getType());
+    }
+  }
+
   {
     Value *X = nullptr;
     bool OpsSwapped = false;
@@ -1927,14 +1962,14 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
     case ICmpInst::ICMP_EQ:
       if (LHS->getOperand(0) == RHS->getOperand(0)) {
         // if LHSCst and RHSCst differ only by one bit:
-        // (A == C1 || A == C2) -> (A & ~(C1 ^ C2)) == C1
+        // (A == C1 || A == C2) -> (A | (C1 ^ C2)) == C2
         assert(LHSCst->getValue().ule(LHSCst->getValue()));
 
         APInt Xor = LHSCst->getValue() ^ RHSCst->getValue();
         if (Xor.isPowerOf2()) {
-          Value *NegCst = Builder->getInt(~Xor);
-          Value *And = Builder->CreateAnd(LHS->getOperand(0), NegCst);
-          return Builder->CreateICmp(ICmpInst::ICMP_EQ, And, LHSCst);
+          Value *Cst = Builder->getInt(Xor);
+          Value *Or = Builder->CreateOr(LHS->getOperand(0), Cst);
+          return Builder->CreateICmp(ICmpInst::ICMP_EQ, Or, RHSCst);
         }
       }