Clean up DemandedBitsAreZero interface
[oota-llvm.git] / lib / CodeGen / SelectionDAG / DAGCombiner.cpp
index 8287f2d0965673c18508806d7f105d654970c209..f1d0a9b5065efbddc5c664084fe13bbba310574d 100644 (file)
@@ -98,6 +98,17 @@ namespace {
       DAG.DeleteNode(N);
       return SDOperand(N, 0);
     }
+    
+    bool DemandedBitsAreZero(SDOperand Op, uint64_t DemandedMask) {
+      TargetLowering::TargetLoweringOpt TLO(DAG);
+      uint64_t KnownZero, KnownOne;
+      if (TLI.SimplifyDemandedBits(Op, DemandedMask, KnownZero, KnownOne, TLO)){
+        WorkList.push_back(Op.Val);
+        CombineTo(TLO.Old.Val, TLO.New);
+        return true;
+      }
+      return false;
+    }
 
     SDOperand CombineTo(SDNode *N, SDOperand Res) {
       std::vector<SDOperand> To;
@@ -145,14 +156,11 @@ namespace {
     SDOperand visitSELECT(SDNode *N);
     SDOperand visitSELECT_CC(SDNode *N);
     SDOperand visitSETCC(SDNode *N);
-    SDOperand visitADD_PARTS(SDNode *N);
-    SDOperand visitSUB_PARTS(SDNode *N);
     SDOperand visitSIGN_EXTEND(SDNode *N);
     SDOperand visitZERO_EXTEND(SDNode *N);
     SDOperand visitSIGN_EXTEND_INREG(SDNode *N);
     SDOperand visitTRUNCATE(SDNode *N);
     SDOperand visitBIT_CONVERT(SDNode *N);
-    
     SDOperand visitFADD(SDNode *N);
     SDOperand visitFSUB(SDNode *N);
     SDOperand visitFMUL(SDNode *N);
@@ -171,13 +179,9 @@ namespace {
     SDOperand visitBRCONDTWOWAY(SDNode *N);
     SDOperand visitBR_CC(SDNode *N);
     SDOperand visitBRTWOWAY_CC(SDNode *N);
-
     SDOperand visitLOAD(SDNode *N);
     SDOperand visitSTORE(SDNode *N);
 
-    SDOperand visitLOCATION(SDNode *N);
-    SDOperand visitDEBUGLOC(SDNode *N);
-
     SDOperand ReassociateOps(unsigned Opc, SDOperand LHS, SDOperand RHS);
     
     bool SimplifySelectOps(SDNode *SELECT, SDOperand LHS, SDOperand RHS);
@@ -541,8 +545,6 @@ SDOperand DAGCombiner::visit(SDNode *N) {
   case ISD::SELECT:             return visitSELECT(N);
   case ISD::SELECT_CC:          return visitSELECT_CC(N);
   case ISD::SETCC:              return visitSETCC(N);
-  case ISD::ADD_PARTS:          return visitADD_PARTS(N);
-  case ISD::SUB_PARTS:          return visitSUB_PARTS(N);
   case ISD::SIGN_EXTEND:        return visitSIGN_EXTEND(N);
   case ISD::ZERO_EXTEND:        return visitZERO_EXTEND(N);
   case ISD::SIGN_EXTEND_INREG:  return visitSIGN_EXTEND_INREG(N);
@@ -568,8 +570,6 @@ SDOperand DAGCombiner::visit(SDNode *N) {
   case ISD::BRTWOWAY_CC:        return visitBRTWOWAY_CC(N);
   case ISD::LOAD:               return visitLOAD(N);
   case ISD::STORE:              return visitSTORE(N);
-  case ISD::LOCATION:           return visitLOCATION(N);
-  case ISD::DEBUG_LOC:          return visitDEBUGLOC(N);
   }
   return SDOperand();
 }
@@ -730,8 +730,8 @@ SDOperand DAGCombiner::visitSDIV(SDNode *N) {
   if (TLI.MaskedValueIsZero(N1, SignBit) &&
       TLI.MaskedValueIsZero(N0, SignBit))
     return DAG.getNode(ISD::UDIV, N1.getValueType(), N0, N1);
-  // fold (sdiv X, pow2) -> (add (sra X, log(pow2)), (srl X, sizeof(X)-1))
-  if (N1C && N1C->getValue() && !TLI.isIntDivCheap() && 
+  // fold (sdiv X, pow2) -> simple ops after legalize
+  if (N1C && N1C->getValue() && !TLI.isIntDivCheap() &&
       (isPowerOf2_64(N1C->getSignExtended()) || 
        isPowerOf2_64(-N1C->getSignExtended()))) {
     // If dividing by powers of two is cheap, then don't perform the following
@@ -740,15 +740,21 @@ SDOperand DAGCombiner::visitSDIV(SDNode *N) {
       return SDOperand();
     int64_t pow2 = N1C->getSignExtended();
     int64_t abs2 = pow2 > 0 ? pow2 : -pow2;
-    SDOperand SRL = DAG.getNode(ISD::SRL, VT, N0,
+    unsigned lg2 = Log2_64(abs2);
+    // Splat the sign bit into the register
+    SDOperand SGN = DAG.getNode(ISD::SRA, VT, N0,
                                 DAG.getConstant(MVT::getSizeInBits(VT)-1,
                                                 TLI.getShiftAmountTy()));
-    WorkList.push_back(SRL.Val);
-    SDOperand SGN = DAG.getNode(ISD::ADD, VT, N0, SRL);
     WorkList.push_back(SGN.Val);
-    SDOperand SRA = DAG.getNode(ISD::SRA, VT, SGN, 
-                                DAG.getConstant(Log2_64(abs2),
+    // Add (N0 < 0) ? abs2 - 1 : 0;
+    SDOperand SRL = DAG.getNode(ISD::SRL, VT, SGN,
+                                DAG.getConstant(MVT::getSizeInBits(VT)-lg2,
                                                 TLI.getShiftAmountTy()));
+    SDOperand ADD = DAG.getNode(ISD::ADD, VT, N0, SRL);
+    WorkList.push_back(SRL.Val);
+    WorkList.push_back(ADD.Val);    // Divide by pow2
+    SDOperand SRA = DAG.getNode(ISD::SRA, VT, ADD,
+                                DAG.getConstant(lg2, TLI.getShiftAmountTy()));
     // If we're dividing by a positive value, we're done.  Otherwise, we must
     // negate the result.
     if (pow2 > 0)
@@ -778,15 +784,27 @@ SDOperand DAGCombiner::visitUDIV(SDNode *N) {
     return DAG.getNode(ISD::UDIV, VT, N0, N1);
   // fold (udiv x, (1 << c)) -> x >>u c
   if (N1C && isPowerOf2_64(N1C->getValue()))
-    return DAG.getNode(ISD::SRL, N->getValueType(0), N0,
+    return DAG.getNode(ISD::SRL, VT, N0, 
                        DAG.getConstant(Log2_64(N1C->getValue()),
                                        TLI.getShiftAmountTy()));
+  // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2
+  if (N1.getOpcode() == ISD::SHL) {
+    if (ConstantSDNode *SHC = dyn_cast<ConstantSDNode>(N1.getOperand(0))) {
+      if (isPowerOf2_64(SHC->getValue())) {
+        MVT::ValueType ADDVT = N1.getOperand(1).getValueType();
+        SDOperand Add = DAG.getNode(ISD::ADD, ADDVT, N1.getOperand(1),
+                                    DAG.getConstant(Log2_64(SHC->getValue()),
+                                                    ADDVT));
+        WorkList.push_back(Add.Val);
+        return DAG.getNode(ISD::SRL, VT, N0, Add);
+      }
+    }
+  }
   // fold (udiv x, c) -> alternate
   if (N1C && N1C->getValue() && !TLI.isIntDivCheap()) {
     SDOperand Op = BuildUDIV(N);
     if (Op.Val) return Op;
   }
-      
   return SDOperand();
 }
 
@@ -822,6 +840,16 @@ SDOperand DAGCombiner::visitUREM(SDNode *N) {
   // fold (urem x, pow2) -> (and x, pow2-1)
   if (N1C && !N1C->isNullValue() && isPowerOf2_64(N1C->getValue()))
     return DAG.getNode(ISD::AND, VT, N0, DAG.getConstant(N1C->getValue()-1,VT));
+  // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1))
+  if (N1.getOpcode() == ISD::SHL) {
+    if (ConstantSDNode *SHC = dyn_cast<ConstantSDNode>(N1.getOperand(0))) {
+      if (isPowerOf2_64(SHC->getValue())) {
+        SDOperand Add = DAG.getNode(ISD::ADD, VT, N1,DAG.getConstant(~0ULL,VT));
+        WorkList.push_back(Add.Val);
+        return DAG.getNode(ISD::AND, VT, N0, Add);
+      }
+    }
+  }
   return SDOperand();
 }
 
@@ -858,7 +886,7 @@ SDOperand DAGCombiner::visitMULHU(SDNode *N) {
 SDOperand DAGCombiner::visitAND(SDNode *N) {
   SDOperand N0 = N->getOperand(0);
   SDOperand N1 = N->getOperand(1);
-  SDOperand LL, LR, RL, RR, CC0, CC1, Old, New;
+  SDOperand LL, LR, RL, RR, CC0, CC1;
   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
   MVT::ValueType VT = N1.getValueType();
@@ -874,12 +902,8 @@ SDOperand DAGCombiner::visitAND(SDNode *N) {
   if (N1C && N1C->isAllOnesValue())
     return N0;
   // if (and x, c) is known to be zero, return 0
-  if (N1C && TLI.MaskedValueIsZero(SDOperand(N, 0), ~0ULL >> (64-OpSizeInBits)))
+  if (N1C && TLI.MaskedValueIsZero(SDOperand(N, 0), MVT::getIntVTBitMask(VT)))
     return DAG.getConstant(0, VT);
-  // fold (and x, c) -> x iff (x & ~c) == 0
-  if (N1C && 
-      TLI.MaskedValueIsZero(N0, ~N1C->getValue() & (~0ULL>>(64-OpSizeInBits))))
-    return N0;
   // reassociate and
   SDOperand RAND = ReassociateOps(ISD::AND, N0, N1);
   if (RAND.Val != 0)
@@ -961,38 +985,8 @@ SDOperand DAGCombiner::visitAND(SDNode *N) {
   }
   // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
   // fold (and (sra)) -> (and (srl)) when possible.
-  if (TLI.DemandedBitsAreZero(SDOperand(N, 0), ~0ULL >> (64-OpSizeInBits), Old, 
-                              New, DAG)) {
-    WorkList.push_back(N);
-    CombineTo(Old.Val, New);
+  if (DemandedBitsAreZero(SDOperand(N, 0), MVT::getIntVTBitMask(VT)))
     return SDOperand();
-  }
-  // FIXME: DemandedBitsAreZero cannot currently handle AND with non-constant
-  // RHS and propagate known cleared bits to LHS.  For this reason, we must keep
-  // this fold, for now, for the following testcase:
-  //
-  //int %test2(uint %mode.0.i.0) {
-  //  %tmp.79 = cast uint %mode.0.i.0 to int
-  //  %tmp.80 = shr int %tmp.79, ubyte 15
-  //  %tmp.81 = shr uint %mode.0.i.0, ubyte 16
-  //  %tmp.82 = cast uint %tmp.81 to int
-  //  %tmp.83 = and int %tmp.80, %tmp.82
-  //  ret int %tmp.83
-  //}
-  // fold (and (sra)) -> (and (srl)) when possible.
-  if (N0.getOpcode() == ISD::SRA && N0.Val->hasOneUse()) {
-    if (ConstantSDNode *N01C = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
-      // If the RHS of the AND has zeros where the sign bits of the SRA will
-      // land, turn the SRA into an SRL.
-      if (TLI.MaskedValueIsZero(N1, (~0ULL << (OpSizeInBits-N01C->getValue())) &
-                                (~0ULL>>(64-OpSizeInBits)))) {
-        WorkList.push_back(N);
-        CombineTo(N0.Val, DAG.getNode(ISD::SRL, VT, N0.getOperand(0),
-                                      N0.getOperand(1)));
-        return SDOperand();
-      }
-    }
-  }
   // fold (zext_inreg (extload x)) -> (zextload x)
   if (N0.getOpcode() == ISD::EXTLOAD) {
     MVT::ValueType EVT = cast<VTSDNode>(N0.getOperand(3))->getVT();
@@ -1253,8 +1247,6 @@ SDOperand DAGCombiner::visitXOR(SDNode *N) {
 SDOperand DAGCombiner::visitSHL(SDNode *N) {
   SDOperand N0 = N->getOperand(0);
   SDOperand N1 = N->getOperand(1);
-  SDOperand Old = SDOperand();
-  SDOperand New = SDOperand();
   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
   MVT::ValueType VT = N0.getValueType();
@@ -1273,14 +1265,10 @@ SDOperand DAGCombiner::visitSHL(SDNode *N) {
   if (N1C && N1C->isNullValue())
     return N0;
   // if (shl x, c) is known to be zero, return 0
-  if (N1C && TLI.MaskedValueIsZero(SDOperand(N, 0), ~0ULL >> (64-OpSizeInBits)))
+  if (TLI.MaskedValueIsZero(SDOperand(N, 0), MVT::getIntVTBitMask(VT)))
     return DAG.getConstant(0, VT);
-  if (N1C && TLI.DemandedBitsAreZero(SDOperand(N,0), ~0ULL >> (64-OpSizeInBits),
-                                     Old, New, DAG)) {
-    WorkList.push_back(N);
-    CombineTo(Old.Val, New);
+  if (DemandedBitsAreZero(SDOperand(N,0), MVT::getIntVTBitMask(VT)))
     return SDOperand();
-  }
   // fold (shl (shl x, c1), c2) -> 0 or (shl x, c1+c2)
   if (N1C && N0.getOpcode() == ISD::SHL && 
       N0.getOperand(1).getOpcode() == ISD::Constant) {
@@ -1319,7 +1307,6 @@ SDOperand DAGCombiner::visitSRA(SDNode *N) {
   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
   ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
   MVT::ValueType VT = N0.getValueType();
-  unsigned OpSizeInBits = MVT::getSizeInBits(VT);
   
   // fold (sra c1, c2) -> c1>>c2
   if (N0C && N1C)
@@ -1331,13 +1318,29 @@ SDOperand DAGCombiner::visitSRA(SDNode *N) {
   if (N0C && N0C->isAllOnesValue())
     return N0;
   // fold (sra x, c >= size(x)) -> undef
-  if (N1C && N1C->getValue() >= OpSizeInBits)
+  if (N1C && N1C->getValue() >= MVT::getSizeInBits(VT))
     return DAG.getNode(ISD::UNDEF, VT);
   // fold (sra x, 0) -> x
   if (N1C && N1C->isNullValue())
     return N0;
+  // fold (sra (shl x, c1), c1) -> sext_inreg for some c1 and target supports
+  // sext_inreg.
+  if (N1C && N0.getOpcode() == ISD::SHL && N1 == N0.getOperand(1)) {
+    unsigned LowBits = MVT::getSizeInBits(VT) - (unsigned)N1C->getValue();
+    MVT::ValueType EVT;
+    switch (LowBits) {
+    default: EVT = MVT::Other; break;
+    case  1: EVT = MVT::i1;    break;
+    case  8: EVT = MVT::i8;    break;
+    case 16: EVT = MVT::i16;   break;
+    case 32: EVT = MVT::i32;   break;
+    }
+    if (EVT > MVT::Other && TLI.isOperationLegal(ISD::SIGN_EXTEND_INREG, EVT))
+      return DAG.getNode(ISD::SIGN_EXTEND_INREG, VT, N0.getOperand(0),
+                         DAG.getValueType(EVT));
+  }
   // If the sign bit is known to be zero, switch this to a SRL.
-  if (TLI.MaskedValueIsZero(N0, (1ULL << (OpSizeInBits-1))))
+  if (TLI.MaskedValueIsZero(N0, MVT::getIntVTSignBit(VT)))
     return DAG.getNode(ISD::SRL, VT, N0, N1);
   return SDOperand();
 }
@@ -1504,46 +1507,6 @@ SDOperand DAGCombiner::visitSETCC(SDNode *N) {
                        cast<CondCodeSDNode>(N->getOperand(2))->get());
 }
 
-SDOperand DAGCombiner::visitADD_PARTS(SDNode *N) {
-  SDOperand LHSLo = N->getOperand(0);
-  SDOperand RHSLo = N->getOperand(2);
-  MVT::ValueType VT = LHSLo.getValueType();
-  
-  // fold (a_Hi, 0) + (b_Hi, b_Lo) -> (b_Hi + a_Hi, b_Lo)
-  if (TLI.MaskedValueIsZero(LHSLo, (1ULL << MVT::getSizeInBits(VT))-1)) {
-    SDOperand Hi = DAG.getNode(ISD::ADD, VT, N->getOperand(1),
-                               N->getOperand(3));
-    WorkList.push_back(Hi.Val);
-    CombineTo(N, RHSLo, Hi);
-    return SDOperand();
-  }
-  // fold (a_Hi, a_Lo) + (b_Hi, 0) -> (a_Hi + b_Hi, a_Lo)
-  if (TLI.MaskedValueIsZero(RHSLo, (1ULL << MVT::getSizeInBits(VT))-1)) {
-    SDOperand Hi = DAG.getNode(ISD::ADD, VT, N->getOperand(1),
-                               N->getOperand(3));
-    WorkList.push_back(Hi.Val);
-    CombineTo(N, LHSLo, Hi);
-    return SDOperand();
-  }
-  return SDOperand();
-}
-
-SDOperand DAGCombiner::visitSUB_PARTS(SDNode *N) {
-  SDOperand LHSLo = N->getOperand(0);
-  SDOperand RHSLo = N->getOperand(2);
-  MVT::ValueType VT = LHSLo.getValueType();
-  
-  // fold (a_Hi, a_Lo) - (b_Hi, 0) -> (a_Hi - b_Hi, a_Lo)
-  if (TLI.MaskedValueIsZero(RHSLo, (1ULL << MVT::getSizeInBits(VT))-1)) {
-    SDOperand Hi = DAG.getNode(ISD::SUB, VT, N->getOperand(1),
-                               N->getOperand(3));
-    WorkList.push_back(Hi.Val);
-    CombineTo(N, LHSLo, Hi);
-    return SDOperand();
-  }
-  return SDOperand();
-}
-
 SDOperand DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
   SDOperand N0 = N->getOperand(0);
   ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
@@ -2140,35 +2103,6 @@ SDOperand DAGCombiner::visitSTORE(SDNode *N) {
   return SDOperand();
 }
 
-SDOperand DAGCombiner::visitLOCATION(SDNode *N) {
-  SDOperand Chain    = N->getOperand(0);
-  
-  // Remove redundant locations (last one holds)
-  if (Chain.getOpcode() == ISD::LOCATION && Chain.hasOneUse()) {
-    return DAG.getNode(ISD::LOCATION, MVT::Other, Chain.getOperand(0),
-                                                  N->getOperand(1),
-                                                  N->getOperand(2),
-                                                  N->getOperand(3),
-                                                  N->getOperand(4));
-  }
-  
-  return SDOperand();
-}
-
-SDOperand DAGCombiner::visitDEBUGLOC(SDNode *N) {
-  SDOperand Chain    = N->getOperand(0);
-  
-  // Remove redundant debug locations (last one holds)
-  if (Chain.getOpcode() == ISD::DEBUG_LOC && Chain.hasOneUse()) {
-    return DAG.getNode(ISD::DEBUG_LOC, MVT::Other, Chain.getOperand(0),
-                                                   N->getOperand(1),
-                                                   N->getOperand(2),
-                                                   N->getOperand(3));
-  }
-  
-  return SDOperand();
-}
-
 SDOperand DAGCombiner::SimplifySelect(SDOperand N0, SDOperand N1, SDOperand N2){
   assert(N0.getOpcode() ==ISD::SETCC && "First argument must be a SetCC node!");
   
@@ -2522,6 +2456,32 @@ SDOperand DAGCombiner::SimplifySetCC(MVT::ValueType VT, SDOperand N0,
                             DAG.getConstant(C1 & (~0ULL>>(64-ExtSrcTyBits)), 
                                             ExtDstTy),
                             Cond);
+      } else if ((N1C->getValue() == 0 || N1C->getValue() == 1) &&
+                 (Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
+                 (N0.getOpcode() == ISD::XOR ||
+                  (N0.getOpcode() == ISD::AND && 
+                   N0.getOperand(0).getOpcode() == ISD::XOR &&
+                   N0.getOperand(1) == N0.getOperand(0).getOperand(1))) &&
+                 isa<ConstantSDNode>(N0.getOperand(1)) &&
+                 cast<ConstantSDNode>(N0.getOperand(1))->getValue() == 1) {
+        // If this is (X^1) == 0/1, swap the RHS and eliminate the xor.  We can
+        // only do this if the top bits are known zero.
+        if (TLI.MaskedValueIsZero(N1, 
+                                  MVT::getIntVTBitMask(N0.getValueType())-1)) {
+          // Okay, get the un-inverted input value.
+          SDOperand Val;
+          if (N0.getOpcode() == ISD::XOR)
+            Val = N0.getOperand(0);
+          else {
+            assert(N0.getOpcode() == ISD::AND && 
+                   N0.getOperand(0).getOpcode() == ISD::XOR);
+            // ((X^1)&1)^1 -> X & 1
+            Val = DAG.getNode(ISD::AND, N0.getValueType(),
+                              N0.getOperand(0).getOperand(0), N0.getOperand(1));
+          }
+          return DAG.getSetCC(VT, Val, N1,
+                              Cond == ISD::SETEQ ? ISD::SETNE : ISD::SETEQ);
+        }
       }
       
       uint64_t MinVal, MaxVal;