Remove functions from the FnSet when one of their callee's is being merged. This
authorNick Lewycky <nicholas@mxc.ca>
Sun, 2 Jan 2011 02:46:33 +0000 (02:46 +0000)
committerNick Lewycky <nicholas@mxc.ca>
Sun, 2 Jan 2011 02:46:33 +0000 (02:46 +0000)
maintains the guarantee that the DenseSet expects two elements it contains to
not go from inequal to equal under its nose.

As a side-effect, this also lets us switch from iterating to a fixed-point to
actually maintaining a work queue of functions to look at again, and we don't
add thunks to our work queue so we don't need to detect and ignore them.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@122677 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/IPO/MergeFunctions.cpp

index 9cfbcc8dae9d908800674ee22b7ed66c628a2110..74d410515161da62522d9d76d5531e38789349e5 100644 (file)
@@ -160,20 +160,36 @@ public:
 private:
   typedef DenseSet<ComparableFunction> FnSetType;
 
+  /// A work queue of functions that may have been modified and should be
+  /// analyzed again.
+  std::vector<WeakVH> Deferred;
 
   /// Insert a ComparableFunction into the FnSet, or merge it away if it's
   /// equal to one that's already present.
-  bool Insert(FnSetType &FnSet, ComparableFunction &NewF);
+  bool Insert(ComparableFunction &NewF);
+
+  /// Remove a Function from the FnSet and queue it up for a second sweep of
+  /// analysis.
+  void Remove(Function *F);
+
+  /// Find the functions that use this Value and remove them from FnSet and
+  /// queue the functions.
+  void RemoveUsers(Value *V);
 
   /// MergeTwoFunctions - Merge two equivalent functions. Upon completion, G
   /// may be deleted, or may be converted into a thunk. In either case, it
   /// should never be visited again.
-  void MergeTwoFunctions(Function *F, Function *G) const;
+  void MergeTwoFunctions(Function *F, Function *G);
 
   /// WriteThunk - Replace G with a simple tail call to bitcast(F). Also
   /// replace direct uses of G with bitcast(F). Deletes G.
-  void WriteThunk(Function *F, Function *G) const;
+  void WriteThunk(Function *F, Function *G);
 
+  /// The set of all distinct functions. Use the Insert and Remove methods to
+  /// modify it.
+  FnSetType FnSet;
+
+  /// TargetData for more accurate GEP comparisons. May be NULL.
   TargetData *TD;
 };
 
@@ -560,7 +576,7 @@ bool FunctionComparator::Compare() {
 
 /// WriteThunk - Replace G with a simple tail call to bitcast(F). Also replace
 /// direct uses of G with bitcast(F). Deletes G.
-void MergeFunctions::WriteThunk(Function *F, Function *G) const {
+void MergeFunctions::WriteThunk(Function *F, Function *G) {
   if (!G->mayBeOverridden()) {
     // Redirect direct callers of G to F.
     Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType());
@@ -569,8 +585,10 @@ void MergeFunctions::WriteThunk(Function *F, Function *G) const {
       Value::use_iterator TheIter = UI;
       ++UI;
       CallSite CS(*TheIter);
-      if (CS && CS.isCallee(TheIter))
+      if (CS && CS.isCallee(TheIter)) {
+        Remove(CS.getInstruction()->getParent()->getParent());
         TheIter.getUse().set(BitcastF);
+      }
     }
   }
 
@@ -606,6 +624,7 @@ void MergeFunctions::WriteThunk(Function *F, Function *G) const {
 
   NewG->copyAttributesFrom(G);
   NewG->takeName(G);
+  RemoveUsers(G);
   G->replaceAllUsesWith(NewG);
   G->eraseFromParent();
 
@@ -615,7 +634,7 @@ void MergeFunctions::WriteThunk(Function *F, Function *G) const {
 
 /// MergeTwoFunctions - Merge two equivalent functions. Upon completion,
 /// Function G is deleted.
-void MergeFunctions::MergeTwoFunctions(Function *F, Function *G) const {
+void MergeFunctions::MergeTwoFunctions(Function *F, Function *G) {
   if (F->mayBeOverridden()) {
     assert(G->mayBeOverridden());
 
@@ -624,6 +643,7 @@ void MergeFunctions::MergeTwoFunctions(Function *F, Function *G) const {
                                    F->getParent());
     H->copyAttributesFrom(F);
     H->takeName(F);
+    RemoveUsers(F);
     F->replaceAllUsesWith(H);
 
     unsigned MaxAlignment = std::max(G->getAlignment(), H->getAlignment());
@@ -632,7 +652,7 @@ void MergeFunctions::MergeTwoFunctions(Function *F, Function *G) const {
     WriteThunk(F, H);
 
     F->setAlignment(MaxAlignment);
-    F->setLinkage(GlobalValue::InternalLinkage);
+    F->setLinkage(GlobalValue::PrivateLinkage);
 
     ++NumDoubleWeak;
   } else {
@@ -644,7 +664,7 @@ void MergeFunctions::MergeTwoFunctions(Function *F, Function *G) const {
 
 // Insert - Insert a ComparableFunction into the FnSet, or merge it away if
 // equal to one that's already inserted.
-bool MergeFunctions::Insert(FnSetType &FnSet, ComparableFunction &NewF) {
+bool MergeFunctions::Insert(ComparableFunction &NewF) {
   std::pair<FnSetType::iterator, bool> Result = FnSet.insert(NewF);
   if (Result.second)
     return false;
@@ -664,91 +684,52 @@ bool MergeFunctions::Insert(FnSetType &FnSet, ComparableFunction &NewF) {
   return true;
 }
 
-// IsThunk - This method determines whether or not a given Function is a thunk\// like the ones emitted by this pass and therefore not subject to further
-// merging.
-static bool IsThunk(const Function *F) {
-  // The safe direction to fail is to return true. In that case, the function
-  // will be removed from merging analysis. If we failed to including functions
-  // then we may try to merge unmergable thing (ie., identical weak functions)
-  // which will push us into an infinite loop.
-
-  assert(!F->isDeclaration() && "Expected a function definition.");
-
-  const BasicBlock *BB = &F->front();
-  // A thunk is:
-  //   bitcast-inst*
-  //   optional-reg tail call @thunkee(args...*)
-  //   ret void|optional-reg
-  // where the args are in the same order as the arguments.
-
-  // Put this at the top since it triggers most often.
-  const ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator());
-  if (!RI) return false;
-
-  // Verify that the sequence of bitcast-inst's are all casts of arguments and
-  // that there aren't any extras (ie. no repeated casts).
-  int LastArgNo = -1;
-  BasicBlock::const_iterator I = BB->begin();
-  while (const BitCastInst *BCI = dyn_cast<BitCastInst>(I)) {
-    const Argument *A = dyn_cast<Argument>(BCI->getOperand(0));
-    if (!A) return false;
-    if ((int)A->getArgNo() <= LastArgNo) return false;
-    LastArgNo = A->getArgNo();
-    ++I;
+// Remove - Remove a function from FnSet. If it was already in FnSet, add it to
+// Deferred so that we'll look at it in the next round.
+void MergeFunctions::Remove(Function *F) {
+  ComparableFunction CF = ComparableFunction(F, TD);
+  if (FnSet.erase(CF)) {
+    Deferred.push_back(F);
   }
+}
 
-  // Verify that we have a direct tail call and that the calling conventions
-  // and number of arguments match.
-  const CallInst *CI = dyn_cast<CallInst>(I++);
-  if (!CI || !CI->isTailCall() || !CI->getCalledFunction() || 
-      CI->getCallingConv() != CI->getCalledFunction()->getCallingConv() ||
-      CI->getNumArgOperands() != F->arg_size())
-    return false;
-
-  // Verify that the call instruction has the same arguments as this function
-  // and that they're all either the incoming argument or a cast of the right
-  // argument.
-  for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) {
-    const Value *V = CI->getArgOperand(i);
-    const Argument *A = dyn_cast<Argument>(V);
-    if (!A) {
-      const BitCastInst *BCI = dyn_cast<BitCastInst>(V);
-      if (!BCI) return false;
-      A = cast<Argument>(BCI->getOperand(0));
+// RemoveUsers - For each instruction used by the value, Remove() the function
+// that contains the instruction. This should happen right before a call to RAUW.
+void MergeFunctions::RemoveUsers(Value *V) {
+  for (Value::use_iterator UI = V->use_begin(), UE = V->use_end();
+       UI != UE; ++UI) {
+    Use &U = UI.getUse();
+    if (Instruction *I = dyn_cast<Instruction>(U.getUser())) {
+      Remove(I->getParent()->getParent());
     }
-    if (A->getArgNo() != i) return false;
   }
-
-  // Verify that the terminator is a ret void (if we're void) or a ret of the
-  // call's return, or a ret of a bitcast of the call's return.
-  if (const BitCastInst *BCI = dyn_cast<BitCastInst>(I)) {
-    ++I;
-    if (BCI->getOperand(0) != CI) return false;
-  }
-  if (RI != I) return false;
-  if (RI->getNumOperands() == 0)
-    return CI->getType()->isVoidTy();
-  return RI->getReturnValue() == CI;
 }
 
 bool MergeFunctions::runOnModule(Module &M) {
   bool Changed = false;
   TD = getAnalysisIfAvailable<TargetData>();
 
-  bool LocalChanged;
+  for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) {
+    Deferred.push_back(WeakVH(I));
+  }
+
   do {
+    std::vector<WeakVH> Worklist;
+    Deferred.swap(Worklist);
+
     DEBUG(dbgs() << "size of module: " << M.size() << '\n');
-    LocalChanged = false;
-    FnSetType FnSet;
+    DEBUG(dbgs() << "size of worklist: " << Worklist.size() << '\n');
 
     // Insert only strong functions and merge them. Strong function merging
     // always deletes one of them.
-    for (Module::iterator I = M.begin(), E = M.end(); I != E;) {
-      Function *F = I++;
+    for (std::vector<WeakVH>::iterator I = Worklist.begin(),
+           E = Worklist.end(); I != E; ++I) {
+      if (!*I) continue;
+      Function *F = cast<Function>(*I);
       if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage() &&
-          !F->mayBeOverridden() && !IsThunk(F)) {
+          !F->mayBeOverridden()) {
         ComparableFunction CF = ComparableFunction(F, TD);
-        LocalChanged |= Insert(FnSet, CF);
+        Changed |= Insert(CF);
       }
     }
 
@@ -756,17 +737,20 @@ bool MergeFunctions::runOnModule(Module &M) {
     // create thunks to the strong function when possible. When two weak
     // functions are identical, we create a new strong function with two weak
     // weak thunks to it which are identical but not mergable.
-    for (Module::iterator I = M.begin(), E = M.end(); I != E;) {
-      Function *F = I++;
+    for (std::vector<WeakVH>::iterator I = Worklist.begin(),
+           E = Worklist.end(); I != E; ++I) {
+      if (!*I) continue;
+      Function *F = cast<Function>(*I);
       if (!F->isDeclaration() && !F->hasAvailableExternallyLinkage() &&
-          F->mayBeOverridden() && !IsThunk(F)) {
+          F->mayBeOverridden()) {
         ComparableFunction CF = ComparableFunction(F, TD);
-        LocalChanged |= Insert(FnSet, CF);
+        Changed |= Insert(CF);
       }
     }
     DEBUG(dbgs() << "size of FnSet: " << FnSet.size() << '\n');
-    Changed |= LocalChanged;
-  } while (LocalChanged);
+  } while (!Deferred.empty());
+
+  FnSet.clear();
 
   return Changed;
 }