move combineRepeatedFPDivisors logic into a helper function; NFCI
[oota-llvm.git] / lib / CodeGen / SelectionDAG / DAGCombiner.cpp
index 52d620b1d540fc5958cf7cc8d63c645c6bea8203..b06d53310bbb67ef4ea67dc07ba6048a9b68e461 100644 (file)
@@ -338,6 +338,7 @@ namespace {
                                          unsigned HiOp);
     SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
     SDValue CombineExtLoad(SDNode *N);
+    SDValue combineRepeatedFPDivisors(SDNode *N);
     SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT);
     SDValue BuildSDIV(SDNode *N);
     SDValue BuildSDIVPow2(SDNode *N);
@@ -1216,9 +1217,8 @@ void DAGCombiner::Run(CombineLevel AtLevel) {
   LegalTypes = Level >= AfterLegalizeTypes;
 
   // Add all the dag nodes to the worklist.
-  for (SelectionDAG::allnodes_iterator I = DAG.allnodes_begin(),
-       E = DAG.allnodes_end(); I != E; ++I)
-    AddToWorklist(I);
+  for (SDNode &Node : DAG.allnodes())
+    AddToWorklist(&Node);
 
   // Create a dummy node (which is not added to allnodes), that adds a reference
   // to the root node, preventing it from being deleted, and tracking any
@@ -2192,8 +2192,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
       return SDValue();
 
     // Target-specific implementation of sdiv x, pow2.
-    SDValue Res = BuildSDIVPow2(N);
-    if (Res.getNode())
+    if (SDValue Res = BuildSDIVPow2(N))
       return Res;
 
     unsigned lg2 = N1C->getAPIntValue().countTrailingZeros();
@@ -2229,10 +2228,9 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) {
 
   // If integer divide is expensive and we satisfy the requirements, emit an
   // alternate sequence.
-  if (N1C && !TLI.isIntDivCheap()) {
-    SDValue Op = BuildSDIV(N);
-    if (Op.getNode()) return Op;
-  }
+  if (N1C && !TLI.isIntDivCheap())
+    if (SDValue Op = BuildSDIV(N))
+      return Op;
 
   // undef / X -> 0
   if (N0.getOpcode() == ISD::UNDEF)
@@ -2285,10 +2283,9 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) {
     }
   }
   // fold (udiv x, c) -> alternate
-  if (N1C && !TLI.isIntDivCheap()) {
-    SDValue Op = BuildUDIV(N);
-    if (Op.getNode()) return Op;
-  }
+  if (N1C && !TLI.isIntDivCheap())
+    if (SDValue Op = BuildUDIV(N))
+      return Op;
 
   // undef / X -> 0
   if (N0.getOpcode() == ISD::UNDEF)
@@ -2532,8 +2529,8 @@ SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp,
 }
 
 SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
-  SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHS);
-  if (Res.getNode()) return Res;
+  if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHS))
+    return Res;
 
   EVT VT = N->getValueType(0);
   SDLoc DL(N);
@@ -2563,8 +2560,8 @@ SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
 }
 
 SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
-  SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHU);
-  if (Res.getNode()) return Res;
+  if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHU))
+    return Res;
 
   EVT VT = N->getValueType(0);
   SDLoc DL(N);
@@ -2614,15 +2611,15 @@ SDValue DAGCombiner::visitUMULO(SDNode *N) {
 }
 
 SDValue DAGCombiner::visitSDIVREM(SDNode *N) {
-  SDValue Res = SimplifyNodeWithTwoResults(N, ISD::SDIV, ISD::SREM);
-  if (Res.getNode()) return Res;
+  if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::SDIV, ISD::SREM))
+    return Res;
 
   return SDValue();
 }
 
 SDValue DAGCombiner::visitUDIVREM(SDNode *N) {
-  SDValue Res = SimplifyNodeWithTwoResults(N, ISD::UDIV, ISD::UREM);
-  if (Res.getNode()) return Res;
+  if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::UDIV, ISD::UREM))
+    return Res;
 
   return SDValue();
 }
@@ -3142,10 +3139,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
     return Combined;
 
   // Simplify: (and (op x...), (op y...))  -> (op (and x, y))
-  if (N0.getOpcode() == N1.getOpcode()) {
-    SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N);
-    if (Tmp.getNode()) return Tmp;
-  }
+  if (N0.getOpcode() == N1.getOpcode())
+    if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N))
+      return Tmp;
 
   // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1)
   // fold (and (sra)) -> (and (srl)) when possible.
@@ -3665,11 +3661,9 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
     return Combined;
 
   // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16)
-  SDValue BSwap = MatchBSwapHWord(N, N0, N1);
-  if (BSwap.getNode())
+  if (SDValue BSwap = MatchBSwapHWord(N, N0, N1))
     return BSwap;
-  BSwap = MatchBSwapHWordLow(N, N0, N1);
-  if (BSwap.getNode())
+  if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1))
     return BSwap;
 
   // reassociate or
@@ -3690,10 +3684,9 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
     }
   }
   // Simplify: (or (op x...), (op y...))  -> (op (or x, y))
-  if (N0.getOpcode() == N1.getOpcode()) {
-    SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N);
-    if (Tmp.getNode()) return Tmp;
-  }
+  if (N0.getOpcode() == N1.getOpcode())
+    if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N))
+      return Tmp;
 
   // See if this is some rotate idiom.
   if (SDNode *Rot = MatchRotate(N0, N1, SDLoc(N)))
@@ -4112,10 +4105,9 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
   }
 
   // Simplify: xor (op x...), (op y...)  -> (op (xor x, y))
-  if (N0.getOpcode() == N1.getOpcode()) {
-    SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N);
-    if (Tmp.getNode()) return Tmp;
-  }
+  if (N0.getOpcode() == N1.getOpcode())
+    if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N))
+      return Tmp;
 
   // Simplify the expression using non-local knowledge.
   if (!VT.isVector() &&
@@ -4434,11 +4426,9 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
     return DAG.getNode(ISD::ADD, SDLoc(N), VT, Shl0, Shl1);
   }
 
-  if (N1C && !N1C->isOpaque()) {
-    SDValue NewSHL = visitShiftByConstant(N, N1C);
-    if (NewSHL.getNode())
+  if (N1C && !N1C->isOpaque())
+    if (SDValue NewSHL = visitShiftByConstant(N, N1C))
       return NewSHL;
-  }
 
   return SDValue();
 }
@@ -4583,11 +4573,9 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
   if (DAG.SignBitIsZero(N0))
     return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, N1);
 
-  if (N1C && !N1C->isOpaque()) {
-    SDValue NewSRA = visitShiftByConstant(N, N1C);
-    if (NewSRA.getNode())
+  if (N1C && !N1C->isOpaque())
+    if (SDValue NewSRA = visitShiftByConstant(N, N1C))
       return NewSRA;
-  }
 
   return SDValue();
 }
@@ -5837,8 +5825,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) {
   if (N0.getOpcode() == ISD::TRUNCATE) {
     // fold (sext (truncate (load x))) -> (sext (smaller load x))
     // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n)))
-    SDValue NarrowLoad = ReduceLoadWidth(N0.getNode());
-    if (NarrowLoad.getNode()) {
+    if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
       SDNode* oye = N0.getNode()->getOperand(0).getNode();
       if (NarrowLoad.getNode() != N0.getNode()) {
         CombineTo(N0.getNode(), NarrowLoad);
@@ -6120,8 +6107,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
   // fold (zext (truncate (load x))) -> (zext (smaller load x))
   // fold (zext (truncate (srl (load x), c))) -> (zext (small load (x+c/n)))
   if (N0.getOpcode() == ISD::TRUNCATE) {
-    SDValue NarrowLoad = ReduceLoadWidth(N0.getNode());
-    if (NarrowLoad.getNode()) {
+    if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
       SDNode* oye = N0.getNode()->getOperand(0).getNode();
       if (NarrowLoad.getNode() != N0.getNode()) {
         CombineTo(N0.getNode(), NarrowLoad);
@@ -6138,8 +6124,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
 
     // fold (zext (truncate (load x))) -> (zext (smaller load x))
     // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n)))
-    SDValue NarrowLoad = ReduceLoadWidth(N0.getNode());
-    if (NarrowLoad.getNode()) {
+    if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) {
       SDNode* oye = N0.getNode()->getOperand(0).getNode();
       if (NarrowLoad.getNode() != N0.getNode()) {
         CombineTo(N0.getNode(), NarrowLoad);
@@ -6546,8 +6531,7 @@ SDValue DAGCombiner::GetDemandedBits(SDValue V, const APInt &Mask) {
       // Watch out for shift count overflow though.
       if (Amt >= Mask.getBitWidth()) break;
       APInt NewMask = Mask << Amt;
-      SDValue SimplifyLHS = GetDemandedBits(V.getOperand(0), NewMask);
-      if (SimplifyLHS.getNode())
+      if (SDValue SimplifyLHS = GetDemandedBits(V.getOperand(0), NewMask))
         return DAG.getNode(ISD::SRL, SDLoc(V), V.getValueType(),
                            SimplifyLHS, V.getOperand(1));
     }
@@ -6771,8 +6755,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
 
   // fold (sext_in_reg (load x)) -> (smaller sextload x)
   // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits))
-  SDValue NarrowLoad = ReduceLoadWidth(N);
-  if (NarrowLoad.getNode())
+  if (SDValue NarrowLoad = ReduceLoadWidth(N))
     return NarrowLoad;
 
   // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24)
@@ -6999,9 +6982,9 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
   // fold (truncate (load x)) -> (smaller load x)
   // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits))
   if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) {
-    SDValue Reduced = ReduceLoadWidth(N);
-    if (Reduced.getNode())
+    if (SDValue Reduced = ReduceLoadWidth(N))
       return Reduced;
+
     // Handle the case where the load remains an extending load even
     // after truncation.
     if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N0.getNode())) {
@@ -7240,11 +7223,9 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
   }
 
   // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive.
-  if (N0.getOpcode() == ISD::BUILD_PAIR) {
-    SDValue CombineLD = CombineConsecutiveLoads(N0.getNode(), VT);
-    if (CombineLD.getNode())
+  if (N0.getOpcode() == ISD::BUILD_PAIR)
+    if (SDValue CombineLD = CombineConsecutiveLoads(N0.getNode(), VT))
       return CombineLD;
-  }
 
   // Remove double bitcasts from shuffles - this is often a legacy of
   // XformToShuffleWithZero being used to combine bitmaskings (of
@@ -7257,10 +7238,10 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
     ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(N0);
 
     // If operands are a bitcast, peek through if it casts the original VT.
-    // If operands are a UNDEF or constant, just bitcast back to original VT.
+    // If operands are a constant, just bitcast back to original VT.
     auto PeekThroughBitcast = [&](SDValue Op) {
       if (Op.getOpcode() == ISD::BITCAST &&
-          Op.getOperand(0)->getValueType(0) == VT)
+          Op.getOperand(0).getValueType() == VT)
         return SDValue(Op.getOperand(0));
       if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) ||
           ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode()))
@@ -8139,7 +8120,9 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
     // Undo the fmul 2.0, x -> fadd x, x transformation, since if it occurs
     // during an early run of DAGCombiner can prevent folding with fmuls
     // inserted during lowering.
-    if (N0.getOpcode() == ISD::FADD && N0.getOperand(0) == N0.getOperand(1)) {
+    if (N0.getOpcode() == ISD::FADD &&
+        (N0.getOperand(0) == N0.getOperand(1)) &&
+        N0.hasOneUse()) {
       const SDValue Two = DAG.getConstantFP(2.0, DL, VT);
       SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1);
       return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts);
@@ -8253,6 +8236,60 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
   return SDValue();
 }
 
+// Combine multiple FDIVs with the same divisor into multiple FMULs by the
+// reciprocal.
+// E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
+// Notice that this is not always beneficial. One reason is different target
+// may have different costs for FDIV and FMUL, so sometimes the cost of two
+// FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
+// is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
+SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) {
+  if (!DAG.getTarget().Options.UnsafeFPMath)
+    return SDValue();
+  
+  SDValue N0 = N->getOperand(0);
+  ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
+
+  // Skip if current node is a reciprocal.
+  if (N0CFP && N0CFP->isExactlyValue(1.0))
+    return SDValue();
+  
+  SDValue N1 = N->getOperand(1);
+  SmallVector<SDNode *, 4> Users;
+
+  // Find all FDIV users of the same divisor.
+  for (auto *U : N1->uses()) {
+    if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1)
+      Users.push_back(U);
+  }
+
+  if (!TLI.combineRepeatedFPDivisors(Users.size()))
+    return SDValue();
+
+  EVT VT = N->getValueType(0);
+  SDLoc DL(N);
+  SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
+  // FIXME: This optimization requires some level of fast-math, so the
+  // created reciprocal node should at least have the 'allowReciprocal'
+  // fast-math-flag set.
+  SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1);
+
+  // Dividend / Divisor -> Dividend * Reciprocal
+  for (auto *U : Users) {
+    SDValue Dividend = U->getOperand(0);
+    if (Dividend != FPOne) {
+      SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend,
+                                    Reciprocal);
+      CombineTo(U, NewNode);
+    } else if (U != Reciprocal.getNode()) {
+      // In the absence of fast-math-flags, this user node is always the
+      // same node as Reciprocal, but with FMF they may be different nodes.
+      CombineTo(U, Reciprocal);
+    }
+  }
+  return SDValue(N, 0);  // N was replaced.
+}
+
 SDValue DAGCombiner::visitFDIV(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
@@ -8353,48 +8390,8 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
     }
   }
 
-  // Combine multiple FDIVs with the same divisor into multiple FMULs by the
-  // reciprocal.
-  // E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip)
-  // Notice that this is not always beneficial. One reason is different target
-  // may have different costs for FDIV and FMUL, so sometimes the cost of two
-  // FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason
-  // is the critical path is increased from "one FDIV" to "one FDIV + one FMUL".
-  if (Options.UnsafeFPMath) {
-    // Skip if current node is a reciprocal.
-    if (N0CFP && N0CFP->isExactlyValue(1.0))
-      return SDValue();
-
-    SmallVector<SDNode *, 4> Users;
-    // Find all FDIV users of the same divisor.
-    for (auto *U : N1->uses()) {
-      if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1)
-        Users.push_back(U);
-    }
-
-    if (TLI.combineRepeatedFPDivisors(Users.size())) {
-      SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
-      // FIXME: This optimization requires some level of fast-math, so the
-      // created reciprocal node should at least have the 'allowReciprocal'
-      // fast-math-flag set.
-      SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1);
-
-      // Dividend / Divisor -> Dividend * Reciprocal
-      for (auto *U : Users) {
-        SDValue Dividend = U->getOperand(0);
-        if (Dividend != FPOne) {
-          SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend,
-                                        Reciprocal);
-          CombineTo(U, NewNode);
-        } else if (U != Reciprocal.getNode()) {
-          // In the absence of fast-math-flags, this user node is always the
-          // same node as Reciprocal, but with FMF they may be different nodes.
-          CombineTo(U, Reciprocal);
-        }
-      }
-      return SDValue(N, 0);  // N was replaced.
-    }
-  }
+  if (SDValue CombineRepeatedDivisors = combineRepeatedFPDivisors(N))
+    return CombineRepeatedDivisors;
 
   return SDValue();
 }