[SDAG] Now that we have a way to communicate the exact bit on sdiv use it to simplify...
[oota-llvm.git] / lib / CodeGen / SelectionDAG / TargetLowering.cpp
index 34ddeb7e9c39add544dfc247e90df09393f74964..e7722b392a8183b52a7854eea5a980cb5928cc07 100644 (file)
@@ -254,7 +254,7 @@ const MCExpr *
 TargetLowering::getPICJumpTableRelocBaseExpr(const MachineFunction *MF,
                                              unsigned JTI,MCContext &Ctx) const{
   // The normal PIC reloc base is the label at the start of the jump table.
-  return MCSymbolRefExpr::Create(MF->getJTISymbol(JTI, Ctx), Ctx);
+  return MCSymbolRefExpr::create(MF->getJTISymbol(JTI, Ctx), Ctx);
 }
 
 bool
@@ -700,6 +700,13 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op,
       if (ShAmt >= BitWidth)
         break;
 
+      APInt InDemandedMask = (NewMask << ShAmt);
+
+      // If the shift is exact, then it does demand the low bits (and knows that
+      // they are zero).
+      if (cast<BinaryWithFlagsSDNode>(Op)->Flags.hasExact())
+        InDemandedMask |= APInt::getLowBitsSet(BitWidth, ShAmt);
+
       // If this is ((X << C1) >>u ShAmt), see if we can simplify this into a
       // single shift.  We can do this if the top bits (which are shifted out)
       // are never demanded.
@@ -722,7 +729,7 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op,
       }
 
       // Compute the new bits that are at the top now.
-      if (SimplifyDemandedBits(InOp, (NewMask << ShAmt),
+      if (SimplifyDemandedBits(InOp, InDemandedMask,
                                KnownZero, KnownOne, TLO, Depth+1))
         return true;
       assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
@@ -753,6 +760,11 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op,
 
       APInt InDemandedMask = (NewMask << ShAmt);
 
+      // If the shift is exact, then it does demand the low bits (and knows that
+      // they are zero).
+      if (cast<BinaryWithFlagsSDNode>(Op)->Flags.hasExact())
+        InDemandedMask |= APInt::getLowBitsSet(BitWidth, ShAmt);
+
       // If any of the demanded bits are produced by the sign extension, we also
       // demand the input sign bit.
       APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt);
@@ -771,10 +783,13 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op,
 
       // If the input sign bit is known to be zero, or if none of the top bits
       // are demanded, turn this into an unsigned shift right.
-      if (KnownZero.intersects(SignBit) || (HighBits & ~NewMask) == HighBits)
-        return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT,
-                                                 Op.getOperand(0),
-                                                 Op.getOperand(1)));
+      if (KnownZero.intersects(SignBit) || (HighBits & ~NewMask) == HighBits) {
+        SDNodeFlags Flags;
+        Flags.setExact(cast<BinaryWithFlagsSDNode>(Op)->Flags.hasExact());
+        return TLO.CombineTo(Op,
+                             TLO.DAG.getNode(ISD::SRL, dl, VT, Op.getOperand(0),
+                                             Op.getOperand(1), &Flags));
+      }
 
       int Log2 = NewMask.exactLogBase2();
       if (Log2 >= 0) {
@@ -1086,9 +1101,19 @@ bool TargetLowering::SimplifyDemandedBits(SDValue Op,
 
   // If we know the value of all of the demanded bits, return this as a
   // constant.
-  if ((NewMask & (KnownZero|KnownOne)) == NewMask)
+  if ((NewMask & (KnownZero|KnownOne)) == NewMask) {
+    // Avoid folding to a constant if any OpaqueConstant is involved.
+    const SDNode *N = Op.getNode();
+    for (SDNodeIterator I = SDNodeIterator::begin(N),
+         E = SDNodeIterator::end(N); I != E; ++I) {
+      SDNode *Op = *I;
+      if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op))
+        if (C->isOpaque())
+          return false;
+    }
     return TLO.CombineTo(Op,
                          TLO.DAG.getConstant(KnownOne, dl, Op.getValueType()));
+  }
 
   return false;
 }
@@ -1730,7 +1755,8 @@ TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
           ShiftBits = C1.countTrailingZeros();
         }
         NewC = NewC.lshr(ShiftBits);
-        if (ShiftBits && isLegalICmpImmediate(NewC.getSExtValue())) {
+        if (ShiftBits && NewC.getMinSignedBits() <= 64 &&
+          isLegalICmpImmediate(NewC.getSExtValue())) {
           EVT ShiftTy = DCI.isBeforeLegalize() ?
             getPointerTy() : getShiftAmountTy(N0.getValueType());
           EVT CmpTy = N0.getValueType();
@@ -2648,20 +2674,21 @@ void TargetLowering::ComputeConstraintToUse(AsmOperandInfo &OpInfo,
 
 /// \brief Given an exact SDIV by a constant, create a multiplication
 /// with the multiplicative inverse of the constant.
-SDValue TargetLowering::BuildExactSDIV(SDValue Op1, SDValue Op2, SDLoc dl,
-                                       SelectionDAG &DAG) const {
-  ConstantSDNode *C = cast<ConstantSDNode>(Op2);
-  APInt d = C->getAPIntValue();
+static SDValue BuildExactSDIV(const TargetLowering &TLI, SDValue Op1, APInt d,
+                              SDLoc dl, SelectionDAG &DAG,
+                              std::vector<SDNode *> &Created) {
   assert(d != 0 && "Division by zero!");
 
   // Shift the value upfront if it is even, so the LSB is one.
   unsigned ShAmt = d.countTrailingZeros();
   if (ShAmt) {
     // TODO: For UDIV use SRL instead of SRA.
-    SDValue Amt = DAG.getConstant(ShAmt, dl,
-                                  getShiftAmountTy(Op1.getValueType()));
-    Op1 = DAG.getNode(ISD::SRA, dl, Op1.getValueType(), Op1, Amt, false, false,
-                      true);
+    SDValue Amt =
+        DAG.getConstant(ShAmt, dl, TLI.getShiftAmountTy(Op1.getValueType()));
+    SDNodeFlags Flags;
+    Flags.setExact(true);
+    Op1 = DAG.getNode(ISD::SRA, dl, Op1.getValueType(), Op1, Amt, &Flags);
+    Created.push_back(Op1.getNode());
     d = d.ashr(ShAmt);
   }
 
@@ -2670,8 +2697,10 @@ SDValue TargetLowering::BuildExactSDIV(SDValue Op1, SDValue Op2, SDLoc dl,
   while ((t = d*xn) != 1)
     xn *= APInt(d.getBitWidth(), 2) - t;
 
-  Op2 = DAG.getConstant(xn, dl, Op1.getValueType());
-  return DAG.getNode(ISD::MUL, dl, Op1.getValueType(), Op1, Op2);
+  SDValue Op2 = DAG.getConstant(xn, dl, Op1.getValueType());
+  SDValue Mul = DAG.getNode(ISD::MUL, dl, Op1.getValueType(), Op1, Op2);
+  Created.push_back(Mul.getNode());
+  return Mul;
 }
 
 /// \brief Given an ISD::SDIV node expressing a divide by constant,
@@ -2691,6 +2720,10 @@ SDValue TargetLowering::BuildSDIV(SDNode *N, const APInt &Divisor,
   if (!isTypeLegal(VT))
     return SDValue();
 
+  // If the sdiv has an 'exact' bit we can use a simpler lowering.
+  if (cast<BinaryWithFlagsSDNode>(N)->Flags.hasExact())
+    return BuildExactSDIV(*this, N->getOperand(0), Divisor, dl, DAG, *Created);
+
   APInt::ms magics = Divisor.magic();
 
   // Multiply the numerator (operand 0) by the magic value