TargetLowering.h #includes SelectionDAGNodes.h, so it doesn't need its
[oota-llvm.git] / lib / CodeGen / SelectionDAG / TargetLowering.cpp
index 01977908be6229bcf3ea9e468dca63ef95be5d59..bbd222285e4cf1fe2a1550cddb07e799d4c6b6ad 100644 (file)
@@ -398,8 +398,6 @@ static void InitCmpLibcallCCs(ISD::CondCode *CCs) {
 
 TargetLowering::TargetLowering(TargetMachine &tm)
   : TM(tm), TD(TM.getTargetData()) {
-  assert(ISD::BUILTIN_OP_END <= OpActionsCapacity &&
-         "Fixed size array in TargetLowering is not large enough!");
   // All operations default to being supported.
   memset(OpActions, 0, sizeof(OpActions));
   memset(LoadExtActions, 0, sizeof(LoadExtActions));
@@ -572,8 +570,32 @@ void TargetLowering::computeRegisterProperties() {
                                IntermediateVT, NumIntermediates,
                                RegisterVT);
       RegisterTypeForVT[i] = RegisterVT;
-      TransformToType[i] = MVT::Other; // this isn't actually used
-      ValueTypeActions.setTypeAction(VT, Promote);
+      
+      // Determine if there is a legal wider type.
+      bool IsLegalWiderType = false;
+      MVT EltVT = VT.getVectorElementType();
+      unsigned NElts = VT.getVectorNumElements();
+      for (unsigned nVT = i+1; nVT <= MVT::LAST_VECTOR_VALUETYPE; ++nVT) {
+        MVT SVT = (MVT::SimpleValueType)nVT;
+        if (isTypeLegal(SVT) && SVT.getVectorElementType() == EltVT &&
+            SVT.getVectorNumElements() > NElts) {
+          TransformToType[i] = SVT;
+          ValueTypeActions.setTypeAction(VT, Promote);
+          IsLegalWiderType = true;
+          break;
+        }
+      }
+      if (!IsLegalWiderType) {
+        MVT NVT = VT.getPow2VectorType();
+        if (NVT == VT) {
+          // Type is already a power of 2.  The default action is to split.
+          TransformToType[i] = MVT::Other;
+          ValueTypeActions.setTypeAction(VT, Expand);
+        } else {
+          TransformToType[i] = NVT;
+          ValueTypeActions.setTypeAction(VT, Promote);
+        }
+      }
     }
   }
 }
@@ -583,7 +605,7 @@ const char *TargetLowering::getTargetNodeName(unsigned Opcode) const {
 }
 
 
-MVT TargetLowering::getSetCCResultType(const SDValue &) const {
+MVT TargetLowering::getSetCCResultType(MVT VT) const {
   return getValueType(TD->getIntPtrType());
 }
 
@@ -1583,12 +1605,19 @@ TargetLowering::SimplifySetCC(MVT VT, SDValue N0, SDValue N1,
       // by changing cc.
 
       // SETUGT X, SINTMAX  -> SETLT X, 0
-      if (Cond == ISD::SETUGT && OperandBitSize != 1 &&
-          C1 == (~0ULL >> (65-OperandBitSize)))
+      if (Cond == ISD::SETUGT && 
+          C1 == APInt::getSignedMaxValue(OperandBitSize))
         return DAG.getSetCC(VT, N0, DAG.getConstant(0, N1.getValueType()),
                             ISD::SETLT);
 
-      // FIXME: Implement the rest of these.
+      // SETULT X, SINTMIN  -> SETGT X, -1
+      if (Cond == ISD::SETULT &&
+          C1 == APInt::getSignedMinValue(OperandBitSize)) {
+        SDValue ConstMinusOne =
+            DAG.getConstant(APInt::getAllOnesValue(OperandBitSize),
+                            N1.getValueType());
+        return DAG.getSetCC(VT, N0, ConstMinusOne, ISD::SETGT);
+      }
 
       // Fold bit comparisons when we can.
       if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
@@ -2214,180 +2243,105 @@ bool TargetLowering::isLegalAddressingMode(const AddrMode &AM,
   return true;
 }
 
-// Magic for divide replacement
-
-struct ms {
-  int64_t m;  // magic number
-  int64_t s;  // shift amount
-};
-
 struct mu {
-  uint64_t m; // magic number
-  int64_t a;  // add indicator
-  int64_t s;  // shift amount
+  APInt m;     // magic number
+  bool a;      // add indicator
+  unsigned s;  // shift amount
 };
 
-/// magic - calculate the magic numbers required to codegen an integer sdiv as
-/// a sequence of multiply and shifts.  Requires that the divisor not be 0, 1,
-/// or -1.
-static ms magic32(int32_t d) {
-  int32_t p;
-  uint32_t ad, anc, delta, q1, r1, q2, r2, t;
-  const uint32_t two31 = 0x80000000U;
-  struct ms mag;
-  
-  ad = abs(d);
-  t = two31 + ((uint32_t)d >> 31);
-  anc = t - 1 - t%ad;   // absolute value of nc
-  p = 31;               // initialize p
-  q1 = two31/anc;       // initialize q1 = 2p/abs(nc)
-  r1 = two31 - q1*anc;  // initialize r1 = rem(2p,abs(nc))
-  q2 = two31/ad;        // initialize q2 = 2p/abs(d)
-  r2 = two31 - q2*ad;   // initialize r2 = rem(2p,abs(d))
-  do {
-    p = p + 1;
-    q1 = 2*q1;        // update q1 = 2p/abs(nc)
-    r1 = 2*r1;        // update r1 = rem(2p/abs(nc))
-    if (r1 >= anc) {  // must be unsigned comparison
-      q1 = q1 + 1;
-      r1 = r1 - anc;
-    }
-    q2 = 2*q2;        // update q2 = 2p/abs(d)
-    r2 = 2*r2;        // update r2 = rem(2p/abs(d))
-    if (r2 >= ad) {   // must be unsigned comparison
-      q2 = q2 + 1;
-      r2 = r2 - ad;
-    }
-    delta = ad - r2;
-  } while (q1 < delta || (q1 == delta && r1 == 0));
-  
-  mag.m = (int32_t)(q2 + 1); // make sure to sign extend
-  if (d < 0) mag.m = -mag.m; // resulting magic number
-  mag.s = p - 32;            // resulting shift
-  return mag;
-}
-
 /// magicu - calculate the magic numbers required to codegen an integer udiv as
 /// a sequence of multiply, add and shifts.  Requires that the divisor not be 0.
-static mu magicu32(uint32_t d) {
-  int32_t p;
-  uint32_t nc, delta, q1, r1, q2, r2;
+static mu magicu(const APInt& d) {
+  unsigned p;
+  APInt nc, delta, q1, r1, q2, r2;
   struct mu magu;
   magu.a = 0;               // initialize "add" indicator
-  nc = - 1 - (-d)%d;
-  p = 31;                   // initialize p
-  q1 = 0x80000000/nc;       // initialize q1 = 2p/nc
-  r1 = 0x80000000 - q1*nc;  // initialize r1 = rem(2p,nc)
-  q2 = 0x7FFFFFFF/d;        // initialize q2 = (2p-1)/d
-  r2 = 0x7FFFFFFF - q2*d;   // initialize r2 = rem((2p-1),d)
+  APInt allOnes = APInt::getAllOnesValue(d.getBitWidth());
+  APInt signedMin = APInt::getSignedMinValue(d.getBitWidth());
+  APInt signedMax = APInt::getSignedMaxValue(d.getBitWidth());
+
+  nc = allOnes - (-d).urem(d);
+  p = d.getBitWidth() - 1;  // initialize p
+  q1 = signedMin.udiv(nc);  // initialize q1 = 2p/nc
+  r1 = signedMin - q1*nc;   // initialize r1 = rem(2p,nc)
+  q2 = signedMax.udiv(d);   // initialize q2 = (2p-1)/d
+  r2 = signedMax - q2*d;    // initialize r2 = rem((2p-1),d)
   do {
     p = p + 1;
-    if (r1 >= nc - r1 ) {
-      q1 = 2*q1 + 1;  // update q1
-      r1 = 2*r1 - nc; // update r1
+    if (r1.uge(nc - r1)) {
+      q1 = q1 + q1 + 1;  // update q1
+      r1 = r1 + r1 - nc; // update r1
     }
     else {
-      q1 = 2*q1; // update q1
-      r1 = 2*r1; // update r1
+      q1 = q1+q1; // update q1
+      r1 = r1+r1; // update r1
     }
-    if (r2 + 1 >= d - r2) {
-      if (q2 >= 0x7FFFFFFF) magu.a = 1;
-      q2 = 2*q2 + 1;     // update q2
-      r2 = 2*r2 + 1 - d; // update r2
+    if ((r2 + 1).uge(d - r2)) {
+      if (q2.uge(signedMax)) magu.a = 1;
+      q2 = q2+q2 + 1;     // update q2
+      r2 = r2+r2 + 1 - d; // update r2
     }
     else {
-      if (q2 >= 0x80000000) magu.a = 1;
-      q2 = 2*q2;     // update q2
-      r2 = 2*r2 + 1; // update r2
+      if (q2.uge(signedMin)) magu.a = 1;
+      q2 = q2+q2;     // update q2
+      r2 = r2+r2 + 1; // update r2
     }
     delta = d - 1 - r2;
-  } while (p < 64 && (q1 < delta || (q1 == delta && r1 == 0)));
+  } while (p < d.getBitWidth()*2 &&
+           (q1.ult(delta) || (q1 == delta && r1 == 0)));
   magu.m = q2 + 1; // resulting magic number
-  magu.s = p - 32;  // resulting shift
+  magu.s = p - d.getBitWidth();  // resulting shift
   return magu;
 }
 
+// Magic for divide replacement
+struct ms {
+  APInt m;  // magic number
+  unsigned s;  // shift amount
+};
+
 /// magic - calculate the magic numbers required to codegen an integer sdiv as
 /// a sequence of multiply and shifts.  Requires that the divisor not be 0, 1,
 /// or -1.
-static ms magic64(int64_t d) {
-  int64_t p;
-  uint64_t ad, anc, delta, q1, r1, q2, r2, t;
-  const uint64_t two63 = 9223372036854775808ULL; // 2^63
+static ms magic(const APInt& d) {
+  unsigned p;
+  APInt ad, anc, delta, q1, r1, q2, r2, t;
+  APInt allOnes = APInt::getAllOnesValue(d.getBitWidth());
+  APInt signedMin = APInt::getSignedMinValue(d.getBitWidth());
+  APInt signedMax = APInt::getSignedMaxValue(d.getBitWidth());
   struct ms mag;
   
-  ad = d >= 0 ? d : -d;
-  t = two63 + ((uint64_t)d >> 63);
-  anc = t - 1 - t%ad;   // absolute value of nc
-  p = 63;               // initialize p
-  q1 = two63/anc;       // initialize q1 = 2p/abs(nc)
-  r1 = two63 - q1*anc;  // initialize r1 = rem(2p,abs(nc))
-  q2 = two63/ad;        // initialize q2 = 2p/abs(d)
-  r2 = two63 - q2*ad;   // initialize r2 = rem(2p,abs(d))
+  ad = d.abs();
+  t = signedMin + (d.lshr(d.getBitWidth() - 1));
+  anc = t - 1 - t.urem(ad);   // absolute value of nc
+  p = d.getBitWidth() - 1;    // initialize p
+  q1 = signedMin.udiv(anc);   // initialize q1 = 2p/abs(nc)
+  r1 = signedMin - q1*anc;    // initialize r1 = rem(2p,abs(nc))
+  q2 = signedMin.udiv(ad);    // initialize q2 = 2p/abs(d)
+  r2 = signedMin - q2*ad;     // initialize r2 = rem(2p,abs(d))
   do {
     p = p + 1;
-    q1 = 2*q1;        // update q1 = 2p/abs(nc)
-    r1 = 2*r1;        // update r1 = rem(2p/abs(nc))
-    if (r1 >= anc) {  // must be unsigned comparison
+    q1 = q1<<1;          // update q1 = 2p/abs(nc)
+    r1 = r1<<1;          // update r1 = rem(2p/abs(nc))
+    if (r1.uge(anc)) {  // must be unsigned comparison
       q1 = q1 + 1;
       r1 = r1 - anc;
     }
-    q2 = 2*q2;        // update q2 = 2p/abs(d)
-    r2 = 2*r2;        // update r2 = rem(2p/abs(d))
-    if (r2 >= ad) {   // must be unsigned comparison
+    q2 = q2<<1;          // update q2 = 2p/abs(d)
+    r2 = r2<<1;          // update r2 = rem(2p/abs(d))
+    if (r2.uge(ad)) {   // must be unsigned comparison
       q2 = q2 + 1;
       r2 = r2 - ad;
     }
     delta = ad - r2;
-  } while (q1 < delta || (q1 == delta && r1 == 0));
+  } while (q1.ule(delta) || (q1 == delta && r1 == 0));
   
   mag.m = q2 + 1;
-  if (d < 0) mag.m = -mag.m; // resulting magic number
-  mag.s = p - 64;            // resulting shift
+  if (d.isNegative()) mag.m = -mag.m;   // resulting magic number
+  mag.s = p - d.getBitWidth();          // resulting shift
   return mag;
 }
 
-/// magicu - calculate the magic numbers required to codegen an integer udiv as
-/// a sequence of multiply, add and shifts.  Requires that the divisor not be 0.
-static mu magicu64(uint64_t d)
-{
-  int64_t p;
-  uint64_t nc, delta, q1, r1, q2, r2;
-  struct mu magu;
-  magu.a = 0;               // initialize "add" indicator
-  nc = - 1 - (-d)%d;
-  p = 63;                   // initialize p
-  q1 = 0x8000000000000000ull/nc;       // initialize q1 = 2p/nc
-  r1 = 0x8000000000000000ull - q1*nc;  // initialize r1 = rem(2p,nc)
-  q2 = 0x7FFFFFFFFFFFFFFFull/d;        // initialize q2 = (2p-1)/d
-  r2 = 0x7FFFFFFFFFFFFFFFull - q2*d;   // initialize r2 = rem((2p-1),d)
-  do {
-    p = p + 1;
-    if (r1 >= nc - r1 ) {
-      q1 = 2*q1 + 1;  // update q1
-      r1 = 2*r1 - nc; // update r1
-    }
-    else {
-      q1 = 2*q1; // update q1
-      r1 = 2*r1; // update r1
-    }
-    if (r2 + 1 >= d - r2) {
-      if (q2 >= 0x7FFFFFFFFFFFFFFFull) magu.a = 1;
-      q2 = 2*q2 + 1;     // update q2
-      r2 = 2*r2 + 1 - d; // update r2
-    }
-    else {
-      if (q2 >= 0x8000000000000000ull) magu.a = 1;
-      q2 = 2*q2;     // update q2
-      r2 = 2*r2 + 1; // update r2
-    }
-    delta = d - 1 - r2;
-  } while (p < 128 && (q1 < delta || (q1 == delta && r1 == 0)));
-  magu.m = q2 + 1; // resulting magic number
-  magu.s = p - 64;  // resulting shift
-  return magu;
-}
-
 /// BuildSDIVSequence - Given an ISD::SDIV node expressing a divide by constant,
 /// return a DAG expression to select that will generate the same value by
 /// multiplying by a magic number.  See:
@@ -2397,13 +2351,15 @@ SDValue TargetLowering::BuildSDIV(SDNode *N, SelectionDAG &DAG,
   MVT VT = N->getValueType(0);
   
   // Check to see if we can do this.
-  if (!isTypeLegal(VT) || (VT != MVT::i32 && VT != MVT::i64))
-    return SDValue();       // BuildSDIV only operates on i32 or i64
+  // FIXME: We should be more aggressive here.
+  if (!isTypeLegal(VT))
+    return SDValue();
   
-  int64_t d = cast<ConstantSDNode>(N->getOperand(1))->getSExtValue();
-  ms magics = (VT == MVT::i32) ? magic32(d) : magic64(d);
+  APInt d = cast<ConstantSDNode>(N->getOperand(1))->getAPIntValue();
+  ms magics = magic(d);
   
   // Multiply the numerator (operand 0) by the magic value
+  // FIXME: We should support doing a MUL in a wider type
   SDValue Q;
   if (isOperationLegal(ISD::MULHS, VT))
     Q = DAG.getNode(ISD::MULHS, VT, N->getOperand(0),
@@ -2415,13 +2371,13 @@ SDValue TargetLowering::BuildSDIV(SDNode *N, SelectionDAG &DAG,
   else
     return SDValue();       // No mulhs or equvialent
   // If d > 0 and m < 0, add the numerator
-  if (d > 0 && magics.m < 0) { 
+  if (d.isStrictlyPositive() && magics.m.isNegative()) { 
     Q = DAG.getNode(ISD::ADD, VT, Q, N->getOperand(0));
     if (Created)
       Created->push_back(Q.getNode());
   }
   // If d < 0 and m > 0, subtract the numerator.
-  if (d < 0 && magics.m > 0) {
+  if (d.isNegative() && magics.m.isStrictlyPositive()) {
     Q = DAG.getNode(ISD::SUB, VT, Q, N->getOperand(0));
     if (Created)
       Created->push_back(Q.getNode());
@@ -2449,15 +2405,19 @@ SDValue TargetLowering::BuildSDIV(SDNode *N, SelectionDAG &DAG,
 SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG,
                                   std::vector<SDNode*>* Created) const {
   MVT VT = N->getValueType(0);
-  
+
   // Check to see if we can do this.
-  if (!isTypeLegal(VT) || (VT != MVT::i32 && VT != MVT::i64))
-    return SDValue();       // BuildUDIV only operates on i32 or i64
-  
-  uint64_t d = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
-  mu magics = (VT == MVT::i32) ? magicu32(d) : magicu64(d);
-  
+  // FIXME: We should be more aggressive here.
+  if (!isTypeLegal(VT))
+    return SDValue();
+
+  // FIXME: We should use a narrower constant when the upper
+  // bits are known to be zero.
+  ConstantSDNode *N1C = cast<ConstantSDNode>(N->getOperand(1));
+  mu magics = magicu(N1C->getAPIntValue());
+
   // Multiply the numerator (operand 0) by the magic value
+  // FIXME: We should support doing a MUL in a wider type
   SDValue Q;
   if (isOperationLegal(ISD::MULHU, VT))
     Q = DAG.getNode(ISD::MULHU, VT, N->getOperand(0),
@@ -2472,6 +2432,8 @@ SDValue TargetLowering::BuildUDIV(SDNode *N, SelectionDAG &DAG,
     Created->push_back(Q.getNode());
 
   if (magics.a == 0) {
+    assert(magics.s < N1C->getAPIntValue().getBitWidth() &&
+           "We shouldn't generate an undefined shift!");
     return DAG.getNode(ISD::SRL, VT, Q, 
                        DAG.getConstant(magics.s, getShiftAmountTy()));
   } else {