Clean up DemandedBitsAreZero interface
[oota-llvm.git] / lib / CodeGen / SelectionDAG / DAGCombiner.cpp
index e40be3e3ca2c6ff39101337b26c34b4d472d69ff..f1d0a9b5065efbddc5c664084fe13bbba310574d 100644 (file)
@@ -99,13 +99,12 @@ namespace {
       return SDOperand(N, 0);
     }
     
-    bool DemandedBitsAreZero(SDOperand Op, uint64_t DemandedMask,
-                             SDOperand &Old, SDOperand &New) const {
+    bool DemandedBitsAreZero(SDOperand Op, uint64_t DemandedMask) {
       TargetLowering::TargetLoweringOpt TLO(DAG);
       uint64_t KnownZero, KnownOne;
       if (TLI.SimplifyDemandedBits(Op, DemandedMask, KnownZero, KnownOne, TLO)){
-        Old = TLO.Old;
-        New = TLO.New;
+        WorkList.push_back(Op.Val);
+        CombineTo(TLO.Old.Val, TLO.New);
         return true;
       }
       return false;
@@ -157,14 +156,11 @@ namespace {
     SDOperand visitSELECT(SDNode *N);
     SDOperand visitSELECT_CC(SDNode *N);
     SDOperand visitSETCC(SDNode *N);
-    SDOperand visitADD_PARTS(SDNode *N);
-    SDOperand visitSUB_PARTS(SDNode *N);
     SDOperand visitSIGN_EXTEND(SDNode *N);
     SDOperand visitZERO_EXTEND(SDNode *N);
     SDOperand visitSIGN_EXTEND_INREG(SDNode *N);
     SDOperand visitTRUNCATE(SDNode *N);
     SDOperand visitBIT_CONVERT(SDNode *N);
-    
     SDOperand visitFADD(SDNode *N);
     SDOperand visitFSUB(SDNode *N);
     SDOperand visitFMUL(SDNode *N);
@@ -183,7 +179,6 @@ namespace {
     SDOperand visitBRCONDTWOWAY(SDNode *N);
     SDOperand visitBR_CC(SDNode *N);
     SDOperand visitBRTWOWAY_CC(SDNode *N);
-
     SDOperand visitLOAD(SDNode *N);
     SDOperand visitSTORE(SDNode *N);
 
@@ -550,8 +545,6 @@ SDOperand DAGCombiner::visit(SDNode *N) {
   case ISD::SELECT:             return visitSELECT(N);
   case ISD::SELECT_CC:          return visitSELECT_CC(N);
   case ISD::SETCC:              return visitSETCC(N);
-  case ISD::ADD_PARTS:          return visitADD_PARTS(N);
-  case ISD::SUB_PARTS:          return visitSUB_PARTS(N);
   case ISD::SIGN_EXTEND:        return visitSIGN_EXTEND(N);
   case ISD::ZERO_EXTEND:        return visitZERO_EXTEND(N);
   case ISD::SIGN_EXTEND_INREG:  return visitSIGN_EXTEND_INREG(N);
@@ -737,8 +730,8 @@ SDOperand DAGCombiner::visitSDIV(SDNode *N) {
   if (TLI.MaskedValueIsZero(N1, SignBit) &&
       TLI.MaskedValueIsZero(N0, SignBit))
     return DAG.getNode(ISD::UDIV, N1.getValueType(), N0, N1);
-  // fold (sdiv X, pow2) -> simple ops.
-  if (N1C && N1C->getValue() && !TLI.isIntDivCheap() && 
+  // fold (sdiv X, pow2) -> simple ops after legalize
+  if (N1C && N1C->getValue() && !TLI.isIntDivCheap() &&
       (isPowerOf2_64(N1C->getSignExtended()) || 
        isPowerOf2_64(-N1C->getSignExtended()))) {
     // If dividing by powers of two is cheap, then don't perform the following
@@ -893,7 +886,7 @@ SDOperand DAGCombiner::visitMULHU(SDNode *N) {
 SDOperand DAGCombiner::visitAND(SDNode *N) {
   SDOperand N0 = N->getOperand(0);
   SDOperand N1 = N->getOperand(1);
-  SDOperand LL, LR, RL, RR, CC0, CC1, Old, New;
+  SDOperand LL, LR, RL, RR, CC0, CC1;
   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
   MVT::ValueType VT = N1.getValueType();
@@ -992,12 +985,8 @@ SDOperand DAGCombiner::visitAND(SDNode *N) {
   }
   // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
   // fold (and (sra)) -> (and (srl)) when possible.
-  if (DemandedBitsAreZero(SDOperand(N, 0), MVT::getIntVTBitMask(VT), Old, 
-                          New)) {
-    WorkList.push_back(N);
-    CombineTo(Old.Val, New);
+  if (DemandedBitsAreZero(SDOperand(N, 0), MVT::getIntVTBitMask(VT)))
     return SDOperand();
-  }
   // fold (zext_inreg (extload x)) -> (zextload x)
   if (N0.getOpcode() == ISD::EXTLOAD) {
     MVT::ValueType EVT = cast<VTSDNode>(N0.getOperand(3))->getVT();
@@ -1258,8 +1247,6 @@ SDOperand DAGCombiner::visitXOR(SDNode *N) {
 SDOperand DAGCombiner::visitSHL(SDNode *N) {
   SDOperand N0 = N->getOperand(0);
   SDOperand N1 = N->getOperand(1);
-  SDOperand Old = SDOperand();
-  SDOperand New = SDOperand();
   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
   MVT::ValueType VT = N0.getValueType();
@@ -1278,14 +1265,10 @@ SDOperand DAGCombiner::visitSHL(SDNode *N) {
   if (N1C && N1C->isNullValue())
     return N0;
   // if (shl x, c) is known to be zero, return 0
-  if (N1C && TLI.MaskedValueIsZero(SDOperand(N, 0), ~0ULL >> (64-OpSizeInBits)))
+  if (TLI.MaskedValueIsZero(SDOperand(N, 0), MVT::getIntVTBitMask(VT)))
     return DAG.getConstant(0, VT);
-  if (N1C && DemandedBitsAreZero(SDOperand(N,0), ~0ULL >> (64-OpSizeInBits),
-                                 Old, New)) {
-    WorkList.push_back(N);
-    CombineTo(Old.Val, New);
+  if (DemandedBitsAreZero(SDOperand(N,0), MVT::getIntVTBitMask(VT)))
     return SDOperand();
-  }
   // fold (shl (shl x, c1), c2) -> 0 or (shl x, c1+c2)
   if (N1C && N0.getOpcode() == ISD::SHL && 
       N0.getOperand(1).getOpcode() == ISD::Constant) {
@@ -1324,7 +1307,6 @@ SDOperand DAGCombiner::visitSRA(SDNode *N) {
   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
   MVT::ValueType VT = N0.getValueType();
-  unsigned OpSizeInBits = MVT::getSizeInBits(VT);
   
   // fold (sra c1, c2) -> c1>>c2
   if (N0C && N1C)
@@ -1336,13 +1318,29 @@ SDOperand DAGCombiner::visitSRA(SDNode *N) {
   if (N0C && N0C->isAllOnesValue())
     return N0;
   // fold (sra x, c >= size(x)) -> undef
-  if (N1C && N1C->getValue() >= OpSizeInBits)
+  if (N1C && N1C->getValue() >= MVT::getSizeInBits(VT))
     return DAG.getNode(ISD::UNDEF, VT);
   // fold (sra x, 0) -> x
   if (N1C && N1C->isNullValue())
     return N0;
+  // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target supports
+  // sext_inreg.
+  if (N1C && N0.getOpcode() == ISD::SHL && N1 == N0.getOperand(1)) {
+    unsigned LowBits = MVT::getSizeInBits(VT) - (unsigned)N1C->getValue();
+    MVT::ValueType EVT;
+    switch (LowBits) {
+    default: EVT = MVT::Other; break;
+    case  1: EVT = MVT::i1;    break;
+    case  8: EVT = MVT::i8;    break;
+    case 16: EVT = MVT::i16;   break;
+    case 32: EVT = MVT::i32;   break;
+    }
+    if (EVT > MVT::Other && TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG, EVT))
+      return DAG.getNode(ISD::SIGN_EXTEND_INREG, VT, N0.getOperand(0),
+                         DAG.getValueType(EVT));
+  }
   // If the sign bit is known to be zero, switch this to a SRL.
-  if (TLI.MaskedValueIsZero(N0, (1ULL << (OpSizeInBits-1))))
+  if (TLI.MaskedValueIsZero(N0, MVT::getIntVTSignBit(VT)))
     return DAG.getNode(ISD::SRL, VT, N0, N1);
   return SDOperand();
 }
@@ -1509,46 +1507,6 @@ SDOperand DAGCombiner::visitSETCC(SDNode *N) {
                        cast<CondCodeSDNode>(N->getOperand(2))->get());
 }
 
-SDOperand DAGCombiner::visitADD_PARTS(SDNode *N) {
-  SDOperand LHSLo = N->getOperand(0);
-  SDOperand RHSLo = N->getOperand(2);
-  MVT::ValueType VT = LHSLo.getValueType();
-  
-  // fold (a_Hi, 0) + (b_Hi, b_Lo) -> (b_Hi + a_Hi, b_Lo)
-  if (TLI.MaskedValueIsZero(LHSLo, (1ULL << MVT::getSizeInBits(VT))-1)) {
-    SDOperand Hi = DAG.getNode(ISD::ADD, VT, N->getOperand(1),
-                               N->getOperand(3));
-    WorkList.push_back(Hi.Val);
-    CombineTo(N, RHSLo, Hi);
-    return SDOperand();
-  }
-  // fold (a_Hi, a_Lo) + (b_Hi, 0) -> (a_Hi + b_Hi, a_Lo)
-  if (TLI.MaskedValueIsZero(RHSLo, (1ULL << MVT::getSizeInBits(VT))-1)) {
-    SDOperand Hi = DAG.getNode(ISD::ADD, VT, N->getOperand(1),
-                               N->getOperand(3));
-    WorkList.push_back(Hi.Val);
-    CombineTo(N, LHSLo, Hi);
-    return SDOperand();
-  }
-  return SDOperand();
-}
-
-SDOperand DAGCombiner::visitSUB_PARTS(SDNode *N) {
-  SDOperand LHSLo = N->getOperand(0);
-  SDOperand RHSLo = N->getOperand(2);
-  MVT::ValueType VT = LHSLo.getValueType();
-  
-  // fold (a_Hi, a_Lo) - (b_Hi, 0) -> (a_Hi - b_Hi, a_Lo)
-  if (TLI.MaskedValueIsZero(RHSLo, (1ULL << MVT::getSizeInBits(VT))-1)) {
-    SDOperand Hi = DAG.getNode(ISD::SUB, VT, N->getOperand(1),
-                               N->getOperand(3));
-    WorkList.push_back(Hi.Val);
-    CombineTo(N, LHSLo, Hi);
-    return SDOperand();
-  }
-  return SDOperand();
-}
-
 SDOperand DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
   SDOperand N0 = N->getOperand(0);
   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);