Divergence analysis for GPU programs
[oota-llvm.git] / include / llvm / Analysis / TargetTransformInfo.h
index 15eab52d6a144d848d3b55466e092bfc3c92e522..1a7f6f6cfc04fdb55a50f6929154e91a57346854 100644 (file)
 #define LLVM_ANALYSIS_TARGETTRANSFORMINFO_H
 
 #include "llvm/ADT/Optional.h"
-#include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/DataTypes.h"
+#include <functional>
 
 namespace llvm {
 
@@ -189,12 +190,21 @@ public:
   /// comments for a detailed explanation of the cost values.
   unsigned getUserCost(const User *U) const;
 
-  /// \brief hasBranchDivergence - Return true if branch divergence exists.
+  /// \brief Return true if branch divergence exists.
+  ///
   /// Branch divergence has a significantly negative impact on GPU performance
   /// when threads in the same wavefront take different paths due to conditional
   /// branches.
   bool hasBranchDivergence() const;
 
+  /// \brief Returns whether V is a source of divergence.
+  ///
+  /// This function provides the target-dependent information for
+  /// the target-independent DivergenceAnalysis. DivergenceAnalysis first
+  /// builds the dependency graph, and then runs the reachability algorithm
+  /// starting with the sources of divergence.
+  bool isSourceOfDivergence(const Value *V) const;
+
   /// \brief Test whether calls to a function lower to actual program function
   /// calls.
   ///
@@ -313,6 +323,10 @@ public:
   /// by referencing its sub-register AX.
   bool isTruncateFree(Type *Ty1, Type *Ty2) const;
 
+  /// \brief Return true if it is profitable to hoist instruction in the
+  /// then/else to before if.
+  bool isProfitableToHoist(Instruction *I) const;
+
   /// \brief Return true if this type is legal.
   bool isTypeLegal(Type *Ty) const;
 
@@ -326,6 +340,9 @@ public:
   /// target.
   bool shouldBuildLookupTables() const;
 
+  /// \brief Don't restrict interleaved unrolling to small loops.
+  bool enableAggressiveInterleaving(bool LoopHasReductions) const;
+
   /// \brief Return hardware support for population count.
   PopcntSupportKind getPopcntSupport(unsigned IntTyWidthInBit) const;
 
@@ -444,6 +461,10 @@ public:
   unsigned getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy,
                                  ArrayRef<Type *> Tys) const;
 
+  /// \returns The cost of Call instructions.
+  unsigned getCallInstrCost(Function *F, Type *RetTy,
+                            ArrayRef<Type *> Tys) const;
+
   /// \returns The number of pieces into which the provided type must be
   /// split during legalization. Zero is returned when the answer is unknown.
   unsigned getNumberOfParts(Type *Tp) const;
@@ -508,6 +529,7 @@ public:
                                     ArrayRef<const Value *> Arguments) = 0;
   virtual unsigned getUserCost(const User *U) = 0;
   virtual bool hasBranchDivergence() = 0;
+  virtual bool isSourceOfDivergence(const Value *V) = 0;
   virtual bool isLoweredToCall(const Function *F) = 0;
   virtual void getUnrollingPreferences(Loop *L, UnrollingPreferences &UP) = 0;
   virtual bool isLegalAddImmediate(int64_t Imm) = 0;
@@ -521,10 +543,12 @@ public:
                                    int64_t BaseOffset, bool HasBaseReg,
                                    int64_t Scale) = 0;
   virtual bool isTruncateFree(Type *Ty1, Type *Ty2) = 0;
+  virtual bool isProfitableToHoist(Instruction *I) = 0;
   virtual bool isTypeLegal(Type *Ty) = 0;
   virtual unsigned getJumpBufAlignment() = 0;
   virtual unsigned getJumpBufSize() = 0;
   virtual bool shouldBuildLookupTables() = 0;
+  virtual bool enableAggressiveInterleaving(bool LoopHasReductions) = 0;
   virtual PopcntSupportKind getPopcntSupport(unsigned IntTyWidthInBit) = 0;
   virtual bool haveFastSqrt(Type *Ty) = 0;
   virtual unsigned getFPOpCost(Type *Ty) = 0;
@@ -559,6 +583,8 @@ public:
                                     bool IsPairwiseForm) = 0;
   virtual unsigned getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy,
                                          ArrayRef<Type *> Tys) = 0;
+  virtual unsigned getCallInstrCost(Function *F, Type *RetTy,
+                                    ArrayRef<Type *> Tys) = 0;
   virtual unsigned getNumberOfParts(Type *Tp) = 0;
   virtual unsigned getAddressComputationCost(Type *Ty, bool IsComplex) = 0;
   virtual unsigned getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) = 0;
@@ -603,6 +629,9 @@ public:
   }
   unsigned getUserCost(const User *U) override { return Impl.getUserCost(U); }
   bool hasBranchDivergence() override { return Impl.hasBranchDivergence(); }
+  bool isSourceOfDivergence(const Value *V) override {
+    return Impl.isSourceOfDivergence(V);
+  }
   bool isLoweredToCall(const Function *F) override {
     return Impl.isLoweredToCall(F);
   }
@@ -633,12 +662,18 @@ public:
   bool isTruncateFree(Type *Ty1, Type *Ty2) override {
     return Impl.isTruncateFree(Ty1, Ty2);
   }
+  bool isProfitableToHoist(Instruction *I) override {
+    return Impl.isProfitableToHoist(I);
+  }
   bool isTypeLegal(Type *Ty) override { return Impl.isTypeLegal(Ty); }
   unsigned getJumpBufAlignment() override { return Impl.getJumpBufAlignment(); }
   unsigned getJumpBufSize() override { return Impl.getJumpBufSize(); }
   bool shouldBuildLookupTables() override {
     return Impl.shouldBuildLookupTables();
   }
+  bool enableAggressiveInterleaving(bool LoopHasReductions) override {
+    return Impl.enableAggressiveInterleaving(LoopHasReductions);
+  }
   PopcntSupportKind getPopcntSupport(unsigned IntTyWidthInBit) override {
     return Impl.getPopcntSupport(IntTyWidthInBit);
   }
@@ -710,6 +745,10 @@ public:
                                  ArrayRef<Type *> Tys) override {
     return Impl.getIntrinsicInstrCost(ID, RetTy, Tys);
   }
+  unsigned getCallInstrCost(Function *F, Type *RetTy,
+                            ArrayRef<Type *> Tys) override {
+    return Impl.getCallInstrCost(F, RetTy, Tys);
+  }
   unsigned getNumberOfParts(Type *Tp) override {
     return Impl.getNumberOfParts(Tp);
   }