+
+/// A potential constituent of a bitreverse or bswap expression. See
+/// collectBitParts for a fuller explanation.
+struct BitPart {
+ BitPart(Value *P, unsigned BW) : Provider(P) {
+ Provenance.resize(BW);
+ }
+
+ /// The Value that this is a bitreverse/bswap of.
+ Value *Provider;
+ /// The "provenance" of each bit. Provenance[A] = B means that bit A
+ /// in Provider becomes bit B in the result of this expression.
+ SmallVector<int8_t, 32> Provenance; // int8_t means max size is i128.
+
+ enum { Unset = -1 };
+};
+
+/// Analyze the specified subexpression and see if it is capable of providing
+/// pieces of a bswap or bitreverse. The subexpression provides a potential
+/// piece of a bswap or bitreverse if it can be proven that each non-zero bit in
+/// the output of the expression came from a corresponding bit in some other
+/// value. This function is recursive, and the end result is a mapping of
+/// bitnumber to bitnumber. It is the caller's responsibility to validate that
+/// the bitnumber to bitnumber mapping is correct for a bswap or bitreverse.
+///
+/// For example, if the current subexpression if "(shl i32 %X, 24)" then we know
+/// that the expression deposits the low byte of %X into the high byte of the
+/// result and that all other bits are zero. This expression is accepted and a
+/// BitPart is returned with Provider set to %X and Provenance[24-31] set to
+/// [0-7].
+///
+/// To avoid revisiting values, the BitPart results are memoized into the
+/// provided map. To avoid unnecessary copying of BitParts, BitParts are
+/// constructed in-place in the \c BPS map. Because of this \c BPS needs to
+/// store BitParts objects, not pointers. As we need the concept of a nullptr
+/// BitParts (Value has been analyzed and the analysis failed), we an Optional
+/// type instead to provide the same functionality.
+///
+/// Because we pass around references into \c BPS, we must use a container that
+/// does not invalidate internal references (std::map instead of DenseMap).
+///
+static const Optional<BitPart> &
+collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals,
+ std::map<Value *, Optional<BitPart>> &BPS) {
+ auto I = BPS.find(V);
+ if (I != BPS.end())
+ return I->second;
+
+ auto &Result = BPS[V] = None;
+ auto BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
+
+ if (Instruction *I = dyn_cast<Instruction>(V)) {
+ // If this is an or instruction, it may be an inner node of the bswap.
+ if (I->getOpcode() == Instruction::Or) {
+ auto &A = collectBitParts(I->getOperand(0), MatchBSwaps,
+ MatchBitReversals, BPS);
+ auto &B = collectBitParts(I->getOperand(1), MatchBSwaps,
+ MatchBitReversals, BPS);
+ if (!A || !B)
+ return Result;
+
+ // Try and merge the two together.
+ if (!A->Provider || A->Provider != B->Provider)
+ return Result;
+
+ Result = BitPart(A->Provider, BitWidth);
+ for (unsigned i = 0; i < A->Provenance.size(); ++i) {
+ if (A->Provenance[i] != BitPart::Unset &&
+ B->Provenance[i] != BitPart::Unset &&
+ A->Provenance[i] != B->Provenance[i])
+ return Result = None;
+
+ if (A->Provenance[i] == BitPart::Unset)
+ Result->Provenance[i] = B->Provenance[i];
+ else
+ Result->Provenance[i] = A->Provenance[i];
+ }
+
+ return Result;
+ }
+
+ // If this is a logical shift by a constant, recurse then shift the result.
+ if (I->isLogicalShift() && isa<ConstantInt>(I->getOperand(1))) {
+ unsigned BitShift =
+ cast<ConstantInt>(I->getOperand(1))->getLimitedValue(~0U);
+ // Ensure the shift amount is defined.
+ if (BitShift > BitWidth)
+ return Result;
+
+ auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps,
+ MatchBitReversals, BPS);
+ if (!Res)
+ return Result;
+ Result = Res;
+
+ // Perform the "shift" on BitProvenance.
+ auto &P = Result->Provenance;
+ if (I->getOpcode() == Instruction::Shl) {
+ P.erase(std::prev(P.end(), BitShift), P.end());
+ P.insert(P.begin(), BitShift, BitPart::Unset);
+ } else {
+ P.erase(P.begin(), std::next(P.begin(), BitShift));
+ P.insert(P.end(), BitShift, BitPart::Unset);
+ }
+
+ return Result;
+ }
+
+ // If this is a logical 'and' with a mask that clears bits, recurse then
+ // unset the appropriate bits.
+ if (I->getOpcode() == Instruction::And &&
+ isa<ConstantInt>(I->getOperand(1))) {
+ APInt Bit(I->getType()->getPrimitiveSizeInBits(), 1);
+ const APInt &AndMask = cast<ConstantInt>(I->getOperand(1))->getValue();
+
+ // Check that the mask allows a multiple of 8 bits for a bswap, for an
+ // early exit.
+ unsigned NumMaskedBits = AndMask.countPopulation();
+ if (!MatchBitReversals && NumMaskedBits % 8 != 0)
+ return Result;
+
+ auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps,
+ MatchBitReversals, BPS);
+ if (!Res)
+ return Result;
+ Result = Res;
+
+ for (unsigned i = 0; i < BitWidth; ++i, Bit <<= 1)
+ // If the AndMask is zero for this bit, clear the bit.
+ if ((AndMask & Bit) == 0)
+ Result->Provenance[i] = BitPart::Unset;
+
+ return Result;
+ }
+ }
+
+ // Okay, we got to something that isn't a shift, 'or' or 'and'. This must be
+ // the input value to the bswap/bitreverse.
+ Result = BitPart(V, BitWidth);
+ for (unsigned i = 0; i < BitWidth; ++i)
+ Result->Provenance[i] = i;
+ return Result;
+}
+
+static bool bitTransformIsCorrectForBSwap(unsigned From, unsigned To,
+ unsigned BitWidth) {
+ if (From % 8 != To % 8)
+ return false;
+ // Convert from bit indices to byte indices and check for a byte reversal.
+ From >>= 3;
+ To >>= 3;
+ BitWidth >>= 3;
+ return From == BitWidth - To - 1;
+}
+
+static bool bitTransformIsCorrectForBitReverse(unsigned From, unsigned To,
+ unsigned BitWidth) {
+ return From == BitWidth - To - 1;
+}
+
+/// Given an OR instruction, check to see if this is a bitreverse
+/// idiom. If so, insert the new intrinsic and return true.
+bool llvm::recognizeBitReverseOrBSwapIdiom(
+ Instruction *I, bool MatchBSwaps, bool MatchBitReversals,
+ SmallVectorImpl<Instruction *> &InsertedInsts) {
+ if (Operator::getOpcode(I) != Instruction::Or)
+ return false;
+ if (!MatchBSwaps && !MatchBitReversals)
+ return false;
+ IntegerType *ITy = dyn_cast<IntegerType>(I->getType());
+ if (!ITy || ITy->getBitWidth() > 128)
+ return false; // Can't do vectors or integers > 128 bits.
+ unsigned BW = ITy->getBitWidth();
+
+ // Try to find all the pieces corresponding to the bswap.
+ std::map<Value *, Optional<BitPart>> BPS;
+ auto Res = collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS);
+ if (!Res)
+ return false;
+ auto &BitProvenance = Res->Provenance;
+
+ // Now, is the bit permutation correct for a bswap or a bitreverse? We can
+ // only byteswap values with an even number of bytes.
+ bool OKForBSwap = BW % 16 == 0, OKForBitReverse = true;
+ for (unsigned i = 0; i < BW; ++i) {
+ OKForBSwap &= bitTransformIsCorrectForBSwap(BitProvenance[i], i, BW);
+ OKForBitReverse &=
+ bitTransformIsCorrectForBitReverse(BitProvenance[i], i, BW);
+ }
+
+ Intrinsic::ID Intrin;
+ if (OKForBSwap && MatchBSwaps)
+ Intrin = Intrinsic::bswap;
+ else if (OKForBitReverse && MatchBitReversals)
+ Intrin = Intrinsic::bitreverse;
+ else
+ return false;
+
+ Function *F = Intrinsic::getDeclaration(I->getModule(), Intrin, ITy);
+ InsertedInsts.push_back(CallInst::Create(F, Res->Provider, "rev", I));
+ return true;
+}