Fix broken type legalization of min/max
[oota-llvm.git] / lib / CodeGen / SelectionDAG / LegalizeIntegerTypes.cpp
index 2cfcf77b17a58ac7434528630c1a493555e68359..3131ca1014587111ab8b25546613181bbcdc67ae 100644 (file)
@@ -66,16 +66,20 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::CTTZ:        Res = PromoteIntRes_CTTZ(N); break;
   case ISD::EXTRACT_VECTOR_ELT:
                          Res = PromoteIntRes_EXTRACT_VECTOR_ELT(N); break;
-  case ISD::LOAD:        Res = PromoteIntRes_LOAD(cast<LoadSDNode>(N));break;
-  case ISD::MLOAD:       Res = PromoteIntRes_MLOAD(cast<MaskedLoadSDNode>(N));break;
+  case ISD::LOAD:        Res = PromoteIntRes_LOAD(cast<LoadSDNode>(N)); break;
+  case ISD::MLOAD:       Res = PromoteIntRes_MLOAD(cast<MaskedLoadSDNode>(N));
+    break;
+  case ISD::MGATHER:     Res = PromoteIntRes_MGATHER(cast<MaskedGatherSDNode>(N));
+    break;
   case ISD::SELECT:      Res = PromoteIntRes_SELECT(N); break;
   case ISD::VSELECT:     Res = PromoteIntRes_VSELECT(N); break;
   case ISD::SELECT_CC:   Res = PromoteIntRes_SELECT_CC(N); break;
   case ISD::SETCC:       Res = PromoteIntRes_SETCC(N); break;
   case ISD::SMIN:
-  case ISD::SMAX:
+  case ISD::SMAX:        Res = PromoteIntRes_SExtOrZExtIntBinOp(N, true); break;
   case ISD::UMIN:
-  case ISD::UMAX:        Res = PromoteIntRes_SimpleIntBinOp(N); break;
+  case ISD::UMAX:        Res = PromoteIntRes_SExtOrZExtIntBinOp(N, false); break;
+
   case ISD::SHL:         Res = PromoteIntRes_SHL(N); break;
   case ISD::SIGN_EXTEND_INREG:
                          Res = PromoteIntRes_SIGN_EXTEND_INREG(N); break;
@@ -181,7 +185,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_Atomic0(AtomicSDNode *N) {
                               N->getChain(), N->getBasePtr(),
                               N->getMemOperand(), N->getOrdering(),
                               N->getSynchScope());
-  // Legalized the chain result - switch anything that used the old chain to
+  // Legalize the chain result - switch anything that used the old chain to
   // use the new one.
   ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
   return Res;
@@ -194,7 +198,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_Atomic1(AtomicSDNode *N) {
                               N->getChain(), N->getBasePtr(),
                               Op2, N->getMemOperand(), N->getOrdering(),
                               N->getSynchScope());
-  // Legalized the chain result - switch anything that used the old chain to
+  // Legalize the chain result - switch anything that used the old chain to
   // use the new one.
   ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
   return Res;
@@ -479,7 +483,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_LOAD(LoadSDNode *N) {
   SDValue Res = DAG.getExtLoad(ExtType, dl, NVT, N->getChain(), N->getBasePtr(),
                                N->getMemoryVT(), N->getMemOperand());
 
-  // Legalized the chain result - switch anything that used the old chain to
+  // Legalize the chain result - switch anything that used the old chain to
   // use the new one.
   ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
   return Res;
@@ -489,20 +493,34 @@ SDValue DAGTypeLegalizer::PromoteIntRes_MLOAD(MaskedLoadSDNode *N) {
   EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
   SDValue ExtSrc0 = GetPromotedInteger(N->getSrc0());
 
-  SDValue Mask = N->getMask();
-  EVT NewMaskVT = getSetCCResultType(NVT);
-  if (NewMaskVT != N->getMask().getValueType())
-    Mask = PromoteTargetBoolean(Mask, NewMaskVT);
   SDLoc dl(N);
-
   SDValue Res = DAG.getMaskedLoad(NVT, dl, N->getChain(), N->getBasePtr(),
-                                  Mask, ExtSrc0, N->getMemoryVT(),
+                                  N->getMask(), ExtSrc0, N->getMemoryVT(),
                                   N->getMemOperand(), ISD::SEXTLOAD);
-  // Legalized the chain result - switch anything that used the old chain to
+  // Legalize the chain result - switch anything that used the old chain to
+  // use the new one.
+  ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
+  return Res;
+}
+
+SDValue DAGTypeLegalizer::PromoteIntRes_MGATHER(MaskedGatherSDNode *N) {
+  EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  SDValue ExtSrc0 = GetPromotedInteger(N->getValue());
+  assert(NVT == ExtSrc0.getValueType() &&
+      "Gather result type and the passThru agrument type should be the same");
+
+  SDLoc dl(N);
+  SDValue Ops[] = {N->getChain(), ExtSrc0, N->getMask(), N->getBasePtr(),
+                   N->getIndex()};
+  SDValue Res = DAG.getMaskedGather(DAG.getVTList(NVT, MVT::Other),
+                                    N->getMemoryVT(), dl, Ops,
+                                    N->getMemOperand()); 
+  // Legalize the chain result - switch anything that used the old chain to
   // use the new one.
   ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
   return Res;
 }
+
 /// Promote the overflow flag of an overflowing arithmetic node.
 SDValue DAGTypeLegalizer::PromoteIntRes_Overflow(SDNode *N) {
   // Simply change the return type of the boolean result.
@@ -643,6 +661,22 @@ SDValue DAGTypeLegalizer::PromoteIntRes_SimpleIntBinOp(SDNode *N) {
                      LHS.getValueType(), LHS, RHS);
 }
 
+SDValue DAGTypeLegalizer::PromoteIntRes_SExtOrZExtIntBinOp(SDNode *N,
+                                                           bool Signed) {
+  SDValue LHS, RHS;
+
+  if (Signed) {
+    LHS = SExtPromotedInteger(N->getOperand(0));
+    RHS = SExtPromotedInteger(N->getOperand(1));
+  } else {
+    LHS = ZExtPromotedInteger(N->getOperand(0));
+    RHS = ZExtPromotedInteger(N->getOperand(1));
+  }
+
+  return DAG.getNode(N->getOpcode(), SDLoc(N),
+                     LHS.getValueType(), LHS, RHS);
+}
+
 SDValue DAGTypeLegalizer::PromoteIntRes_SRA(SDNode *N) {
   SDValue LHS = N->getOperand(0);
   SDValue RHS = N->getOperand(1);
@@ -889,6 +923,10 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
                                                     OpNo); break;
   case ISD::MLOAD:        Res = PromoteIntOp_MLOAD(cast<MaskedLoadSDNode>(N),
                                                     OpNo); break;
+  case ISD::MGATHER:  Res = PromoteIntOp_MGATHER(cast<MaskedGatherSDNode>(N),
+                                                 OpNo); break;
+  case ISD::MSCATTER: Res = PromoteIntOp_MSCATTER(cast<MaskedScatterSDNode>(N),
+                                                  OpNo); break;
   case ISD::TRUNCATE:     Res = PromoteIntOp_TRUNCATE(N); break;
   case ISD::FP16_TO_FP:
   case ISD::UINT_TO_FP:   Res = PromoteIntOp_UINT_TO_FP(N); break;
@@ -1157,56 +1195,49 @@ SDValue DAGTypeLegalizer::PromoteIntOp_STORE(StoreSDNode *N, unsigned OpNo){
                            N->getMemoryVT(), N->getMemOperand());
 }
 
-SDValue DAGTypeLegalizer::PromoteIntOp_MSTORE(MaskedStoreSDNode *N, unsigned OpNo){
+SDValue DAGTypeLegalizer::PromoteIntOp_MSTORE(MaskedStoreSDNode *N,
+                                              unsigned OpNo) {
 
   SDValue DataOp = N->getValue();
   EVT DataVT = DataOp.getValueType();
   SDValue Mask = N->getMask();
-  EVT MaskVT = Mask.getValueType();
   SDLoc dl(N);
 
   bool TruncateStore = false;
-  if (!TLI.isTypeLegal(DataVT)) {
-    if (getTypeAction(DataVT) == TargetLowering::TypePromoteInteger) {
-      DataOp = GetPromotedInteger(DataOp);
-      if (!TLI.isTypeLegal(MaskVT))
-        Mask = PromoteTargetBoolean(Mask, DataOp.getValueType());
-      TruncateStore = true;
-    }
+  if (OpNo == 2) {
+    // Mask comes before the data operand. If the data operand is legal, we just
+    // promote the mask.
+    // When the data operand has illegal type, we should legalize the data
+    // operand first. The mask will be promoted/splitted/widened according to
+    // the data operand type.
+    if (TLI.isTypeLegal(DataVT))
+      Mask = PromoteTargetBoolean(Mask, DataVT);
     else {
-      assert(getTypeAction(DataVT) == TargetLowering::TypeWidenVector &&
-             "Unexpected data legalization in MSTORE");
-      DataOp = GetWidenedVector(DataOp);
+      if (getTypeAction(DataVT) == TargetLowering::TypePromoteInteger)
+        return PromoteIntOp_MSTORE(N, 3);
 
-      if (getTypeAction(MaskVT) == TargetLowering::TypeWidenVector)
-        Mask = GetWidenedVector(Mask);
-      else {
-        EVT BoolVT = getSetCCResultType(DataOp.getValueType());
-
-        // We can't use ModifyToType() because we should fill the mask with
-        // zeroes
-        unsigned WidenNumElts = BoolVT.getVectorNumElements();
-        unsigned MaskNumElts = MaskVT.getVectorNumElements();
+      else if (getTypeAction(DataVT) == TargetLowering::TypeWidenVector)
+        return WidenVecOp_MSTORE(N, 3);
 
-        unsigned NumConcat = WidenNumElts / MaskNumElts;
-        SmallVector<SDValue, 16> Ops(NumConcat);
-        SDValue ZeroVal = DAG.getConstant(0, dl, MaskVT);
-        Ops[0] = Mask;
-        for (unsigned i = 1; i != NumConcat; ++i)
-          Ops[i] = ZeroVal;
-
-        Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, BoolVT, Ops);
+      else {
+        assert (getTypeAction(DataVT) == TargetLowering::TypeSplitVector);
+        return SplitVecOp_MSTORE(N, 3);
       }
     }
+  } else { // Data operand
+    assert(OpNo == 3 && "Unexpected operand for promotion");
+    DataOp = GetPromotedInteger(DataOp);
+    Mask = PromoteTargetBoolean(Mask, DataOp.getValueType());
+    TruncateStore = true;
   }
-  else
-    Mask = PromoteTargetBoolean(N->getMask(), DataOp.getValueType());
+
   return DAG.getMaskedStore(N->getChain(), dl, DataOp, N->getBasePtr(), Mask,
                             N->getMemoryVT(), N->getMemOperand(),
                             TruncateStore);
 }
 
-SDValue DAGTypeLegalizer::PromoteIntOp_MLOAD(MaskedLoadSDNode *N, unsigned OpNo){
+SDValue DAGTypeLegalizer::PromoteIntOp_MLOAD(MaskedLoadSDNode *N,
+                                             unsigned OpNo) {
   assert(OpNo == 2 && "Only know how to promote the mask!");
   EVT DataVT = N->getValueType(0);
   SDValue Mask = PromoteTargetBoolean(N->getOperand(OpNo), DataVT);
@@ -1215,6 +1246,31 @@ SDValue DAGTypeLegalizer::PromoteIntOp_MLOAD(MaskedLoadSDNode *N, unsigned OpNo)
   return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
 }
 
+SDValue DAGTypeLegalizer::PromoteIntOp_MGATHER(MaskedGatherSDNode *N,
+                                               unsigned OpNo) {
+
+  SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
+  if (OpNo == 2) {
+    // The Mask
+    EVT DataVT = N->getValueType(0);
+    NewOps[OpNo] = PromoteTargetBoolean(N->getOperand(OpNo), DataVT);
+  } else
+    NewOps[OpNo] = GetPromotedInteger(N->getOperand(OpNo));
+  return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
+}
+
+SDValue DAGTypeLegalizer::PromoteIntOp_MSCATTER(MaskedScatterSDNode *N,
+                                                unsigned OpNo) {
+  SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
+  if (OpNo == 2) {
+    // The Mask
+    EVT DataVT = N->getValue().getValueType();
+    NewOps[OpNo] = PromoteTargetBoolean(N->getOperand(OpNo), DataVT);
+  } else
+    NewOps[OpNo] = GetPromotedInteger(N->getOperand(OpNo));
+  return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
+}
+
 SDValue DAGTypeLegalizer::PromoteIntOp_TRUNCATE(SDNode *N) {
   SDValue Op = GetPromotedInteger(N->getOperand(0));
   return DAG.getNode(ISD::TRUNCATE, SDLoc(N), N->getValueType(0), Op);
@@ -2071,7 +2127,7 @@ void DAGTypeLegalizer::ExpandIntRes_LOAD(LoadSDNode *N,
     }
   }
 
-  // Legalized the chain result - switch anything that used the old chain to
+  // Legalize the chain result - switch anything that used the old chain to
   // use the new one.
   ReplaceValueWith(SDValue(N, 1), Ch);
 }