getParent() ^ 3 == getModule() ; NFCI
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineAndOrXor.cpp
index 0a603c030d951525cb0c64882da4cdda59d45026..95c50d32c8207b12f5ec0abfd78301ce0e4da6a4 100644 (file)
@@ -150,8 +150,7 @@ Value *InstCombiner::SimplifyBSwap(BinaryOperator &I) {
   else //if (Op == Instruction::Xor)
     BinOp = Builder->CreateXor(NewLHS, NewRHS);
 
-  Module *M = I.getParent()->getParent()->getParent();
-  Function *F = Intrinsic::getDeclaration(M, Intrinsic::bswap, ITy);
+  Function *F = Intrinsic::getDeclaration(I.getModule(), Intrinsic::bswap, ITy);
   return Builder->CreateCall(F, BinOp);
 }
 
@@ -1208,6 +1207,11 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I,
   auto Opcode = I.getOpcode();
   assert((Opcode == Instruction::And || Opcode == Instruction::Or) &&
          "Trying to match De Morgan's Laws with something other than and/or");
+  // Flip the logic operation.
+  if (Opcode == Instruction::And)
+    Opcode = Instruction::Or;
+  else
+    Opcode = Instruction::And;
 
   Value *Op0 = I.getOperand(0);
   Value *Op1 = I.getOperand(1);
@@ -1215,16 +1219,31 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I,
   if (Value *Op0NotVal = dyn_castNotVal(Op0))
     if (Value *Op1NotVal = dyn_castNotVal(Op1))
       if (Op0->hasOneUse() && Op1->hasOneUse()) {
-        // Flip the logic operation.
-        if (Opcode == Instruction::And)
-          Opcode = Instruction::Or;
-        else
-          Opcode = Instruction::And;
         Value *LogicOp = Builder->CreateBinOp(Opcode, Op0NotVal, Op1NotVal,
                                               I.getName() + ".demorgan");
         return BinaryOperator::CreateNot(LogicOp);
       }
 
+  // De Morgan's Law in disguise:
+  // (zext(bool A) ^ 1) & (zext(bool B) ^ 1) -> zext(~(A | B))
+  // (zext(bool A) ^ 1) | (zext(bool B) ^ 1) -> zext(~(A & B))
+  Value *A = nullptr;
+  Value *B = nullptr;
+  ConstantInt *C1 = nullptr;
+  if (match(Op0, m_OneUse(m_Xor(m_ZExt(m_Value(A)), m_ConstantInt(C1)))) &&
+      match(Op1, m_OneUse(m_Xor(m_ZExt(m_Value(B)), m_Specific(C1))))) {
+    // TODO: This check could be loosened to handle different type sizes.
+    // Alternatively, we could fix the definition of m_Not to recognize a not
+    // operation hidden by a zext?
+    if (A->getType()->isIntegerTy(1) && B->getType()->isIntegerTy(1) &&
+        C1->isOne()) {
+      Value *LogicOp = Builder->CreateBinOp(Opcode, A, B,
+                                            I.getName() + ".demorgan");
+      Value *Not = Builder->CreateNot(LogicOp);
+      return CastInst::CreateZExtOrBitCast(Not, I.getType());
+    }
+  }
+
   return nullptr;
 }
 
@@ -1468,14 +1487,15 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
         return ReplaceInstUsesWith(I, Res);
 
 
-  // fold (and (cast A), (cast B)) -> (cast (and A, B))
-  if (CastInst *Op0C = dyn_cast<CastInst>(Op0))
+  if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) {
+    Value *Op0COp = Op0C->getOperand(0);
+    Type *SrcTy = Op0COp->getType();
+    // fold (and (cast A), (cast B)) -> (cast (and A, B))
     if (CastInst *Op1C = dyn_cast<CastInst>(Op1)) {
-      Type *SrcTy = Op0C->getOperand(0)->getType();
       if (Op0C->getOpcode() == Op1C->getOpcode() && // same cast kind ?
           SrcTy == Op1C->getOperand(0)->getType() &&
           SrcTy->isIntOrIntVectorTy()) {
-        Value *Op0COp = Op0C->getOperand(0), *Op1COp = Op1C->getOperand(0);
+        Value *Op1COp = Op1C->getOperand(0);
 
         // Only do this if the casts both really cause code to be generated.
         if (ShouldOptimizeCast(Op0C->getOpcode(), Op0COp, I.getType()) &&
@@ -1500,6 +1520,20 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
       }
     }
 
+    // If we are masking off the sign bit of a floating-point value, convert
+    // this to the canonical fabs intrinsic call and cast back to integer.
+    // The backend should know how to optimize fabs().
+    // TODO: This transform should also apply to vectors.
+    ConstantInt *CI;
+    if (isa<BitCastInst>(Op0C) && SrcTy->isFloatingPointTy() &&
+        match(Op1, m_ConstantInt(CI)) && CI->isMaxValue(true)) {
+      Module *M = I.getModule();
+      Function *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, SrcTy);
+      Value *Call = Builder->CreateCall(Fabs, Op0COp, "fabs");
+      return CastInst::CreateBitOrPointerCast(Call, I.getType());
+    }
+  }
+
   {
     Value *X = nullptr;
     bool OpsSwapped = false;
@@ -1531,157 +1565,189 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
   return Changed ? &I : nullptr;
 }
 
+
 /// Analyze the specified subexpression and see if it is capable of providing
-/// pieces of a bswap.  The subexpression provides pieces of a bswap if it is
-/// proven that each of the non-zero bytes in the output of the expression came
-/// from the corresponding "byte swapped" byte in some other value.
-/// For example, if the current subexpression is "(shl i32 %X, 24)" then
-/// we know that the expression deposits the low byte of %X into the high byte
-/// of the bswap result and that all other bytes are zero.  This expression is
-/// accepted, the high byte of ByteValues is set to X to indicate a correct
-/// match.
+/// pieces of a bswap or bitreverse. The subexpression provides a potential
+/// piece of a bswap or bitreverse if it can be proven that each non-zero bit in
+/// the output of the expression came from a corresponding bit in some other
+/// value. This function is recursive, and the end result is a mapping of
+/// (value, bitnumber) to bitnumber. It is the caller's responsibility to
+/// validate that all `value`s are identical and that the bitnumber to bitnumber
+/// mapping is correct for a bswap or bitreverse.
+///
+/// For example, if the current subexpression if "(shl i32 %X, 24)" then we know
+/// that the expression deposits the low byte of %X into the high byte of the
+/// result and that all other bits are zero. This expression is accepted,
+/// BitValues[24-31] are set to %X and BitProvenance[24-31] are set to [0-7].
 ///
 /// This function returns true if the match was unsuccessful and false if so.
 /// On entry to the function the "OverallLeftShift" is a signed integer value
-/// indicating the number of bytes that the subexpression is later shifted.  For
+/// indicating the number of bits that the subexpression is later shifted.  For
 /// example, if the expression is later right shifted by 16 bits, the
-/// OverallLeftShift value would be -2 on entry.  This is used to specify which
-/// byte of ByteValues is actually being set.
+/// OverallLeftShift value would be -16 on entry.  This is used to specify which
+/// bits of BitValues are actually being set.
 ///
-/// Similarly, ByteMask is a bitmask where a bit is clear if its corresponding
-/// byte is masked to zero by a user.  For example, in (X & 255), X will be
-/// processed with a bytemask of 1.  Because bytemask is 32-bits, this limits
-/// this function to working on up to 32-byte (256 bit) values.  ByteMask is
-/// always in the local (OverallLeftShift) coordinate space.
+/// Similarly, BitMask is a bitmask where a bit is clear if its corresponding
+/// bit is masked to zero by a user.  For example, in (X & 255), X will be
+/// processed with a bytemask of 255. BitMask is always in the local
+/// (OverallLeftShift) coordinate space.
 ///
-static bool CollectBSwapParts(Value *V, int OverallLeftShift, uint32_t ByteMask,
-                              SmallVectorImpl<Value *> &ByteValues) {
+static bool CollectBitParts(Value *V, int OverallLeftShift, APInt BitMask,
+                            SmallVectorImpl<Value *> &BitValues,
+                            SmallVectorImpl<int> &BitProvenance) {
   if (Instruction *I = dyn_cast<Instruction>(V)) {
     // If this is an or instruction, it may be an inner node of the bswap.
-    if (I->getOpcode() == Instruction::Or) {
-      return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask,
-                               ByteValues) ||
-             CollectBSwapParts(I->getOperand(1), OverallLeftShift, ByteMask,
-                               ByteValues);
-    }
-
-    // If this is a logical shift by a constant multiple of 8, recurse with
-    // OverallLeftShift and ByteMask adjusted.
+    if (I->getOpcode() == Instruction::Or)
+      return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask,
+                             BitValues, BitProvenance) ||
+             CollectBitParts(I->getOperand(1), OverallLeftShift, BitMask,
+                             BitValues, BitProvenance);
+
+    // If this is a logical shift by a constant, recurse with OverallLeftShift
+    // and BitMask adjusted.
     if (I->isLogicalShift() && isa<ConstantInt>(I->getOperand(1))) {
       unsigned ShAmt =
-        cast<ConstantInt>(I->getOperand(1))->getLimitedValue(~0U);
-      // Ensure the shift amount is defined and of a byte value.
-      if ((ShAmt & 7) || (ShAmt > 8*ByteValues.size()))
+          cast<ConstantInt>(I->getOperand(1))->getLimitedValue(~0U);
+      // Ensure the shift amount is defined.
+      if (ShAmt > BitValues.size())
         return true;
 
-      unsigned ByteShift = ShAmt >> 3;
+      unsigned BitShift = ShAmt;
       if (I->getOpcode() == Instruction::Shl) {
-        // X << 2 -> collect(X, +2)
-        OverallLeftShift += ByteShift;
-        ByteMask >>= ByteShift;
+        // X << C -> collect(X, +C)
+        OverallLeftShift += BitShift;
+        BitMask = BitMask.lshr(BitShift);
       } else {
-        // X >>u 2 -> collect(X, -2)
-        OverallLeftShift -= ByteShift;
-        ByteMask <<= ByteShift;
-        ByteMask &= (~0U >> (32-ByteValues.size()));
+        // X >>u C -> collect(X, -C)
+        OverallLeftShift -= BitShift;
+        BitMask = BitMask.shl(BitShift);
       }
 
-      if (OverallLeftShift >= (int)ByteValues.size()) return true;
-      if (OverallLeftShift <= -(int)ByteValues.size()) return true;
+      if (OverallLeftShift >= (int)BitValues.size())
+        return true;
+      if (OverallLeftShift <= -(int)BitValues.size())
+        return true;
 
-      return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask,
-                               ByteValues);
+      return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask,
+                             BitValues, BitProvenance);
     }
 
-    // If this is a logical 'and' with a mask that clears bytes, clear the
-    // corresponding bytes in ByteMask.
+    // If this is a logical 'and' with a mask that clears bits, clear the
+    // corresponding bits in BitMask.
     if (I->getOpcode() == Instruction::And &&
         isa<ConstantInt>(I->getOperand(1))) {
-      // Scan every byte of the and mask, seeing if the byte is either 0 or 255.
-      unsigned NumBytes = ByteValues.size();
-      APInt Byte(I->getType()->getPrimitiveSizeInBits(), 255);
+      unsigned NumBits = BitValues.size();
+      APInt Bit(I->getType()->getPrimitiveSizeInBits(), 1);
       const APInt &AndMask = cast<ConstantInt>(I->getOperand(1))->getValue();
 
-      for (unsigned i = 0; i != NumBytes; ++i, Byte <<= 8) {
-        // If this byte is masked out by a later operation, we don't care what
+      for (unsigned i = 0; i != NumBits; ++i, Bit <<= 1) {
+        // If this bit is masked out by a later operation, we don't care what
         // the and mask is.
-        if ((ByteMask & (1 << i)) == 0)
+        if (BitMask[i] == 0)
           continue;
 
-        // If the AndMask is all zeros for this byte, clear the bit.
-        APInt MaskB = AndMask & Byte;
+        // If the AndMask is zero for this bit, clear the bit.
+        APInt MaskB = AndMask & Bit;
         if (MaskB == 0) {
-          ByteMask &= ~(1U << i);
+          BitMask.clearBit(i);
           continue;
         }
 
-        // If the AndMask is not all ones for this byte, it's not a bytezap.
-        if (MaskB != Byte)
-          return true;
-
-        // Otherwise, this byte is kept.
+        // Otherwise, this bit is kept.
       }
 
-      return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask,
-                               ByteValues);
+      return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask,
+                             BitValues, BitProvenance);
     }
   }
 
   // Okay, we got to something that isn't a shift, 'or' or 'and'.  This must be
-  // the input value to the bswap.  Some observations: 1) if more than one byte
-  // is demanded from this input, then it could not be successfully assembled
-  // into a byteswap.  At least one of the two bytes would not be aligned with
-  // their ultimate destination.
-  if (!isPowerOf2_32(ByteMask)) return true;
-  unsigned InputByteNo = countTrailingZeros(ByteMask);
-
-  // 2) The input and ultimate destinations must line up: if byte 3 of an i32
-  // is demanded, it needs to go into byte 0 of the result.  This means that the
-  // byte needs to be shifted until it lands in the right byte bucket.  The
-  // shift amount depends on the position: if the byte is coming from the high
-  // part of the value (e.g. byte 3) then it must be shifted right.  If from the
-  // low part, it must be shifted left.
-  unsigned DestByteNo = InputByteNo + OverallLeftShift;
-  if (ByteValues.size()-1-DestByteNo != InputByteNo)
+  // the input value to the bswap/bitreverse. To be part of a bswap or
+  // bitreverse we must be demanding a contiguous range of bits from it.
+  unsigned InputBitLen = BitMask.countPopulation();
+  unsigned InputBitNo = BitMask.countTrailingZeros();
+  if (BitMask.getBitWidth() - BitMask.countLeadingZeros() - InputBitNo !=
+      InputBitLen)
+    // Not a contiguous set range of bits!
     return true;
 
-  // If the destination byte value is already defined, the values are or'd
-  // together, which isn't a bswap (unless it's an or of the same bits).
-  if (ByteValues[DestByteNo] && ByteValues[DestByteNo] != V)
+  // We know we're moving a contiguous range of bits from the input to the
+  // output. Record which bits in the output came from which bits in the input.
+  unsigned DestBitNo = InputBitNo + OverallLeftShift;
+  for (unsigned I = 0; I < InputBitLen; ++I)
+    BitProvenance[DestBitNo + I] = InputBitNo + I;
+
+  // If the destination bit value is already defined, the values are or'd
+  // together, which isn't a bswap/bitreverse (unless it's an or of the same
+  // bits).
+  if (BitValues[DestBitNo] && BitValues[DestBitNo] != V)
     return true;
-  ByteValues[DestByteNo] = V;
+  for (unsigned I = 0; I < InputBitLen; ++I)
+    BitValues[DestBitNo + I] = V;
+
   return false;
 }
 
-/// Given an OR instruction, check to see if this is a bswap idiom.
-/// If so, insert the new bswap intrinsic and return it.
-Instruction *InstCombiner::MatchBSwap(BinaryOperator &I) {
-  IntegerType *ITy = dyn_cast<IntegerType>(I.getType());
-  if (!ITy || ITy->getBitWidth() % 16 ||
-      // ByteMask only allows up to 32-byte values.
-      ITy->getBitWidth() > 32*8)
-    return nullptr;   // Can only bswap pairs of bytes.  Can't do vectors.
+static bool bitTransformIsCorrectForBSwap(unsigned From, unsigned To,
+                                          unsigned BitWidth) {
+  if (From % 8 != To % 8)
+    return false;
+  // Convert from bit indices to byte indices and check for a byte reversal.
+  From >>= 3;
+  To >>= 3;
+  BitWidth >>= 3;
+  return From == BitWidth - To - 1;
+}
 
-  /// ByteValues - For each byte of the result, we keep track of which value
-  /// defines each byte.
-  SmallVector<Value*, 8> ByteValues;
-  ByteValues.resize(ITy->getBitWidth()/8);
+static bool bitTransformIsCorrectForBitReverse(unsigned From, unsigned To,
+                                               unsigned BitWidth) {
+  return From == BitWidth - To - 1;
+}
 
+/// Given an OR instruction, check to see if this is a bswap or bitreverse
+/// idiom. If so, insert the new intrinsic and return it.
+Instruction *InstCombiner::MatchBSwapOrBitReverse(BinaryOperator &I) {
+  IntegerType *ITy = dyn_cast<IntegerType>(I.getType());
+  if (!ITy)
+    return nullptr;   // Can't do vectors.
+  unsigned BW = ITy->getBitWidth();
+  
+  /// We keep track of which bit (BitProvenance) inside which value (BitValues)
+  /// defines each bit in the result.
+  SmallVector<Value *, 8> BitValues(BW, nullptr);
+  SmallVector<int, 8> BitProvenance(BW, -1);
+  
   // Try to find all the pieces corresponding to the bswap.
-  uint32_t ByteMask = ~0U >> (32-ByteValues.size());
-  if (CollectBSwapParts(&I, 0, ByteMask, ByteValues))
+  APInt BitMask = APInt::getAllOnesValue(BitValues.size());
+  if (CollectBitParts(&I, 0, BitMask, BitValues, BitProvenance))
     return nullptr;
 
-  // Check to see if all of the bytes come from the same value.
-  Value *V = ByteValues[0];
-  if (!V) return nullptr;  // Didn't find a byte?  Must be zero.
+  // Check to see if all of the bits come from the same value.
+  Value *V = BitValues[0];
+  if (!V) return nullptr;  // Didn't find a bit?  Must be zero.
 
-  // Check to make sure that all of the bytes come from the same value.
-  for (unsigned i = 1, e = ByteValues.size(); i != e; ++i)
-    if (ByteValues[i] != V)
-      return nullptr;
-  Module *M = I.getParent()->getParent()->getParent();
-  Function *F = Intrinsic::getDeclaration(M, Intrinsic::bswap, ITy);
+  if (!std::all_of(BitValues.begin(), BitValues.end(),
+                   [&](const Value *X) { return X == V; }))
+    return nullptr;
+
+  // Now, is the bit permutation correct for a bswap or a bitreverse? We can
+  // only byteswap values with an even number of bytes.
+  bool OKForBSwap = BW % 16 == 0, OKForBitReverse = true;;
+  for (unsigned i = 0, e = BitValues.size(); i != e; ++i) {
+    OKForBSwap &= bitTransformIsCorrectForBSwap(BitProvenance[i], i, BW);
+    OKForBitReverse &=
+        bitTransformIsCorrectForBitReverse(BitProvenance[i], i, BW);
+  }
+
+  Intrinsic::ID Intrin;
+  if (OKForBSwap)
+    Intrin = Intrinsic::bswap;
+  else if (OKForBitReverse)
+    Intrin = Intrinsic::bitreverse;
+  else
+    return nullptr;
+
+  Function *F = Intrinsic::getDeclaration(I.getModule(), Intrin, ITy);
   return CallInst::Create(F, V);
 }
 
@@ -1927,14 +1993,14 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
     case ICmpInst::ICMP_EQ:
       if (LHS->getOperand(0) == RHS->getOperand(0)) {
         // if LHSCst and RHSCst differ only by one bit:
-        // (A == C1 || A == C2) -> (A & ~(C1 ^ C2)) == C1
+        // (A == C1 || A == C2) -> (A | (C1 ^ C2)) == C2
         assert(LHSCst->getValue().ule(LHSCst->getValue()));
 
         APInt Xor = LHSCst->getValue() ^ RHSCst->getValue();
         if (Xor.isPowerOf2()) {
-          Value *NegCst = Builder->getInt(~Xor);
-          Value *And = Builder->CreateAnd(LHS->getOperand(0), NegCst);
-          return Builder->CreateICmp(ICmpInst::ICMP_EQ, And, LHSCst);
+          Value *Cst = Builder->getInt(Xor);
+          Value *Or = Builder->CreateOr(LHS->getOperand(0), Cst);
+          return Builder->CreateICmp(ICmpInst::ICMP_EQ, Or, RHSCst);
         }
       }
 
@@ -2220,14 +2286,18 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
   ConstantInt *C1 = nullptr, *C2 = nullptr;
 
   // (A | B) | C  and  A | (B | C)                  -> bswap if possible.
+  bool OrOfOrs = match(Op0, m_Or(m_Value(), m_Value())) ||
+                 match(Op1, m_Or(m_Value(), m_Value()));
   // (A >> B) | (C << D)  and  (A << B) | (B >> C)  -> bswap if possible.
-  if (match(Op0, m_Or(m_Value(), m_Value())) ||
-      match(Op1, m_Or(m_Value(), m_Value())) ||
-      (match(Op0, m_LogicalShift(m_Value(), m_Value())) &&
-       match(Op1, m_LogicalShift(m_Value(), m_Value())))) {
-    if (Instruction *BSwap = MatchBSwap(I))
+  bool OrOfShifts = match(Op0, m_LogicalShift(m_Value(), m_Value())) &&
+                    match(Op1, m_LogicalShift(m_Value(), m_Value()));
+  // (A & B) | (C & D)                              -> bswap if possible.
+  bool OrOfAnds = match(Op0, m_And(m_Value(), m_Value())) &&
+                  match(Op1, m_And(m_Value(), m_Value()));
+
+  if (OrOfOrs || OrOfShifts || OrOfAnds)
+    if (Instruction *BSwap = MatchBSwapOrBitReverse(I))
       return BSwap;
-  }
 
   // (X^C)|Y -> (X|Y)^C iff Y&C == 0
   if (Op0->hasOneUse() &&