Switch lowering: add heuristic for filling leaf nodes in the weight-balanced binary...
[oota-llvm.git] / lib / CodeGen / SelectionDAG / SelectionDAGBuilder.cpp
index 48bf22699b09dce2fb1a766f28b0e1c4e1f66c07..8313a48c34676bb15eadfd498ba27bcf897e0f95 100644 (file)
@@ -78,12 +78,16 @@ LimitFPPrecision("limit-float-precision",
                  cl::location(LimitFloatPrecision),
                  cl::init(0));
 
+static cl::opt<bool>
+EnableFMFInDAG("enable-fmf-dag", cl::init(false), cl::Hidden,
+                cl::desc("Enable fast-math-flags for DAG nodes"));
+
 // Limit the width of DAG chains. This is important in general to prevent
-// prevent DAG-based analysis from blowing up. For example, alias analysis and
+// DAG-based analysis from blowing up. For example, alias analysis and
 // load clustering may not complete in reasonable time. It is difficult to
 // recognize and avoid this situation within each individual analysis, and
 // future analyses are likely to have the same behavior. Limiting DAG width is
-// the safe approach, and will be especially important with global DAGs.
+// the safe approach and will be especially important with global DAGs.
 //
 // MaxParallelChains default is arbitrarily high to avoid affecting
 // optimization, but could be lowered to improve compile time. Any ld-ld-st-st
@@ -1002,7 +1006,16 @@ bool SelectionDAGBuilder::findValue(const Value *V) const {
 SDValue SelectionDAGBuilder::getNonRegisterValue(const Value *V) {
   // If we already have an SDValue for this value, use it.
   SDValue &N = NodeMap[V];
-  if (N.getNode()) return N;
+  if (N.getNode()) {
+    if (isa<ConstantSDNode>(N) || isa<ConstantFPSDNode>(N)) {
+      // Remove the debug location from the node as the node is about to be used
+      // in a location which may differ from the original debug location.  This
+      // is relevant to Constant and ConstantFP nodes because they can appear
+      // as constant expressions inside PHI nodes.
+      N->setDebugLoc(DebugLoc());
+    }
+    return N;
+  }
 
   // Otherwise create a new SDValue and remember it.
   SDValue Val = getValueImpl(V);
@@ -2139,6 +2152,8 @@ void SelectionDAGBuilder::visitBinary(const User &I, unsigned OpCode) {
   bool nuw = false;
   bool nsw = false;
   bool exact = false;
+  FastMathFlags FMF;
+
   if (const OverflowingBinaryOperator *OFBinOp =
           dyn_cast<const OverflowingBinaryOperator>(&I)) {
     nuw = OFBinOp->hasNoUnsignedWrap();
@@ -2147,9 +2162,22 @@ void SelectionDAGBuilder::visitBinary(const User &I, unsigned OpCode) {
   if (const PossiblyExactOperator *ExactOp =
           dyn_cast<const PossiblyExactOperator>(&I))
     exact = ExactOp->isExact();
-
+  if (const FPMathOperator *FPOp = dyn_cast<const FPMathOperator>(&I))
+    FMF = FPOp->getFastMathFlags();
+
+  SDNodeFlags Flags;
+  Flags.setExact(exact);
+  Flags.setNoSignedWrap(nsw);
+  Flags.setNoUnsignedWrap(nuw);
+  if (EnableFMFInDAG) {
+    Flags.setAllowReciprocal(FMF.allowReciprocal());
+    Flags.setNoInfs(FMF.noInfs());
+    Flags.setNoNaNs(FMF.noNaNs());
+    Flags.setNoSignedZeros(FMF.noSignedZeros());
+    Flags.setUnsafeAlgebra(FMF.unsafeAlgebra());
+  }
   SDValue BinNodeValue = DAG.getNode(OpCode, getCurSDLoc(), Op1.getValueType(),
-                                     Op1, Op2, nuw, nsw, exact);
+                                     Op1, Op2, &Flags);
   setValue(&I, BinNodeValue);
 }
 
@@ -2197,9 +2225,12 @@ void SelectionDAGBuilder::visitShift(const User &I, unsigned Opcode) {
             dyn_cast<const PossiblyExactOperator>(&I))
       exact = ExactOp->isExact();
   }
-
+  SDNodeFlags Flags;
+  Flags.setExact(exact);
+  Flags.setNoSignedWrap(nsw);
+  Flags.setNoUnsignedWrap(nuw);
   SDValue Res = DAG.getNode(Opcode, getCurSDLoc(), Op1.getValueType(), Op1, Op2,
-                            nuw, nsw, exact);
+                            &Flags);
   setValue(&I, Res);
 }
 
@@ -2282,7 +2313,11 @@ void SelectionDAGBuilder::visitSelect(const User &I) {
     while (TLI.getTypeAction(Ctx, VT) == TargetLoweringBase::TypeSplitVector)
       VT = TLI.getTypeToTransformTo(Ctx, VT);
 
-    if (Opc != ISD::DELETED_NODE && TLI.isOperationLegalOrCustom(Opc, VT)) {
+    if (Opc != ISD::DELETED_NODE && TLI.isOperationLegalOrCustom(Opc, VT) &&
+        // If the underlying comparison instruction is used by any other instruction,
+        // the consumed instructions won't be destroyed, so it is not profitable
+        // to convert to a min/max.
+        cast<SelectInst>(&I)->getCondition()->hasOneUse()) {
       OpCode = Opc;
       LHSVal = getValue(LHS);
       RHSVal = getValue(RHS);
@@ -2848,7 +2883,17 @@ void SelectionDAGBuilder::visitLoad(const LoadInst &I) {
 
   bool isVolatile = I.isVolatile();
   bool isNonTemporal = I.getMetadata(LLVMContext::MD_nontemporal) != nullptr;
-  bool isInvariant = I.getMetadata(LLVMContext::MD_invariant_load) != nullptr;
+
+  // The IR notion of invariant_load only guarantees that all *non-faulting*
+  // invariant loads result in the same value.  The MI notion of invariant load
+  // guarantees that the load can be legally moved to any location within its
+  // containing function.  The MI notion of invariant_load is stronger than the
+  // IR notion of invariant_load -- an MI invariant_load is an IR invariant_load
+  // with a guarantee that the location being loaded from is dereferenceable
+  // throughout the function's lifetime.
+
+  bool isInvariant = I.getMetadata(LLVMContext::MD_invariant_load) != nullptr &&
+    isDereferenceablePointer(SV, *DAG.getTarget().getDataLayout());
   unsigned Alignment = I.getAlignment();
 
   AAMDNodes AAInfo;
@@ -2869,7 +2914,7 @@ void SelectionDAGBuilder::visitLoad(const LoadInst &I) {
     // Serialize volatile loads with other side effects.
     Root = getRoot();
   else if (AA->pointsToConstantMemory(
-             AliasAnalysis::Location(SV, AA->getTypeStoreSize(Ty), AAInfo))) {
+               MemoryLocation(SV, AA->getTypeStoreSize(Ty), AAInfo))) {
     // Do not serialize (non-volatile) loads of constant memory with anything.
     Root = DAG.getEntryNode();
     ConstantMemory = true;
@@ -2884,8 +2929,7 @@ void SelectionDAGBuilder::visitLoad(const LoadInst &I) {
     Root = TLI.prepareVolatileOrAtomicLoad(Root, dl, DAG);
 
   SmallVector<SDValue, 4> Values(NumValues);
-  SmallVector<SDValue, 4> Chains(std::min(unsigned(MaxParallelChains),
-                                          NumValues));
+  SmallVector<SDValue, 4> Chains(std::min(MaxParallelChains, NumValues));
   EVT PtrVT = Ptr.getValueType();
   unsigned ChainI = 0;
   for (unsigned i = 0; i != NumValues; ++i, ++ChainI) {
@@ -2949,8 +2993,7 @@ void SelectionDAGBuilder::visitStore(const StoreInst &I) {
   SDValue Ptr = getValue(PtrV);
 
   SDValue Root = getRoot();
-  SmallVector<SDValue, 4> Chains(std::min(unsigned(MaxParallelChains),
-                                          NumValues));
+  SmallVector<SDValue, 4> Chains(std::min(MaxParallelChains, NumValues));
   EVT PtrVT = Ptr.getValueType();
   bool isVolatile = I.isVolatile();
   bool isNonTemporal = I.getMetadata(LLVMContext::MD_nontemporal) != nullptr;
@@ -3118,10 +3161,8 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) {
   const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range);
 
   SDValue InChain = DAG.getRoot();
-  if (AA->pointsToConstantMemory(
-      AliasAnalysis::Location(PtrOperand,
-                              AA->getTypeStoreSize(I.getType()),
-                              AAInfo))) {
+  if (AA->pointsToConstantMemory(MemoryLocation(
+          PtrOperand, AA->getTypeStoreSize(I.getType()), AAInfo))) {
     // Do not serialize (non-volatile) loads of constant memory with anything.
     InChain = DAG.getEntryNode();
   }
@@ -3163,10 +3204,9 @@ void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
   Value *BasePtr = Ptr;
   bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
   bool ConstantMemory = false;
-  if (UniformBase && AA->pointsToConstantMemory(
-      AliasAnalysis::Location(BasePtr,
-                                   AA->getTypeStoreSize(I.getType()),
-                              AAInfo))) {
+  if (UniformBase &&
+      AA->pointsToConstantMemory(
+          MemoryLocation(BasePtr, AA->getTypeStoreSize(I.getType()), AAInfo))) {
     // Do not serialize (non-volatile) loads of constant memory with anything.
     Root = DAG.getEntryNode();
     ConstantMemory = true;
@@ -4960,6 +5000,7 @@ SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, unsigned Intrinsic) {
     assert(Reg && "cannot get exception code on this platform");
     MVT PtrVT = TLI.getPointerTy();
     const TargetRegisterClass *PtrRC = TLI.getRegClassFor(PtrVT);
+    assert(FuncInfo.MBB->isLandingPad() && "eh.exceptioncode in non-lpad");
     unsigned VReg = FuncInfo.MBB->addLiveIn(Reg, PtrRC);
     SDValue N =
         DAG.getCopyFromReg(DAG.getEntryNode(), getCurSDLoc(), VReg, PtrVT);
@@ -4979,7 +5020,7 @@ SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI,
   if (LandingPad) {
     // Insert a label before the invoke call to mark the try range.  This can be
     // used to detect deletion of the invoke via the MachineModuleInfo.
-    BeginLabel = MMI.getContext().CreateTempSymbol();
+    BeginLabel = MMI.getContext().createTempSymbol();
 
     // For SjLj, keep track of which landing pads go with which invokes
     // so as to maintain the ordering of pads in the LSDA.
@@ -5022,7 +5063,7 @@ SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI,
   if (LandingPad) {
     // Insert a label at the end of the invoke call to mark the try range.  This
     // can be used to detect deletion of the invoke via the MachineModuleInfo.
-    MCSymbol *EndLabel = MMI.getContext().CreateTempSymbol();
+    MCSymbol *EndLabel = MMI.getContext().createTempSymbol();
     DAG.setRoot(DAG.getEHLabel(getCurSDLoc(), getRoot(), EndLabel));
 
     // Inform MachineModuleInfo of range.
@@ -5450,7 +5491,7 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) {
             return;
         }
       }
-      if (unsigned IID = F->getIntrinsicID()) {
+      if (Intrinsic::ID IID = F->getIntrinsicID()) {
         RenameFn = visitIntrinsicCall(I, IID);
         if (!RenameFn)
           return;
@@ -7437,7 +7478,7 @@ bool SelectionDAGBuilder::buildJumpTable(CaseClusterVector &Clusters,
   JumpTableHeader JTH(Clusters[First].Low->getValue(),
                       Clusters[Last].High->getValue(), SI->getCondition(),
                       nullptr, false);
-  JTCases.push_back(JumpTableBlock(JTH, JT));
+  JTCases.emplace_back(std::move(JTH), std::move(JT));
 
   JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
                                      JTCases.size() - 1, Weight);
@@ -7463,6 +7504,31 @@ void SelectionDAGBuilder::findJumpTables(CaseClusterVector &Clusters,
   const int64_t N = Clusters.size();
   const unsigned MinJumpTableSize = TLI.getMinimumJumpTableEntries();
 
+  // TotalCases[i]: Total nbr of cases in Clusters[0..i].
+  SmallVector<unsigned, 8> TotalCases(N);
+
+  for (unsigned i = 0; i < N; ++i) {
+    APInt Hi = Clusters[i].High->getValue();
+    APInt Lo = Clusters[i].Low->getValue();
+    TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
+    if (i != 0)
+      TotalCases[i] += TotalCases[i - 1];
+  }
+
+  if (N >= MinJumpTableSize && isDense(Clusters, &TotalCases[0], 0, N - 1)) {
+    // Cheap case: the whole range might be suitable for jump table.
+    CaseCluster JTCluster;
+    if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
+      Clusters[0] = JTCluster;
+      Clusters.resize(1);
+      return;
+    }
+  }
+
+  // The algorithm below is not suitable for -O0.
+  if (TM.getOptLevel() == CodeGenOpt::None)
+    return;
+
   // Split Clusters into minimum number of dense partitions. The algorithm uses
   // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
   // for the Case Statement'" (1994), but builds the MinPartitions array in
@@ -7476,16 +7542,6 @@ void SelectionDAGBuilder::findJumpTables(CaseClusterVector &Clusters,
   SmallVector<unsigned, 8> LastElement(N);
   // NumTables[i]: nbr of >= MinJumpTableSize partitions from Clusters[i..N-1].
   SmallVector<unsigned, 8> NumTables(N);
-  // TotalCases[i]: Total nbr of cases in Clusters[0..i].
-  SmallVector<unsigned, 8> TotalCases(N);
-
-  for (unsigned i = 0; i < N; ++i) {
-    APInt Hi = Clusters[i].High->getValue();
-    APInt Lo = Clusters[i].Low->getValue();
-    TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
-    if (i != 0)
-      TotalCases[i] += TotalCases[i - 1];
-  }
 
   // Base case: There is only one way to partition Clusters[N-1].
   MinPartitions[N - 1] = 1;
@@ -7600,7 +7656,7 @@ bool SelectionDAGBuilder::buildBitTests(CaseClusterVector &Clusters,
 
   const int BitWidth =
       DAG.getTargetLoweringInfo().getPointerTy().getSizeInBits();
-  assert((High - Low + 1).sle(BitWidth) && "Case range must fit in bit mask!");
+  assert(rangeFitsInWord(Low, High) && "Case range must fit in bit mask!");
 
   if (Low.isNonNegative() && High.slt(BitWidth)) {
     // Optimize the case where all the case values fit in a
@@ -7628,10 +7684,9 @@ bool SelectionDAGBuilder::buildBitTests(CaseClusterVector &Clusters,
     // Update Mask, Bits and ExtraWeight.
     uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
     uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
-    for (uint64_t j = Lo; j <= Hi; ++j) {
-      CB->Mask |= 1ULL << j;
-      CB->Bits++;
-    }
+    assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
+    CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
+    CB->Bits += Hi - Lo + 1;
     CB->ExtraWeight += Clusters[i].Weight;
     TotalWeight += Clusters[i].Weight;
     assert(TotalWeight >= Clusters[i].Weight && "Weight overflow!");
@@ -7650,9 +7705,9 @@ bool SelectionDAGBuilder::buildBitTests(CaseClusterVector &Clusters,
         FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
     BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraWeight));
   }
-  BitTestCases.push_back(BitTestBlock(LowBound, CmpRange, SI->getCondition(),
-                                      -1U, MVT::Other, false, nullptr,
-                                      nullptr, std::move(BTI)));
+  BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
+                            SI->getCondition(), -1U, MVT::Other, false, nullptr,
+                            nullptr, std::move(BTI));
 
   BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
                                     BitTestCases.size() - 1, TotalWeight);
@@ -7674,6 +7729,10 @@ void SelectionDAGBuilder::findBitTestClusters(CaseClusterVector &Clusters,
     assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
 #endif
 
+  // The algorithm below is not suitable for -O0.
+  if (TM.getOptLevel() == CodeGenOpt::None)
+    return;
+
   // If target does not have legal shift left, do not emit bit tests at all.
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   EVT PTy = TLI.getPointerTy();
@@ -7746,8 +7805,10 @@ void SelectionDAGBuilder::findBitTestClusters(CaseClusterVector &Clusters,
     if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
       Clusters[DstIndex++] = BitTestCluster;
     } else {
-      for (unsigned I = First; I <= Last; ++I)
-        std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
+      size_t NumClusters = Last - First + 1;
+      std::memmove(&Clusters[DstIndex], &Clusters[First],
+                   sizeof(Clusters[0]) * NumClusters);
+      DstIndex += NumClusters;
     }
   }
   Clusters.resize(DstIndex);
@@ -7783,22 +7844,17 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond,
       const APInt &BigValue = Big.Low->getValue();
 
       // Check that there is only one bit different.
-      if (BigValue.countPopulation() == SmallValue.countPopulation() + 1 &&
-          (SmallValue | BigValue) == BigValue) {
-        // Isolate the common bit.
-        APInt CommonBit = BigValue & ~SmallValue;
-        assert((SmallValue | CommonBit) == BigValue &&
-               CommonBit.countPopulation() == 1 && "Not a common bit?");
-
+      APInt CommonBit = BigValue ^ SmallValue;
+      if (CommonBit.isPowerOf2()) {
         SDValue CondLHS = getValue(Cond);
         EVT VT = CondLHS.getValueType();
         SDLoc DL = getCurSDLoc();
 
         SDValue Or = DAG.getNode(ISD::OR, DL, VT, CondLHS,
                                  DAG.getConstant(CommonBit, DL, VT));
-        SDValue Cond = DAG.getSetCC(DL, MVT::i1, Or,
-                                    DAG.getConstant(BigValue, DL, VT),
-                                    ISD::SETEQ);
+        SDValue Cond = DAG.getSetCC(
+            DL, MVT::i1, Or, DAG.getConstant(BigValue | SmallValue, DL, VT),
+            ISD::SETEQ);
 
         // Update successor info.
         // Both Small and Big will jump to Small.BB, so we sum up the weights.
@@ -7940,6 +7996,18 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond,
   }
 }
 
+unsigned SelectionDAGBuilder::caseClusterRank(const CaseCluster &CC,
+                                              CaseClusterIt First,
+                                              CaseClusterIt Last) {
+  return std::count_if(First, Last + 1, [&](const CaseCluster &X) {
+    if (X.Weight != CC.Weight)
+      return X.Weight > CC.Weight;
+
+    // Ties are broken by comparing the case value.
+    return X.Low->getValue().slt(CC.Low->getValue());
+  });
+}
+
 void SelectionDAGBuilder::splitWorkItem(SwitchWorkList &WorkList,
                                         const SwitchWorkListItem &W,
                                         Value *Cond,
@@ -7969,6 +8037,48 @@ void SelectionDAGBuilder::splitWorkItem(SwitchWorkList &WorkList,
       RightWeight += (--FirstRight)->Weight;
     I++;
   }
+
+  for (;;) {
+    // Our binary search tree differs from a typical BST in that ours can have up
+    // to three values in each leaf. The pivot selection above doesn't take that
+    // into account, which means the tree might require more nodes and be less
+    // efficient. We compensate for this here.
+
+    unsigned NumLeft = LastLeft - W.FirstCluster + 1;
+    unsigned NumRight = W.LastCluster - FirstRight + 1;
+
+    if (std::min(NumLeft, NumRight) < 3 && std::max(NumLeft, NumRight) > 3) {
+      // If one side has less than 3 clusters, and the other has more than 3,
+      // consider taking a cluster from the other side.
+
+      if (NumLeft < NumRight) {
+        // Consider moving the first cluster on the right to the left side.
+        CaseCluster &CC = *FirstRight;
+        unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
+        unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
+        if (LeftSideRank <= RightSideRank) {
+          // Moving the cluster to the left does not demote it.
+          ++LastLeft;
+          ++FirstRight;
+          continue;
+        }
+      } else {
+        assert(NumRight < NumLeft);
+        // Consider moving the last element on the left to the right side.
+        CaseCluster &CC = *LastLeft;
+        unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
+        unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
+        if (RightSideRank <= LeftSideRank) {
+          // Moving the cluster to the right does not demot it.
+          --LastLeft;
+          --FirstRight;
+          continue;
+        }
+      }
+    }
+    break;
+  }
+
   assert(LastLeft + 1 == FirstRight);
   assert(LastLeft >= W.FirstCluster);
   assert(FirstRight <= W.LastCluster);
@@ -8092,11 +8202,8 @@ void SelectionDAGBuilder::visitSwitch(const SwitchInst &SI) {
     return;
   }
 
-  if (TM.getOptLevel() != CodeGenOpt::None) {
-    findJumpTables(Clusters, &SI, DefaultMBB);
-    findBitTestClusters(Clusters, &SI);
-  }
-
+  findJumpTables(Clusters, &SI, DefaultMBB);
+  findBitTestClusters(Clusters, &SI);
 
   DEBUG({
     dbgs() << "Case clusters: ";