[DivergenceAnalysis] fix a bug in computing influence regions
[oota-llvm.git] / lib / Analysis / DivergenceAnalysis.cpp
index f3fc7844020e66685b8aec5c590287486fe7a452..5ae6d74130a7a770c7d743256d5e80f01d338eb6 100644 (file)
@@ -1,4 +1,4 @@
-//===- DivergenceAnalysis.cpp ------ Divergence Analysis ------------------===//
+//===- DivergenceAnalysis.cpp --------- Divergence Analysis Implementation -==//
 //
 //                     The LLVM Compiler Infrastructure
 //
@@ -7,8 +7,8 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file defines divergence analysis which determines whether a branch in a
-// GPU program is divergent. It can help branch optimizations such as jump
+// This file implements divergence analysis which determines whether a branch
+// in a GPU program is divergent.It can help branch optimizations such as jump
 // threading and loop unswitching to make better decisions.
 //
 // GPU programs typically use the SIMD execution model, where multiple threads
 // 2. memory as black box. It conservatively considers values loaded from
 //    generic or local address as divergent. This can be improved by leveraging
 //    pointer analysis.
+//
 //===----------------------------------------------------------------------===//
 
-#include <vector>
-#include "llvm/IR/Dominators.h"
-#include "llvm/ADT/DenseSet.h"
+#include "llvm/Analysis/DivergenceAnalysis.h"
 #include "llvm/Analysis/Passes.h"
 #include "llvm/Analysis/PostDominators.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
-#include "llvm/IR/Function.h"
+#include "llvm/IR/Dominators.h"
 #include "llvm/IR/InstIterator.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Value.h"
-#include "llvm/Pass.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Scalar.h"
+#include <vector>
 using namespace llvm;
 
-#define DEBUG_TYPE "divergence"
-
-namespace {
-class DivergenceAnalysis : public FunctionPass {
-public:
-  static char ID;
-
-  DivergenceAnalysis() : FunctionPass(ID) {
-    initializeDivergenceAnalysisPass(*PassRegistry::getPassRegistry());
-  }
-
-  void getAnalysisUsage(AnalysisUsage &AU) const override {
-    AU.addRequired<DominatorTreeWrapperPass>();
-    AU.addRequired<PostDominatorTree>();
-    AU.setPreservesAll();
-  }
-
-  bool runOnFunction(Function &F) override;
-
-  // Print all divergent branches in the function.
-  void print(raw_ostream &OS, const Module *) const override;
-
-  // Returns true if V is divergent.
-  bool isDivergent(const Value *V) const { return DivergentValues.count(V); }
-  // Returns true if V is uniform/non-divergent.
-  bool isUniform(const Value *V) const { return !isDivergent(V); }
-
-private:
-  // Stores all divergent values.
-  DenseSet<const Value *> DivergentValues;
-};
-} // End of anonymous namespace
-
-// Register this pass.
-char DivergenceAnalysis::ID = 0;
-INITIALIZE_PASS_BEGIN(DivergenceAnalysis, "divergence", "Divergence Analysis",
-                      false, true)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(PostDominatorTree)
-INITIALIZE_PASS_END(DivergenceAnalysis, "divergence", "Divergence Analysis",
-                    false, true)
-
 namespace {
 
 class DivergencePropagator {
 public:
-  DivergencePropagator(Function &F, TargetTransformInfo &TTI,
-                       DominatorTree &DT, PostDominatorTree &PDT,
-                       DenseSet<const Value *> &DV)
+  DivergencePropagator(Function &F, TargetTransformInfo &TTI, DominatorTree &DT,
+                       PostDominatorTree &PDT, DenseSet<const Value *> &DV)
       : F(F), TTI(TTI), DT(DT), PDT(PDT), DV(DV) {}
   void populateWithSourcesOfDivergence();
   void propagate();
@@ -140,7 +96,7 @@ private:
   // A helper function that explores sync dependents of TI.
   void exploreSyncDependency(TerminatorInst *TI);
   // Computes the influence region from Start to End. This region includes all
-  // basic blocks on any path from Start to End.
+  // basic blocks on any simple path from Start to End.
   void computeInfluenceRegion(BasicBlock *Start, BasicBlock *End,
                               DenseSet<BasicBlock *> &InfluenceRegion);
   // Finds all users of I that are outside the influence region, and add these
@@ -153,7 +109,7 @@ private:
   DominatorTree &DT;
   PostDominatorTree &PDT;
   std::vector<Value *> Worklist; // Stack for DFS.
-  DenseSet<const Value *> &DV; // Stores all divergent values.
+  DenseSet<const Value *> &DV;   // Stores all divergent values.
 };
 
 void DivergencePropagator::populateWithSourcesOfDivergence() {
@@ -191,8 +147,8 @@ void DivergencePropagator::exploreSyncDependency(TerminatorInst *TI) {
   for (auto I = IPostDom->begin(); isa<PHINode>(I); ++I) {
     // A PHINode is uniform if it returns the same value no matter which path is
     // taken.
-    if (!cast<PHINode>(I)->hasConstantValue() && DV.insert(I).second)
-      Worklist.push_back(I);
+    if (!cast<PHINode>(I)->hasConstantValue() && DV.insert(&*I).second)
+      Worklist.push_back(&*I);
   }
 
   // Propagation rule 2: if a value defined in a loop is used outside, the user
@@ -242,21 +198,33 @@ void DivergencePropagator::findUsersOutsideInfluenceRegion(
   }
 }
 
+// A helper function for computeInfluenceRegion that adds successors of "ThisBB"
+// to the influence region.
+static void
+addSuccessorsToInfluenceRegion(BasicBlock *ThisBB, BasicBlock *End,
+                               DenseSet<BasicBlock *> &InfluenceRegion,
+                               std::vector<BasicBlock *> &InfluenceStack) {
+  for (BasicBlock *Succ : successors(ThisBB)) {
+    if (Succ != End && InfluenceRegion.insert(Succ).second)
+      InfluenceStack.push_back(Succ);
+  }
+}
+
 void DivergencePropagator::computeInfluenceRegion(
     BasicBlock *Start, BasicBlock *End,
     DenseSet<BasicBlock *> &InfluenceRegion) {
   assert(PDT.properlyDominates(End, Start) &&
          "End does not properly dominate Start");
+
+  // The influence region starts from the end of "Start" to the beginning of
+  // "End". Therefore, "Start" should not be in the region unless "Start" is in
+  // a loop that doesn't contain "End".
   std::vector<BasicBlock *> InfluenceStack;
-  InfluenceStack.push_back(Start);
-  InfluenceRegion.insert(Start);
+  addSuccessorsToInfluenceRegion(Start, End, InfluenceRegion, InfluenceStack);
   while (!InfluenceStack.empty()) {
     BasicBlock *BB = InfluenceStack.back();
     InfluenceStack.pop_back();
-    for (BasicBlock *Succ : successors(BB)) {
-      if (End != Succ && InfluenceRegion.insert(Succ).second)
-        InfluenceStack.push_back(Succ);
-    }
+    addSuccessorsToInfluenceRegion(BB, End, InfluenceRegion, InfluenceStack);
   }
 }
 
@@ -286,10 +254,25 @@ void DivergencePropagator::propagate() {
 
 } /// end namespace anonymous
 
+// Register this pass.
+char DivergenceAnalysis::ID = 0;
+INITIALIZE_PASS_BEGIN(DivergenceAnalysis, "divergence", "Divergence Analysis",
+                      false, true)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(PostDominatorTree)
+INITIALIZE_PASS_END(DivergenceAnalysis, "divergence", "Divergence Analysis",
+                    false, true)
+
 FunctionPass *llvm::createDivergenceAnalysisPass() {
   return new DivergenceAnalysis();
 }
 
+void DivergenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.addRequired<DominatorTreeWrapperPass>();
+  AU.addRequired<PostDominatorTree>();
+  AU.setPreservesAll();
+}
+
 bool DivergenceAnalysis::runOnFunction(Function &F) {
   auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();
   if (TTIWP == nullptr)