Merging r257875:
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineAndOrXor.cpp
index 2bf6faa47b93d108870e793d97d4a3c417a8b58a..76cefd97cd8f1a53479aa94ee9ab3e50185c3a18 100644 (file)
@@ -17,6 +17,7 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/Transforms/Utils/CmpInstAnalysis.h"
+#include "llvm/Transforms/Utils/Local.h"
 using namespace llvm;
 using namespace PatternMatch;
 
@@ -150,8 +151,7 @@ Value *InstCombiner::SimplifyBSwap(BinaryOperator &I) {
   else //if (Op == Instruction::Xor)
     BinOp = Builder->CreateXor(NewLHS, NewRHS);
 
-  Module *M = I.getParent()->getParent()->getParent();
-  Function *F = Intrinsic::getDeclaration(M, Intrinsic::bswap, ITy);
+  Function *F = Intrinsic::getDeclaration(I.getModule(), Intrinsic::bswap, ITy);
   return Builder->CreateCall(F, BinOp);
 }
 
@@ -1528,7 +1528,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
     ConstantInt *CI;
     if (isa<BitCastInst>(Op0C) && SrcTy->isFloatingPointTy() &&
         match(Op1, m_ConstantInt(CI)) && CI->isMaxValue(true)) {
-      Module *M = I.getParent()->getParent()->getParent();
+      Module *M = I.getModule();
       Function *Fabs = Intrinsic::getDeclaration(M, Intrinsic::fabs, SrcTy);
       Value *Call = Builder->CreateCall(Fabs, Op0COp, "fabs");
       return CastInst::CreateBitOrPointerCast(Call, I.getType());
@@ -1566,158 +1566,18 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
   return Changed ? &I : nullptr;
 }
 
-/// Analyze the specified subexpression and see if it is capable of providing
-/// pieces of a bswap.  The subexpression provides pieces of a bswap if it is
-/// proven that each of the non-zero bytes in the output of the expression came
-/// from the corresponding "byte swapped" byte in some other value.
-/// For example, if the current subexpression is "(shl i32 %X, 24)" then
-/// we know that the expression deposits the low byte of %X into the high byte
-/// of the bswap result and that all other bytes are zero.  This expression is
-/// accepted, the high byte of ByteValues is set to X to indicate a correct
-/// match.
-///
-/// This function returns true if the match was unsuccessful and false if so.
-/// On entry to the function the "OverallLeftShift" is a signed integer value
-/// indicating the number of bytes that the subexpression is later shifted.  For
-/// example, if the expression is later right shifted by 16 bits, the
-/// OverallLeftShift value would be -2 on entry.  This is used to specify which
-/// byte of ByteValues is actually being set.
-///
-/// Similarly, ByteMask is a bitmask where a bit is clear if its corresponding
-/// byte is masked to zero by a user.  For example, in (X & 255), X will be
-/// processed with a bytemask of 1.  Because bytemask is 32-bits, this limits
-/// this function to working on up to 32-byte (256 bit) values.  ByteMask is
-/// always in the local (OverallLeftShift) coordinate space.
-///
-static bool CollectBSwapParts(Value *V, int OverallLeftShift, uint32_t ByteMask,
-                              SmallVectorImpl<Value *> &ByteValues) {
-  if (Instruction *I = dyn_cast<Instruction>(V)) {
-    // If this is an or instruction, it may be an inner node of the bswap.
-    if (I->getOpcode() == Instruction::Or) {
-      return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask,
-                               ByteValues) ||
-             CollectBSwapParts(I->getOperand(1), OverallLeftShift, ByteMask,
-                               ByteValues);
-    }
-
-    // If this is a logical shift by a constant multiple of 8, recurse with
-    // OverallLeftShift and ByteMask adjusted.
-    if (I->isLogicalShift() && isa<ConstantInt>(I->getOperand(1))) {
-      unsigned ShAmt =
-        cast<ConstantInt>(I->getOperand(1))->getLimitedValue(~0U);
-      // Ensure the shift amount is defined and of a byte value.
-      if ((ShAmt & 7) || (ShAmt > 8*ByteValues.size()))
-        return true;
-
-      unsigned ByteShift = ShAmt >> 3;
-      if (I->getOpcode() == Instruction::Shl) {
-        // X << 2 -> collect(X, +2)
-        OverallLeftShift += ByteShift;
-        ByteMask >>= ByteShift;
-      } else {
-        // X >>u 2 -> collect(X, -2)
-        OverallLeftShift -= ByteShift;
-        ByteMask <<= ByteShift;
-        ByteMask &= (~0U >> (32-ByteValues.size()));
-      }
-
-      if (OverallLeftShift >= (int)ByteValues.size()) return true;
-      if (OverallLeftShift <= -(int)ByteValues.size()) return true;
-
-      return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask,
-                               ByteValues);
-    }
-
-    // If this is a logical 'and' with a mask that clears bytes, clear the
-    // corresponding bytes in ByteMask.
-    if (I->getOpcode() == Instruction::And &&
-        isa<ConstantInt>(I->getOperand(1))) {
-      // Scan every byte of the and mask, seeing if the byte is either 0 or 255.
-      unsigned NumBytes = ByteValues.size();
-      APInt Byte(I->getType()->getPrimitiveSizeInBits(), 255);
-      const APInt &AndMask = cast<ConstantInt>(I->getOperand(1))->getValue();
-
-      for (unsigned i = 0; i != NumBytes; ++i, Byte <<= 8) {
-        // If this byte is masked out by a later operation, we don't care what
-        // the and mask is.
-        if ((ByteMask & (1 << i)) == 0)
-          continue;
-
-        // If the AndMask is all zeros for this byte, clear the bit.
-        APInt MaskB = AndMask & Byte;
-        if (MaskB == 0) {
-          ByteMask &= ~(1U << i);
-          continue;
-        }
-
-        // If the AndMask is not all ones for this byte, it's not a bytezap.
-        if (MaskB != Byte)
-          return true;
-
-        // Otherwise, this byte is kept.
-      }
-
-      return CollectBSwapParts(I->getOperand(0), OverallLeftShift, ByteMask,
-                               ByteValues);
-    }
-  }
-
-  // Okay, we got to something that isn't a shift, 'or' or 'and'.  This must be
-  // the input value to the bswap.  Some observations: 1) if more than one byte
-  // is demanded from this input, then it could not be successfully assembled
-  // into a byteswap.  At least one of the two bytes would not be aligned with
-  // their ultimate destination.
-  if (!isPowerOf2_32(ByteMask)) return true;
-  unsigned InputByteNo = countTrailingZeros(ByteMask);
-
-  // 2) The input and ultimate destinations must line up: if byte 3 of an i32
-  // is demanded, it needs to go into byte 0 of the result.  This means that the
-  // byte needs to be shifted until it lands in the right byte bucket.  The
-  // shift amount depends on the position: if the byte is coming from the high
-  // part of the value (e.g. byte 3) then it must be shifted right.  If from the
-  // low part, it must be shifted left.
-  unsigned DestByteNo = InputByteNo + OverallLeftShift;
-  if (ByteValues.size()-1-DestByteNo != InputByteNo)
-    return true;
-
-  // If the destination byte value is already defined, the values are or'd
-  // together, which isn't a bswap (unless it's an or of the same bits).
-  if (ByteValues[DestByteNo] && ByteValues[DestByteNo] != V)
-    return true;
-  ByteValues[DestByteNo] = V;
-  return false;
-}
-
-/// Given an OR instruction, check to see if this is a bswap idiom.
-/// If so, insert the new bswap intrinsic and return it.
-Instruction *InstCombiner::MatchBSwap(BinaryOperator &I) {
-  IntegerType *ITy = dyn_cast<IntegerType>(I.getType());
-  if (!ITy || ITy->getBitWidth() % 16 ||
-      // ByteMask only allows up to 32-byte values.
-      ITy->getBitWidth() > 32*8)
-    return nullptr;   // Can only bswap pairs of bytes.  Can't do vectors.
-
-  /// ByteValues - For each byte of the result, we keep track of which value
-  /// defines each byte.
-  SmallVector<Value*, 8> ByteValues;
-  ByteValues.resize(ITy->getBitWidth()/8);
-
-  // Try to find all the pieces corresponding to the bswap.
-  uint32_t ByteMask = ~0U >> (32-ByteValues.size());
-  if (CollectBSwapParts(&I, 0, ByteMask, ByteValues))
+/// Given an OR instruction, check to see if this is a bswap or bitreverse
+/// idiom. If so, insert the new intrinsic and return it.
+Instruction *InstCombiner::MatchBSwapOrBitReverse(BinaryOperator &I) {
+  SmallVector<Instruction*, 4> Insts;
+  if (!recognizeBitReverseOrBSwapIdiom(&I, true, false, Insts))
     return nullptr;
+  Instruction *LastInst = Insts.pop_back_val();
+  LastInst->removeFromParent();
 
-  // Check to see if all of the bytes come from the same value.
-  Value *V = ByteValues[0];
-  if (!V) return nullptr;  // Didn't find a byte?  Must be zero.
-
-  // Check to make sure that all of the bytes come from the same value.
-  for (unsigned i = 1, e = ByteValues.size(); i != e; ++i)
-    if (ByteValues[i] != V)
-      return nullptr;
-  Module *M = I.getParent()->getParent()->getParent();
-  Function *F = Intrinsic::getDeclaration(M, Intrinsic::bswap, ITy);
-  return CallInst::Create(F, V);
+  for (auto *Inst : Insts)
+    Worklist.Add(Inst);
+  return LastInst;
 }
 
 /// We have an expression of the form (A&C)|(B&D).  Check if A is (cond?-1:0)
@@ -2265,7 +2125,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
                   match(Op1, m_And(m_Value(), m_Value()));
 
   if (OrOfOrs || OrOfShifts || OrOfAnds)
-    if (Instruction *BSwap = MatchBSwap(I))
+    if (Instruction *BSwap = MatchBSwapOrBitReverse(I))
       return BSwap;
 
   // (X^C)|Y -> (X|Y)^C iff Y&C == 0