Op1 = Op1.getOperand(0);
SDValue LHS, RHS;
- if (Op1.getOpcode() == ISD::SHL) {
- if (ConstantSDNode *And10C = dyn_cast<ConstantSDNode>(Op1.getOperand(0)))
- if (And10C->getZExtValue() == 1) {
- LHS = Op0;
- RHS = Op1.getOperand(1);
- }
- } else if (Op0.getOpcode() == ISD::SHL) {
+ if (Op1.getOpcode() == ISD::SHL)
+ std::swap(Op0, Op1);
+ if (Op0.getOpcode() == ISD::SHL) {
if (ConstantSDNode *And00C = dyn_cast<ConstantSDNode>(Op0.getOperand(0)))
if (And00C->getZExtValue() == 1) {
+ // If we looked past a truncate, check that it's only truncating away
+ // known zeros.
+ unsigned BitWidth = Op0.getValueSizeInBits();
+ unsigned AndBitWidth = And.getValueSizeInBits();
+ if (BitWidth > AndBitWidth) {
+ APInt Mask = APInt::getAllOnesValue(BitWidth), Zeros, Ones;
+ DAG.ComputeMaskedBits(Op0, Mask, Zeros, Ones);
+ if (Zeros.countLeadingOnes() < BitWidth - AndBitWidth)
+ return SDValue();
+ }
LHS = Op1;
RHS = Op0.getOperand(1);
}