[DivergenceAnalysis] Separated definition of class into header.
[oota-llvm.git] / lib / Analysis / DivergenceAnalysis.cpp
index e5ee2959c15d9e6c99e529ae7cf54fcc10a042a4..c24f38a9c61712e8a6ee3d4aa47a94acd2918398 100644 (file)
@@ -1,4 +1,4 @@
-//===- DivergenceAnalysis.cpp ------ Divergence Analysis ------------------===//
+//===- DivergenceAnalysis.cpp --------- Divergence Analysis Implementation -==//
 //
 //                     The LLVM Compiler Infrastructure
 //
@@ -7,8 +7,8 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file defines divergence analysis which determines whether a branch in a
-// GPU program is divergent. It can help branch optimizations such as jump
+// This file implements divergence analysis which determines whether a branch
+// in a GPU program is divergent.It can help branch optimizations such as jump
 // threading and loop unswitching to make better decisions.
 //
 // GPU programs typically use the SIMD execution model, where multiple threads
 // 2. memory as black box. It conservatively considers values loaded from
 //    generic or local address as divergent. This can be improved by leveraging
 //    pointer analysis.
+//
 //===----------------------------------------------------------------------===//
 
-#include <vector>
-#include "llvm/IR/Dominators.h"
-#include "llvm/ADT/DenseSet.h"
+#include "llvm/Analysis/DivergenceAnalysis.h"
 #include "llvm/Analysis/Passes.h"
 #include "llvm/Analysis/PostDominators.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
-#include "llvm/IR/Function.h"
+#include "llvm/IR/Dominators.h"
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Value.h"
-#include "llvm/Pass.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Scalar.h"
+#include <vector>
 using namespace llvm;
 
-#define DEBUG_TYPE "divergence"
-
-namespace {
-class DivergenceAnalysis : public FunctionPass {
-public:
-  static char ID;
-
-  DivergenceAnalysis() : FunctionPass(ID) {
-    initializeDivergenceAnalysisPass(*PassRegistry::getPassRegistry());
-  }
-
-  void getAnalysisUsage(AnalysisUsage &AU) const override {
-    AU.addRequired<DominatorTreeWrapperPass>();
-    AU.addRequired<PostDominatorTree>();
-    AU.setPreservesAll();
-  }
-
-  bool runOnFunction(Function &F) override;
-
-  // Print all divergent branches in the function.
-  void print(raw_ostream &OS, const Module *) const override;
-
-  // Returns true if V is divergent.
-  bool isDivergent(const Value *V) const { return DivergentValues.count(V); }
-  // Returns true if V is uniform/non-divergent.
-  bool isUniform(const Value *V) const { return !isDivergent(V); }
-
-private:
-  // Stores all divergent values.
-  DenseSet<const Value *> DivergentValues;
-};
-} // End of anonymous namespace
-
-// Register this pass.
-char DivergenceAnalysis::ID = 0;
-INITIALIZE_PASS_BEGIN(DivergenceAnalysis, "divergence", "Divergence Analysis",
-                      false, true)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(PostDominatorTree)
-INITIALIZE_PASS_END(DivergenceAnalysis, "divergence", "Divergence Analysis",
-                    false, true)
-
 namespace {
 
 class DivergencePropagator {
 public:
-  DivergencePropagator(Function &F, TargetTransformInfo &TTI,
-                       DominatorTree &DT, PostDominatorTree &PDT,
-                       DenseSet<const Value *> &DV)
+  DivergencePropagator(Function &F, TargetTransformInfo &TTI, DominatorTree &DT,
+                       PostDominatorTree &PDT, DenseSet<const Value *> &DV)
       : F(F), TTI(TTI), DT(DT), PDT(PDT), DV(DV) {}
   void populateWithSourcesOfDivergence();
   void propagate();
@@ -153,13 +109,13 @@ private:
   DominatorTree &DT;
   PostDominatorTree &PDT;
   std::vector<Value *> Worklist; // Stack for DFS.
-  DenseSet<const Value *> &DV; // Stores all divergent values.
+  DenseSet<const Value *> &DV;   // Stores all divergent values.
 };
 
 void DivergencePropagator::populateWithSourcesOfDivergence() {
   Worklist.clear();
   DV.clear();
-  for (auto &I : inst_range(F)) {
+  for (auto &I : instructions(F)) {
     if (TTI.isSourceOfDivergence(&I)) {
       Worklist.push_back(&I);
       DV.insert(&I);
@@ -286,10 +242,25 @@ void DivergencePropagator::propagate() {
 
 } /// end namespace anonymous
 
+// Register this pass.
+char DivergenceAnalysis::ID = 0;
+INITIALIZE_PASS_BEGIN(DivergenceAnalysis, "divergence", "Divergence Analysis",
+                      false, true)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(PostDominatorTree)
+INITIALIZE_PASS_END(DivergenceAnalysis, "divergence", "Divergence Analysis",
+                    false, true)
+
 FunctionPass *llvm::createDivergenceAnalysisPass() {
   return new DivergenceAnalysis();
 }
 
+void DivergenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.addRequired<DominatorTreeWrapperPass>();
+  AU.addRequired<PostDominatorTree>();
+  AU.setPreservesAll();
+}
+
 bool DivergenceAnalysis::runOnFunction(Function &F) {
   auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
   if (TTIWP == nullptr)
@@ -329,8 +300,8 @@ void DivergenceAnalysis::print(raw_ostream &OS, const Module *) const {
     if (DivergentValues.count(&Arg))
       OS << "DIVERGENT:  " << Arg << "\n";
   }
-  // Iterate instructions using inst_range to ensure a deterministic order.
-  for (auto &I : inst_range(F)) {
+  // Iterate instructions using instructions() to ensure a deterministic order.
+  for (auto &I : instructions(F)) {
     if (DivergentValues.count(&I))
       OS << "DIVERGENT:" << I << "\n";
   }