Let SelectionDAG start to use probability-based interface to add successors.
[oota-llvm.git] / lib / CodeGen / SelectionDAG / SelectionDAGBuilder.h
index ba6d9743d74ac0529f2fc9f2f314ad709fd0b42c..1171f0aad00ff9718eae33453342982f9389a1f1 100644 (file)
@@ -17,6 +17,8 @@
 #include "StatepointLowering.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
 #include "llvm/IR/CallSite.h"
@@ -29,7 +31,6 @@
 namespace llvm {
 
 class AddrSpaceCastInst;
-class AliasAnalysis;
 class AllocaInst;
 class BasicBlock;
 class BitCastInst;
@@ -153,39 +154,39 @@ private:
       unsigned JTCasesIndex;
       unsigned BTCasesIndex;
     };
-    uint32_t Weight;
+    BranchProbability Prob;
 
     static CaseCluster range(const ConstantInt *Low, const ConstantInt *High,
-                             MachineBasicBlock *MBB, uint32_t Weight) {
+                             MachineBasicBlock *MBB, BranchProbability Prob) {
       CaseCluster C;
       C.Kind = CC_Range;
       C.Low = Low;
       C.High = High;
       C.MBB = MBB;
-      C.Weight = Weight;
+      C.Prob = Prob;
       return C;
     }
 
     static CaseCluster jumpTable(const ConstantInt *Low,
                                  const ConstantInt *High, unsigned JTCasesIndex,
-                                 uint32_t Weight) {
+                                 BranchProbability Prob) {
       CaseCluster C;
       C.Kind = CC_JumpTable;
       C.Low = Low;
       C.High = High;
       C.JTCasesIndex = JTCasesIndex;
-      C.Weight = Weight;
+      C.Prob = Prob;
       return C;
     }
 
     static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High,
-                                unsigned BTCasesIndex, uint32_t Weight) {
+                                unsigned BTCasesIndex, BranchProbability Prob) {
       CaseCluster C;
       C.Kind = CC_BitTests;
       C.Low = Low;
       C.High = High;
       C.BTCasesIndex = BTCasesIndex;
-      C.Weight = Weight;
+      C.Prob = Prob;
       return C;
     }
   };
@@ -197,13 +198,13 @@ private:
     uint64_t Mask;
     MachineBasicBlock* BB;
     unsigned Bits;
-    uint32_t ExtraWeight;
+    BranchProbability ExtraProb;
 
     CaseBits(uint64_t mask, MachineBasicBlock* bb, unsigned bits,
-             uint32_t Weight):
-      Mask(mask), BB(bb), Bits(bits), ExtraWeight(Weight) { }
+             BranchProbability Prob):
+      Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) { }
 
-    CaseBits() : Mask(0), BB(nullptr), Bits(0), ExtraWeight(0) {}
+    CaseBits() : Mask(0), BB(nullptr), Bits(0) {}
   };
 
   typedef std::vector<CaseBits> CaseBitsVector;
@@ -216,13 +217,13 @@ private:
   /// blocks needed by multi-case switch statements.
   struct CaseBlock {
     CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
-              const Value *cmpmiddle,
-              MachineBasicBlock *truebb, MachineBasicBlock *falsebb,
-              MachineBasicBlock *me,
-              uint32_t trueweight = 0, uint32_t falseweight = 0)
-      : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
-        TrueBB(truebb), FalseBB(falsebb), ThisBB(me),
-        TrueWeight(trueweight), FalseWeight(falseweight) { }
+              const Value *cmpmiddle, MachineBasicBlock *truebb,
+              MachineBasicBlock *falsebb, MachineBasicBlock *me,
+              BranchProbability trueprob = BranchProbability::getUnknown(),
+              BranchProbability falseprob = BranchProbability::getUnknown())
+        : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
+          TrueBB(truebb), FalseBB(falsebb), ThisBB(me), TrueProb(trueprob),
+          FalseProb(falseprob) {}
 
     // CC - the condition code to use for the case block's setcc node
     ISD::CondCode CC;
@@ -238,8 +239,8 @@ private:
     // ThisBB - the block into which to emit the code for the setcc and branches
     MachineBasicBlock *ThisBB;
 
-    // TrueWeight/FalseWeight - branch weights.
-    uint32_t TrueWeight, FalseWeight;
+    // TrueProb/FalseProb - branch weights.
+    BranchProbability TrueProb, FalseProb;
   };
 
   struct JumpTable {
@@ -271,32 +272,35 @@ private:
 
   struct BitTestCase {
     BitTestCase(uint64_t M, MachineBasicBlock* T, MachineBasicBlock* Tr,
-                uint32_t Weight):
-      Mask(M), ThisBB(T), TargetBB(Tr), ExtraWeight(Weight) { }
+                BranchProbability Prob):
+      Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) { }
     uint64_t Mask;
     MachineBasicBlock *ThisBB;
     MachineBasicBlock *TargetBB;
-    uint32_t ExtraWeight;
+    BranchProbability ExtraProb;
   };
 
   typedef SmallVector<BitTestCase, 3> BitTestInfo;
 
   struct BitTestBlock {
-    BitTestBlock(APInt F, APInt R, const Value* SV,
-                 unsigned Rg, MVT RgVT, bool E,
-                 MachineBasicBlock* P, MachineBasicBlock* D,
-                 BitTestInfo C):
-      First(F), Range(R), SValue(SV), Reg(Rg), RegVT(RgVT), Emitted(E),
-      Parent(P), Default(D), Cases(std::move(C)) { }
+    BitTestBlock(APInt F, APInt R, const Value *SV, unsigned Rg, MVT RgVT,
+                 bool E, bool CR, MachineBasicBlock *P, MachineBasicBlock *D,
+                 BitTestInfo C, BranchProbability Pr)
+        : First(F), Range(R), SValue(SV), Reg(Rg), RegVT(RgVT), Emitted(E),
+          ContiguousRange(CR), Parent(P), Default(D), Cases(std::move(C)),
+          Prob(Pr) {}
     APInt First;
     APInt Range;
     const Value *SValue;
     unsigned Reg;
     MVT RegVT;
     bool Emitted;
+    bool ContiguousRange;
     MachineBasicBlock *Parent;
     MachineBasicBlock *Default;
     BitTestInfo Cases;
+    BranchProbability Prob;
+    BranchProbability DefaultProb;
   };
 
   /// Minimum jump table density, in percent.
@@ -338,9 +342,15 @@ private:
     CaseClusterIt LastCluster;
     const ConstantInt *GE;
     const ConstantInt *LT;
+    BranchProbability DefaultProb;
   };
   typedef SmallVector<SwitchWorkListItem, 4> SwitchWorkList;
 
+  /// Determine the rank by weight of CC in [First,Last]. If CC has more weight
+  /// than each cluster in the range, its rank is 0.
+  static unsigned caseClusterRank(const CaseCluster &CC, CaseClusterIt First,
+                                  CaseClusterIt Last);
+
   /// Emit comparison and split W into two subtrees.
   void splitWorkItem(SwitchWorkList &WorkList, const SwitchWorkListItem &W,
                      Value *Cond, MachineBasicBlock *SwitchMBB);
@@ -509,6 +519,7 @@ private:
     void resetPerFunctionState() {
       FailureMBB = nullptr;
       Guard = nullptr;
+      GuardReg = 0;
     }
 
     MachineBasicBlock *getParentMBB() { return ParentMBB; }
@@ -586,10 +597,6 @@ public:
   ///
   FunctionLoweringInfo &FuncInfo;
 
-  /// OptLevel - What optimization level we're generating code for.
-  ///
-  CodeGenOpt::Level OptLevel;
-
   /// GFI - Garbage collection metadata for the function.
   GCFunctionInfo *GFI;
 
@@ -607,7 +614,7 @@ public:
   SelectionDAGBuilder(SelectionDAG &dag, FunctionLoweringInfo &funcinfo,
                       CodeGenOpt::Level ol)
     : CurInst(nullptr), SDNodeOrder(LowestSDNodeOrder), TM(dag.getTarget()),
-      DAG(dag), FuncInfo(funcinfo), OptLevel(ol),
+      DAG(dag), FuncInfo(funcinfo),
       HasTailCall(false) {
   }
 
@@ -667,6 +674,8 @@ public:
   // generate the debug data structures now that we've seen its definition.
   void resolveDanglingDebugInfo(const Value *V, SDValue Val);
   SDValue getValue(const Value *V);
+  bool findValue(const Value *V) const;
+
   SDValue getNonRegisterValue(const Value *V);
   SDValue getValueImpl(const Value *V);
 
@@ -676,12 +685,6 @@ public:
     N = NewN;
   }
 
-  void removeValue(const Value *V) {
-    // This is to support hack in lowerCallFromStatepoint
-    // Should be removed when hack is resolved
-    NodeMap.erase(V);
-  }
-
   void setUnusedArgValue(const Value *V, SDValue NewN) {
     SDValue &N = UnusedArgNodeMap[V];
     assert(!N.getNode() && "Already set a value for this node!");
@@ -690,27 +693,28 @@ public:
 
   void FindMergedConditions(const Value *Cond, MachineBasicBlock *TBB,
                             MachineBasicBlock *FBB, MachineBasicBlock *CurBB,
-                            MachineBasicBlock *SwitchBB, unsigned Opc,
-                            uint32_t TW, uint32_t FW);
+                            MachineBasicBlock *SwitchBB,
+                            Instruction::BinaryOps Opc, BranchProbability TW,
+                            BranchProbability FW);
   void EmitBranchForMergedCondition(const Value *Cond, MachineBasicBlock *TBB,
                                     MachineBasicBlock *FBB,
                                     MachineBasicBlock *CurBB,
                                     MachineBasicBlock *SwitchBB,
-                                    uint32_t TW, uint32_t FW);
+                                    BranchProbability TW, BranchProbability FW);
   bool ShouldEmitAsBranches(const std::vector<CaseBlock> &Cases);
   bool isExportableFromCurrentBlock(const Value *V, const BasicBlock *FromBB);
   void CopyToExportRegsIfNeeded(const Value *V);
   void ExportFromCurrentBlock(const Value *V);
   void LowerCallTo(ImmutableCallSite CS, SDValue Callee, bool IsTailCall,
-                   MachineBasicBlock *LandingPad = nullptr);
+                   const BasicBlock *EHPadBB = nullptr);
 
   std::pair<SDValue, SDValue> lowerCallOperands(
           ImmutableCallSite CS,
           unsigned ArgIdx,
           unsigned NumArgs,
           SDValue Callee,
-          bool UseVoidTy = false,
-          MachineBasicBlock *LandingPad = nullptr,
+          Type *ReturnTy,
+          const BasicBlock *EHPadBB = nullptr,
           bool IsPatchPoint = false);
 
   /// UpdateSplitBlock - When an MBB was split during scheduling, update the
@@ -720,11 +724,11 @@ public:
   // This function is responsible for the whole statepoint lowering process.
   // It uniformly handles invoke and call statepoints.
   void LowerStatepoint(ImmutableStatepoint Statepoint,
-                       MachineBasicBlock *LandingPad = nullptr);
+                       const BasicBlock *EHPadBB = nullptr);
 private:
-  std::pair<SDValue, SDValue> lowerInvokable(
-          TargetLowering::CallLoweringInfo &CLI,
-          MachineBasicBlock *LandingPad);
+  std::pair<SDValue, SDValue>
+  lowerInvokable(TargetLowering::CallLoweringInfo &CLI,
+                 const BasicBlock *EHPadBB = nullptr);
 
   // Terminator instructions.
   void visitRet(const ReturnInst &I);
@@ -732,11 +736,20 @@ private:
   void visitSwitch(const SwitchInst &I);
   void visitIndirectBr(const IndirectBrInst &I);
   void visitUnreachable(const UnreachableInst &I);
+  void visitCleanupEndPad(const CleanupEndPadInst &I);
+  void visitCleanupRet(const CleanupReturnInst &I);
+  void visitCatchEndPad(const CatchEndPadInst &I);
+  void visitCatchRet(const CatchReturnInst &I);
+  void visitCatchPad(const CatchPadInst &I);
+  void visitTerminatePad(const TerminatePadInst &TPI);
+  void visitCleanupPad(const CleanupPadInst &CPI);
+
+  BranchProbability getEdgeProbability(const MachineBasicBlock *Src,
+                                       const MachineBasicBlock *Dst) const;
+  void addSuccessorWithProb(
+      MachineBasicBlock *Src, MachineBasicBlock *Dst,
+      BranchProbability Prob = BranchProbability::getUnknown());
 
-  uint32_t getEdgeWeight(const MachineBasicBlock *Src,
-                         const MachineBasicBlock *Dst) const;
-  void addSuccessorWithWeight(MachineBasicBlock *Src, MachineBasicBlock *Dst,
-                              uint32_t Weight = 0);
 public:
   void visitSwitchCase(CaseBlock &CB,
                        MachineBasicBlock *SwitchBB);
@@ -746,15 +759,13 @@ public:
   void visitBitTestHeader(BitTestBlock &B, MachineBasicBlock *SwitchBB);
   void visitBitTestCase(BitTestBlock &BB,
                         MachineBasicBlock* NextMBB,
-                        uint32_t BranchWeightToNext,
+                        BranchProbability BranchProbToNext,
                         unsigned Reg,
                         BitTestCase &B,
                         MachineBasicBlock *SwitchBB);
   void visitJumpTable(JumpTable &JT);
   void visitJumpTableHeader(JumpTable &JT, JumpTableHeader &JTH,
                             MachineBasicBlock *SwitchBB);
-  unsigned visitLandingPadClauseBB(GlobalValue *ClauseGV,
-                                   MachineBasicBlock *LPadMBB);
 
 private:
   // These all get lowered before this pass.
@@ -814,6 +825,8 @@ private:
   void visitStore(const StoreInst &I);
   void visitMaskedLoad(const CallInst &I);
   void visitMaskedStore(const CallInst &I);
+  void visitMaskedGather(const CallInst &I);
+  void visitMaskedScatter(const CallInst &I);
   void visitAtomicCmpXchg(const AtomicCmpXchgInst &I);
   void visitAtomicRMW(const AtomicRMWInst &I);
   void visitFence(const FenceInst &I);
@@ -840,7 +853,7 @@ private:
   void visitVACopy(const CallInst &I);
   void visitStackmap(const CallInst &I);
   void visitPatchpoint(ImmutableCallSite CS,
-                       MachineBasicBlock *LandingPad = nullptr);
+                       const BasicBlock *EHPadBB = nullptr);
 
   // These three are implemented in StatepointLowering.cpp
   void visitStatepoint(const CallInst &I);
@@ -862,8 +875,8 @@ private:
   /// EmitFuncArgumentDbgValue - If V is an function argument then create
   /// corresponding DBG_VALUE machine instruction for it now. At the end of
   /// instruction selection, they will be inserted to the entry BB.
-  bool EmitFuncArgumentDbgValue(const Value *V, MDLocalVariable *Variable,
-                                MDExpression *Expr, MDLocation *DL,
+  bool EmitFuncArgumentDbgValue(const Value *V, DILocalVariable *Variable,
+                                DIExpression *Expr, DILocation *DL,
                                 int64_t Offset, bool IsIndirect,
                                 const SDValue &N);
 
@@ -875,6 +888,80 @@ private:
   void updateDAGForMaybeTailCall(SDValue MaybeTC);
 };
 
+/// RegsForValue - This struct represents the registers (physical or virtual)
+/// that a particular set of values is assigned, and the type information about
+/// the value. The most common situation is to represent one value at a time,
+/// but struct or array values are handled element-wise as multiple values.  The
+/// splitting of aggregates is performed recursively, so that we never have
+/// aggregate-typed registers. The values at this point do not necessarily have
+/// legal types, so each value may require one or more registers of some legal
+/// type.
+///
+struct RegsForValue {
+  /// ValueVTs - The value types of the values, which may not be legal, and
+  /// may need be promoted or synthesized from one or more registers.
+  ///
+  SmallVector<EVT, 4> ValueVTs;
+
+  /// RegVTs - The value types of the registers. This is the same size as
+  /// ValueVTs and it records, for each value, what the type of the assigned
+  /// register or registers are. (Individual values are never synthesized
+  /// from more than one type of register.)
+  ///
+  /// With virtual registers, the contents of RegVTs is redundant with TLI's
+  /// getRegisterType member function, however when with physical registers
+  /// it is necessary to have a separate record of the types.
+  ///
+  SmallVector<MVT, 4> RegVTs;
+
+  /// Regs - This list holds the registers assigned to the values.
+  /// Each legal or promoted value requires one register, and each
+  /// expanded value requires multiple registers.
+  ///
+  SmallVector<unsigned, 4> Regs;
+
+  RegsForValue();
+
+  RegsForValue(const SmallVector<unsigned, 4> &regs, MVT regvt, EVT valuevt);
+
+  RegsForValue(LLVMContext &Context, const TargetLowering &TLI,
+               const DataLayout &DL, unsigned Reg, Type *Ty);
+
+  /// append - Add the specified values to this one.
+  void append(const RegsForValue &RHS) {
+    ValueVTs.append(RHS.ValueVTs.begin(), RHS.ValueVTs.end());
+    RegVTs.append(RHS.RegVTs.begin(), RHS.RegVTs.end());
+    Regs.append(RHS.Regs.begin(), RHS.Regs.end());
+  }
+
+  /// getCopyFromRegs - Emit a series of CopyFromReg nodes that copies from
+  /// this value and returns the result as a ValueVTs value.  This uses
+  /// Chain/Flag as the input and updates them for the output Chain/Flag.
+  /// If the Flag pointer is NULL, no flag is used.
+  SDValue getCopyFromRegs(SelectionDAG &DAG, FunctionLoweringInfo &FuncInfo,
+                          SDLoc dl,
+                          SDValue &Chain, SDValue *Flag,
+                          const Value *V = nullptr) const;
+
+  /// getCopyToRegs - Emit a series of CopyToReg nodes that copies the specified
+  /// value into the registers specified by this object.  This uses Chain/Flag
+  /// as the input and updates them for the output Chain/Flag.  If the Flag
+  /// pointer is nullptr, no flag is used.  If V is not nullptr, then it is used
+  /// in printing better diagnostic messages on error.
+  void
+  getCopyToRegs(SDValue Val, SelectionDAG &DAG, SDLoc dl, SDValue &Chain,
+                SDValue *Flag, const Value *V = nullptr,
+                ISD::NodeType PreferredExtendType = ISD::ANY_EXTEND) const;
+
+  /// AddInlineAsmOperands - Add this value to the specified inlineasm node
+  /// operand list.  This adds the code marker, matching input operand index
+  /// (if applicable), and includes the number of values added into it.
+  void AddInlineAsmOperands(unsigned Kind,
+                            bool HasMatching, unsigned MatchingIdx, SDLoc dl,
+                            SelectionDAG &DAG,
+                            std::vector<SDValue> &Ops) const;
+};
+
 } // end namespace llvm
 
 #endif