#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Transforms/Utils/CmpInstAnalysis.h"
+#include "llvm/Transforms/Utils/Local.h"
using namespace llvm;
using namespace PatternMatch;
return Changed ? &I : nullptr;
}
-
-/// Analyze the specified subexpression and see if it is capable of providing
-/// pieces of a bswap or bitreverse. The subexpression provides a potential
-/// piece of a bswap or bitreverse if it can be proven that each non-zero bit in
-/// the output of the expression came from a corresponding bit in some other
-/// value. This function is recursive, and the end result is a mapping of
-/// (value, bitnumber) to bitnumber. It is the caller's responsibility to
-/// validate that all `value`s are identical and that the bitnumber to bitnumber
-/// mapping is correct for a bswap or bitreverse.
-///
-/// For example, if the current subexpression if "(shl i32 %X, 24)" then we know
-/// that the expression deposits the low byte of %X into the high byte of the
-/// result and that all other bits are zero. This expression is accepted,
-/// BitValues[24-31] are set to %X and BitProvenance[24-31] are set to [0-7].
-///
-/// This function returns true if the match was unsuccessful and false if so.
-/// On entry to the function the "OverallLeftShift" is a signed integer value
-/// indicating the number of bits that the subexpression is later shifted. For
-/// example, if the expression is later right shifted by 16 bits, the
-/// OverallLeftShift value would be -16 on entry. This is used to specify which
-/// bits of BitValues are actually being set.
-///
-/// Similarly, BitMask is a bitmask where a bit is clear if its corresponding
-/// bit is masked to zero by a user. For example, in (X & 255), X will be
-/// processed with a bytemask of 255. BitMask is always in the local
-/// (OverallLeftShift) coordinate space.
-///
-static bool CollectBitParts(Value *V, int OverallLeftShift, APInt BitMask,
- SmallVectorImpl<Value *> &BitValues,
- SmallVectorImpl<int> &BitProvenance) {
- if (Instruction *I = dyn_cast<Instruction>(V)) {
- // If this is an or instruction, it may be an inner node of the bswap.
- if (I->getOpcode() == Instruction::Or)
- return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask,
- BitValues, BitProvenance) ||
- CollectBitParts(I->getOperand(1), OverallLeftShift, BitMask,
- BitValues, BitProvenance);
-
- // If this is a logical shift by a constant, recurse with OverallLeftShift
- // and BitMask adjusted.
- if (I->isLogicalShift() && isa<ConstantInt>(I->getOperand(1))) {
- unsigned ShAmt =
- cast<ConstantInt>(I->getOperand(1))->getLimitedValue(~0U);
- // Ensure the shift amount is defined.
- if (ShAmt > BitValues.size())
- return true;
-
- unsigned BitShift = ShAmt;
- if (I->getOpcode() == Instruction::Shl) {
- // X << C -> collect(X, +C)
- OverallLeftShift += BitShift;
- BitMask = BitMask.lshr(BitShift);
- } else {
- // X >>u C -> collect(X, -C)
- OverallLeftShift -= BitShift;
- BitMask = BitMask.shl(BitShift);
- }
-
- if (OverallLeftShift >= (int)BitValues.size())
- return true;
- if (OverallLeftShift <= -(int)BitValues.size())
- return true;
-
- return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask,
- BitValues, BitProvenance);
- }
-
- // If this is a logical 'and' with a mask that clears bits, clear the
- // corresponding bits in BitMask.
- if (I->getOpcode() == Instruction::And &&
- isa<ConstantInt>(I->getOperand(1))) {
- unsigned NumBits = BitValues.size();
- APInt Bit(I->getType()->getPrimitiveSizeInBits(), 1);
- const APInt &AndMask = cast<ConstantInt>(I->getOperand(1))->getValue();
-
- for (unsigned i = 0; i != NumBits; ++i, Bit <<= 1) {
- // If this bit is masked out by a later operation, we don't care what
- // the and mask is.
- if (BitMask[i] == 0)
- continue;
-
- // If the AndMask is zero for this bit, clear the bit.
- APInt MaskB = AndMask & Bit;
- if (MaskB == 0) {
- BitMask.clearBit(i);
- continue;
- }
-
- // Otherwise, this bit is kept.
- }
-
- return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask,
- BitValues, BitProvenance);
- }
- }
-
- // Okay, we got to something that isn't a shift, 'or' or 'and'. This must be
- // the input value to the bswap/bitreverse. To be part of a bswap or
- // bitreverse we must be demanding a contiguous range of bits from it.
- unsigned InputBitLen = BitMask.countPopulation();
- unsigned InputBitNo = BitMask.countTrailingZeros();
- if (BitMask.getBitWidth() - BitMask.countLeadingZeros() - InputBitNo !=
- InputBitLen)
- // Not a contiguous set range of bits!
- return true;
-
- // We know we're moving a contiguous range of bits from the input to the
- // output. Record which bits in the output came from which bits in the input.
- unsigned DestBitNo = InputBitNo + OverallLeftShift;
- for (unsigned I = 0; I < InputBitLen; ++I)
- BitProvenance[DestBitNo + I] = InputBitNo + I;
-
- // If the destination bit value is already defined, the values are or'd
- // together, which isn't a bswap/bitreverse (unless it's an or of the same
- // bits).
- if (BitValues[DestBitNo] && BitValues[DestBitNo] != V)
- return true;
- for (unsigned I = 0; I < InputBitLen; ++I)
- BitValues[DestBitNo + I] = V;
-
- return false;
-}
-
-static bool bitTransformIsCorrectForBSwap(unsigned From, unsigned To,
- unsigned BitWidth) {
- if (From % 8 != To % 8)
- return false;
- // Convert from bit indices to byte indices and check for a byte reversal.
- From >>= 3;
- To >>= 3;
- BitWidth >>= 3;
- return From == BitWidth - To - 1;
-}
-
-static bool bitTransformIsCorrectForBitReverse(unsigned From, unsigned To,
- unsigned BitWidth) {
- return From == BitWidth - To - 1;
-}
-
/// Given an OR instruction, check to see if this is a bswap or bitreverse
/// idiom. If so, insert the new intrinsic and return it.
Instruction *InstCombiner::MatchBSwapOrBitReverse(BinaryOperator &I) {
- IntegerType *ITy = dyn_cast<IntegerType>(I.getType());
- if (!ITy)
- return nullptr; // Can't do vectors.
- unsigned BW = ITy->getBitWidth();
-
- /// We keep track of which bit (BitProvenance) inside which value (BitValues)
- /// defines each bit in the result.
- SmallVector<Value *, 8> BitValues(BW, nullptr);
- SmallVector<int, 8> BitProvenance(BW, -1);
-
- // Try to find all the pieces corresponding to the bswap.
- APInt BitMask = APInt::getAllOnesValue(BitValues.size());
- if (CollectBitParts(&I, 0, BitMask, BitValues, BitProvenance))
- return nullptr;
-
- // Check to see if all of the bits come from the same value.
- Value *V = BitValues[0];
- if (!V) return nullptr; // Didn't find a bit? Must be zero.
-
- if (!std::all_of(BitValues.begin(), BitValues.end(),
- [&](const Value *X) { return X == V; }))
- return nullptr;
-
- // Now, is the bit permutation correct for a bswap or a bitreverse? We can
- // only byteswap values with an even number of bytes.
- bool OKForBSwap = BW % 16 == 0, OKForBitReverse = true;;
- for (unsigned i = 0, e = BitValues.size(); i != e; ++i) {
- OKForBSwap &= bitTransformIsCorrectForBSwap(BitProvenance[i], i, BW);
- OKForBitReverse &=
- bitTransformIsCorrectForBitReverse(BitProvenance[i], i, BW);
- }
-
- Intrinsic::ID Intrin;
- if (OKForBSwap)
- Intrin = Intrinsic::bswap;
- else if (OKForBitReverse)
- Intrin = Intrinsic::bitreverse;
- else
+ SmallVector<Instruction*, 4> Insts;
+ if (!recognizeBitReverseOrBSwapIdiom(&I, true, false, Insts))
return nullptr;
+ Instruction *LastInst = Insts.pop_back_val();
+ LastInst->removeFromParent();
- Function *F = Intrinsic::getDeclaration(I.getModule(), Intrin, ITy);
- return CallInst::Create(F, V);
+ for (auto *Inst : Insts)
+ Worklist.Add(Inst);
+ return LastInst;
}
/// We have an expression of the form (A&C)|(B&D). Check if A is (cond?-1:0)