SelectionDAG: Cleanup integer bin op promotion functions.
[oota-llvm.git] / lib / CodeGen / SelectionDAG / LegalizeIntegerTypes.cpp
index ea537fff1688840dd704abd347cd412302fba166..cd114d668e20371468b7b0689c5e95f48f9d3696 100644 (file)
@@ -66,16 +66,20 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::CTTZ:        Res = PromoteIntRes_CTTZ(N); break;
   case ISD::EXTRACT_VECTOR_ELT:
                          Res = PromoteIntRes_EXTRACT_VECTOR_ELT(N); break;
-  case ISD::LOAD:        Res = PromoteIntRes_LOAD(cast<LoadSDNode>(N));break;
-  case ISD::MLOAD:       Res = PromoteIntRes_MLOAD(cast<MaskedLoadSDNode>(N));break;
+  case ISD::LOAD:        Res = PromoteIntRes_LOAD(cast<LoadSDNode>(N)); break;
+  case ISD::MLOAD:       Res = PromoteIntRes_MLOAD(cast<MaskedLoadSDNode>(N));
+    break;
+  case ISD::MGATHER:     Res = PromoteIntRes_MGATHER(cast<MaskedGatherSDNode>(N));
+    break;
   case ISD::SELECT:      Res = PromoteIntRes_SELECT(N); break;
   case ISD::VSELECT:     Res = PromoteIntRes_VSELECT(N); break;
   case ISD::SELECT_CC:   Res = PromoteIntRes_SELECT_CC(N); break;
   case ISD::SETCC:       Res = PromoteIntRes_SETCC(N); break;
   case ISD::SMIN:
-  case ISD::SMAX:
+  case ISD::SMAX:        Res = PromoteIntRes_SExtIntBinOp(N); break;
   case ISD::UMIN:
-  case ISD::UMAX:        Res = PromoteIntRes_SimpleIntBinOp(N); break;
+  case ISD::UMAX:        Res = PromoteIntRes_ZExtIntBinOp(N); break;
+
   case ISD::SHL:         Res = PromoteIntRes_SHL(N); break;
   case ISD::SIGN_EXTEND_INREG:
                          Res = PromoteIntRes_SIGN_EXTEND_INREG(N); break;
@@ -115,10 +119,10 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::MUL:         Res = PromoteIntRes_SimpleIntBinOp(N); break;
 
   case ISD::SDIV:
-  case ISD::SREM:        Res = PromoteIntRes_SDIV(N); break;
+  case ISD::SREM:        Res = PromoteIntRes_SExtIntBinOp(N); break;
 
   case ISD::UDIV:
-  case ISD::UREM:        Res = PromoteIntRes_UDIV(N); break;
+  case ISD::UREM:        Res = PromoteIntRes_ZExtIntBinOp(N); break;
 
   case ISD::SADDO:
   case ISD::SSUBO:       Res = PromoteIntRes_SADDSUBO(N, ResNo); break;
@@ -147,10 +151,6 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::ATOMIC_CMP_SWAP_WITH_SUCCESS:
     Res = PromoteIntRes_AtomicCmpSwap(cast<AtomicSDNode>(N), ResNo);
     break;
-  case ISD::UABSDIFF:
-  case ISD::SABSDIFF:
-    Res = PromoteIntRes_SimpleIntBinOp(N);
-    break;
   }
 
   // If the result is null then the sub-method took care of registering it.
@@ -185,7 +185,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_Atomic0(AtomicSDNode *N) {
                               N->getChain(), N->getBasePtr(),
                               N->getMemOperand(), N->getOrdering(),
                               N->getSynchScope());
-  // Legalized the chain result - switch anything that used the old chain to
+  // Legalize the chain result - switch anything that used the old chain to
   // use the new one.
   ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
   return Res;
@@ -198,7 +198,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_Atomic1(AtomicSDNode *N) {
                               N->getChain(), N->getBasePtr(),
                               Op2, N->getMemOperand(), N->getOrdering(),
                               N->getSynchScope());
-  // Legalized the chain result - switch anything that used the old chain to
+  // Legalize the chain result - switch anything that used the old chain to
   // use the new one.
   ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
   return Res;
@@ -483,7 +483,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_LOAD(LoadSDNode *N) {
   SDValue Res = DAG.getExtLoad(ExtType, dl, NVT, N->getChain(), N->getBasePtr(),
                                N->getMemoryVT(), N->getMemOperand());
 
-  // Legalized the chain result - switch anything that used the old chain to
+  // Legalize the chain result - switch anything that used the old chain to
   // use the new one.
   ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
   return Res;
@@ -493,20 +493,34 @@ SDValue DAGTypeLegalizer::PromoteIntRes_MLOAD(MaskedLoadSDNode *N) {
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
   SDValue ExtSrc0 = GetPromotedInteger(N->getSrc0());
 
-  SDValue Mask = N->getMask();
-  EVT NewMaskVT = getSetCCResultType(NVT);
-  if (NewMaskVT != N->getMask().getValueType())
-    Mask = PromoteTargetBoolean(Mask, NewMaskVT);
   SDLoc dl(N);
-
   SDValue Res = DAG.getMaskedLoad(NVT, dl, N->getChain(), N->getBasePtr(),
-                                  Mask, ExtSrc0, N->getMemoryVT(),
+                                  N->getMask(), ExtSrc0, N->getMemoryVT(),
                                   N->getMemOperand(), ISD::SEXTLOAD);
-  // Legalized the chain result - switch anything that used the old chain to
+  // Legalize the chain result - switch anything that used the old chain to
+  // use the new one.
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
+SDValue DAGTypeLegalizer::PromoteIntRes_MGATHER(MaskedGatherSDNode *N) {
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  SDValue ExtSrc0 = GetPromotedInteger(N->getValue());
+  assert(NVT == ExtSrc0.getValueType() &&
+      "Gather result type and the passThru agrument type should be the same");
+
+  SDLoc dl(N);
+  SDValue Ops[] = {N->getChain(), ExtSrc0, N->getMask(), N->getBasePtr(),
+                   N->getIndex()};
+  SDValue Res = DAG.getMaskedGather(DAG.getVTList(NVT, MVT::Other),
+                                    N->getMemoryVT(), dl, Ops,
+                                    N->getMemOperand()); 
+  // Legalize the chain result - switch anything that used the old chain to
   // use the new one.
   ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
   return Res;
 }
+
 /// Promote the overflow flag of an overflowing arithmetic node.
 SDValue DAGTypeLegalizer::PromoteIntRes_Overflow(SDNode *N) {
   // Simply change the return type of the boolean result.
@@ -552,14 +566,6 @@ SDValue DAGTypeLegalizer::PromoteIntRes_SADDSUBO(SDNode *N, unsigned ResNo) {
   return Res;
 }
 
-SDValue DAGTypeLegalizer::PromoteIntRes_SDIV(SDNode *N) {
-  // Sign extend the input.
-  SDValue LHS = SExtPromotedInteger(N->getOperand(0));
-  SDValue RHS = SExtPromotedInteger(N->getOperand(1));
-  return DAG.getNode(N->getOpcode(), SDLoc(N),
-                     LHS.getValueType(), LHS, RHS);
-}
-
 SDValue DAGTypeLegalizer::PromoteIntRes_SELECT(SDNode *N) {
   SDValue LHS = GetPromotedInteger(N->getOperand(1));
   SDValue RHS = GetPromotedInteger(N->getOperand(2));
@@ -647,6 +653,22 @@ SDValue DAGTypeLegalizer::PromoteIntRes_SimpleIntBinOp(SDNode *N) {
                      LHS.getValueType(), LHS, RHS);
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_SExtIntBinOp(SDNode *N) {
+  // Sign extend the input.
+  SDValue LHS = SExtPromotedInteger(N->getOperand(0));
+  SDValue RHS = SExtPromotedInteger(N->getOperand(1));
+  return DAG.getNode(N->getOpcode(), SDLoc(N),
+                     LHS.getValueType(), LHS, RHS);
+}
+
+SDValue DAGTypeLegalizer::PromoteIntRes_ZExtIntBinOp(SDNode *N) {
+  // Zero extend the input.
+  SDValue LHS = ZExtPromotedInteger(N->getOperand(0));
+  SDValue RHS = ZExtPromotedInteger(N->getOperand(1));
+  return DAG.getNode(N->getOpcode(), SDLoc(N),
+                     LHS.getValueType(), LHS, RHS);
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_SRA(SDNode *N) {
   SDValue LHS = N->getOperand(0);
   SDValue RHS = N->getOperand(1);
@@ -788,14 +810,6 @@ SDValue DAGTypeLegalizer::PromoteIntRes_XMULO(SDNode *N, unsigned ResNo) {
   return Mul;
 }
 
-SDValue DAGTypeLegalizer::PromoteIntRes_UDIV(SDNode *N) {
-  // Zero extend the input.
-  SDValue LHS = ZExtPromotedInteger(N->getOperand(0));
-  SDValue RHS = ZExtPromotedInteger(N->getOperand(1));
-  return DAG.getNode(N->getOpcode(), SDLoc(N),
-                     LHS.getValueType(), LHS, RHS);
-}
-
 SDValue DAGTypeLegalizer::PromoteIntRes_UNDEF(SDNode *N) {
   return DAG.getUNDEF(TLI.getTypeToTransformTo(*DAG.getContext(),
                                                N->getValueType(0)));
@@ -893,6 +907,10 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
                                                     OpNo); break;
   case ISD::MLOAD:        Res = PromoteIntOp_MLOAD(cast<MaskedLoadSDNode>(N),
                                                     OpNo); break;
+  case ISD::MGATHER:  Res = PromoteIntOp_MGATHER(cast<MaskedGatherSDNode>(N),
+                                                 OpNo); break;
+  case ISD::MSCATTER: Res = PromoteIntOp_MSCATTER(cast<MaskedScatterSDNode>(N),
+                                                  OpNo); break;
   case ISD::TRUNCATE:     Res = PromoteIntOp_TRUNCATE(N); break;
   case ISD::FP16_TO_FP:
   case ISD::UINT_TO_FP:   Res = PromoteIntOp_UINT_TO_FP(N); break;
@@ -1161,56 +1179,49 @@ SDValue DAGTypeLegalizer::PromoteIntOp_STORE(StoreSDNode *N, unsigned OpNo){
                            N->getMemoryVT(), N->getMemOperand());
 }
 
-SDValue DAGTypeLegalizer::PromoteIntOp_MSTORE(MaskedStoreSDNode *N, unsigned OpNo){
+SDValue DAGTypeLegalizer::PromoteIntOp_MSTORE(MaskedStoreSDNode *N,
+                                              unsigned OpNo) {
 
   SDValue DataOp = N->getValue();
   EVT DataVT = DataOp.getValueType();
   SDValue Mask = N->getMask();
-  EVT MaskVT = Mask.getValueType();
   SDLoc dl(N);
 
   bool TruncateStore = false;
-  if (!TLI.isTypeLegal(DataVT)) {
-    if (getTypeAction(DataVT) == TargetLowering::TypePromoteInteger) {
-      DataOp = GetPromotedInteger(DataOp);
-      if (!TLI.isTypeLegal(MaskVT))
-        Mask = PromoteTargetBoolean(Mask, DataOp.getValueType());
-      TruncateStore = true;
-    }
+  if (OpNo == 2) {
+    // Mask comes before the data operand. If the data operand is legal, we just
+    // promote the mask.
+    // When the data operand has illegal type, we should legalize the data
+    // operand first. The mask will be promoted/splitted/widened according to
+    // the data operand type.
+    if (TLI.isTypeLegal(DataVT))
+      Mask = PromoteTargetBoolean(Mask, DataVT);
     else {
-      assert(getTypeAction(DataVT) == TargetLowering::TypeWidenVector &&
-             "Unexpected data legalization in MSTORE");
-      DataOp = GetWidenedVector(DataOp);
-
-      if (getTypeAction(MaskVT) == TargetLowering::TypeWidenVector)
-        Mask = GetWidenedVector(Mask);
-      else {
-        EVT BoolVT = getSetCCResultType(DataOp.getValueType());
+      if (getTypeAction(DataVT) == TargetLowering::TypePromoteInteger)
+        return PromoteIntOp_MSTORE(N, 3);
 
-        // We can't use ModifyToType() because we should fill the mask with
-        // zeroes
-        unsigned WidenNumElts = BoolVT.getVectorNumElements();
-        unsigned MaskNumElts = MaskVT.getVectorNumElements();
+      else if (getTypeAction(DataVT) == TargetLowering::TypeWidenVector)
+        return WidenVecOp_MSTORE(N, 3);
 
-        unsigned NumConcat = WidenNumElts / MaskNumElts;
-        SmallVector<SDValue, 16> Ops(NumConcat);
-        SDValue ZeroVal = DAG.getConstant(0, dl, MaskVT);
-        Ops[0] = Mask;
-        for (unsigned i = 1; i != NumConcat; ++i)
-          Ops[i] = ZeroVal;
-
-        Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, BoolVT, Ops);
+      else {
+        assert (getTypeAction(DataVT) == TargetLowering::TypeSplitVector);
+        return SplitVecOp_MSTORE(N, 3);
       }
     }
+  } else { // Data operand
+    assert(OpNo == 3 && "Unexpected operand for promotion");
+    DataOp = GetPromotedInteger(DataOp);
+    Mask = PromoteTargetBoolean(Mask, DataOp.getValueType());
+    TruncateStore = true;
   }
-  else
-    Mask = PromoteTargetBoolean(N->getMask(), DataOp.getValueType());
+
   return DAG.getMaskedStore(N->getChain(), dl, DataOp, N->getBasePtr(), Mask,
                             N->getMemoryVT(), N->getMemOperand(),
                             TruncateStore);
 }
 
-SDValue DAGTypeLegalizer::PromoteIntOp_MLOAD(MaskedLoadSDNode *N, unsigned OpNo){
+SDValue DAGTypeLegalizer::PromoteIntOp_MLOAD(MaskedLoadSDNode *N,
+                                             unsigned OpNo) {
   assert(OpNo == 2 && "Only know how to promote the mask!");
   EVT DataVT = N->getValueType(0);
   SDValue Mask = PromoteTargetBoolean(N->getOperand(OpNo), DataVT);
@@ -1219,6 +1230,31 @@ SDValue DAGTypeLegalizer::PromoteIntOp_MLOAD(MaskedLoadSDNode *N, unsigned OpNo)
   return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
 }
 
+SDValue DAGTypeLegalizer::PromoteIntOp_MGATHER(MaskedGatherSDNode *N,
+                                               unsigned OpNo) {
+
+  SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
+  if (OpNo == 2) {
+    // The Mask
+    EVT DataVT = N->getValueType(0);
+    NewOps[OpNo] = PromoteTargetBoolean(N->getOperand(OpNo), DataVT);
+  } else
+    NewOps[OpNo] = GetPromotedInteger(N->getOperand(OpNo));
+  return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
+}
+
+SDValue DAGTypeLegalizer::PromoteIntOp_MSCATTER(MaskedScatterSDNode *N,
+                                                unsigned OpNo) {
+  SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
+  if (OpNo == 2) {
+    // The Mask
+    EVT DataVT = N->getValue().getValueType();
+    NewOps[OpNo] = PromoteTargetBoolean(N->getOperand(OpNo), DataVT);
+  } else
+    NewOps[OpNo] = GetPromotedInteger(N->getOperand(OpNo));
+  return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
+}
+
 SDValue DAGTypeLegalizer::PromoteIntOp_TRUNCATE(SDNode *N) {
   SDValue Op = GetPromotedInteger(N->getOperand(0));
   return DAG.getNode(ISD::TRUNCATE, SDLoc(N), N->getValueType(0), Op);
@@ -2075,7 +2111,7 @@ void DAGTypeLegalizer::ExpandIntRes_LOAD(LoadSDNode *N,
     }
   }
 
-  // Legalized the chain result - switch anything that used the old chain to
+  // Legalize the chain result - switch anything that used the old chain to
   // use the new one.
   ReplaceValueWith(SDValue(N, 1), Ch);
 }