Rangify for loop in Inliner.cpp. NFC.
[oota-llvm.git] / lib / Transforms / IPO / MergeFunctions.cpp
index 7580830576b875c7c46b85457039e703fc6f8c65..2e3519eac6a524bbb4fa916d999b946380328b3d 100644 (file)
@@ -127,9 +127,8 @@ namespace {
 /// side of claiming that two functions are different).
 class FunctionComparator {
 public:
-  FunctionComparator(const DataLayout *DL, const Function *F1,
-                     const Function *F2)
-      : FnL(F1), FnR(F2), DL(DL) {}
+  FunctionComparator(const Function *F1, const Function *F2)
+      : FnL(F1), FnR(F2) {}
 
   /// Test whether the two functions have equivalent behaviour.
   int compare();
@@ -292,8 +291,7 @@ private:
   /// Parts to be compared for each comparison stage,
   /// most significant stage first:
   /// 1. Address space. As numbers.
-  /// 2. Constant offset, (if "DataLayout *DL" field is not NULL,
-  /// using GEPOperator::accumulateConstantOffset method).
+  /// 2. Constant offset, (using GEPOperator::accumulateConstantOffset method).
   /// 3. Pointer operand type (using cmpType method).
   /// 4. Number of operands.
   /// 5. Compare operands, using cmpValues method.
@@ -354,8 +352,6 @@ private:
   // The two functions undergoing comparison.
   const Function *FnL, *FnR;
 
-  const DataLayout *DL;
-
   /// Assign serial numbers to values from left function, and values from
   /// right function.
   /// Explanation:
@@ -392,16 +388,25 @@ private:
   DenseMap<const Value*, int> sn_mapL, sn_mapR;
 };
 
-class FunctionPtr {
-  AssertingVH<Function> F;
-  const DataLayout *DL;
+class FunctionNode {
+  mutable AssertingVH<Function> F;
 
 public:
-  FunctionPtr(Function *F, const DataLayout *DL) : F(F), DL(DL) {}
+  FunctionNode(Function *F) : F(F) {}
   Function *getFunc() const { return F; }
+
+  /// Replace the reference to the function F by the function G, assuming their
+  /// implementations are equal.
+  void replaceBy(Function *G) const {
+    assert(!(*this < FunctionNode(G)) && !(FunctionNode(G) < *this) &&
+           "The two functions must be equal");
+
+    F = G;
+  }
+
   void release() { F = 0; }
-  bool operator<(const FunctionPtr &RHS) const {
-    return (FunctionComparator(DL, F, RHS.getFunc()).compare()) == -1;
+  bool operator<(const FunctionNode &RHS) const {
+    return (FunctionComparator(F, RHS.getFunc()).compare()) == -1;
   }
 };
 }
@@ -620,10 +625,11 @@ int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const {
   PointerType *PTyL = dyn_cast<PointerType>(TyL);
   PointerType *PTyR = dyn_cast<PointerType>(TyR);
 
-  if (DL) {
-    if (PTyL && PTyL->getAddressSpace() == 0) TyL = DL->getIntPtrType(TyL);
-    if (PTyR && PTyR->getAddressSpace() == 0) TyR = DL->getIntPtrType(TyR);
-  }
+  const DataLayout &DL = FnL->getParent()->getDataLayout();
+  if (PTyL && PTyL->getAddressSpace() == 0)
+    TyL = DL.getIntPtrType(TyL);
+  if (PTyR && PTyR->getAddressSpace() == 0)
+    TyR = DL.getIntPtrType(TyR);
 
   if (TyL == TyR)
     return 0;
@@ -723,6 +729,15 @@ int FunctionComparator::cmpOperations(const Instruction *L,
                            R->getRawSubclassOptionalData()))
     return Res;
 
+  if (const AllocaInst *AI = dyn_cast<AllocaInst>(L)) {
+    if (int Res = cmpTypes(AI->getAllocatedType(),
+                           cast<AllocaInst>(R)->getAllocatedType()))
+      return Res;
+    if (int Res =
+            cmpNumbers(AI->getAlignment(), cast<AllocaInst>(R)->getAlignment()))
+      return Res;
+  }
+
   // We have two instructions of identical opcode and #operands.  Check to see
   // if all operands are the same type
   for (unsigned i = 0, e = L->getNumOperands(); i != e; ++i) {
@@ -855,13 +870,12 @@ int FunctionComparator::cmpGEPs(const GEPOperator *GEPL,
 
   // When we have target data, we can reduce the GEP down to the value in bytes
   // added to the address.
-  if (DL) {
-    unsigned BitWidth = DL->getPointerSizeInBits(ASL);
-    APInt OffsetL(BitWidth, 0), OffsetR(BitWidth, 0);
-    if (GEPL->accumulateConstantOffset(*DL, OffsetL) &&
-        GEPR->accumulateConstantOffset(*DL, OffsetR))
-      return cmpAPInts(OffsetL, OffsetR);
-  }
+  const DataLayout &DL = FnL->getParent()->getDataLayout();
+  unsigned BitWidth = DL.getPointerSizeInBits(ASL);
+  APInt OffsetL(BitWidth, 0), OffsetR(BitWidth, 0);
+  if (GEPL->accumulateConstantOffset(DL, OffsetL) &&
+      GEPR->accumulateConstantOffset(DL, OffsetR))
+    return cmpAPInts(OffsetL, OffsetR);
 
   if (int Res = cmpNumbers((uint64_t)GEPL->getPointerOperand()->getType(),
                            (uint64_t)GEPR->getPointerOperand()->getType()))
@@ -1049,7 +1063,7 @@ int FunctionComparator::compare() {
 
     assert(TermL->getNumSuccessors() == TermR->getNumSuccessors());
     for (unsigned i = 0, e = TermL->getNumSuccessors(); i != e; ++i) {
-      if (!VisitedBBs.insert(TermL->getSuccessor(i)))
+      if (!VisitedBBs.insert(TermL->getSuccessor(i)).second)
         continue;
 
       FnLBBs.push_back(TermL->getSuccessor(i));
@@ -1077,7 +1091,7 @@ public:
   bool runOnModule(Module &M) override;
 
 private:
-  typedef std::set<FunctionPtr> FnTreeType;
+  typedef std::set<FunctionNode> FnTreeType;
 
   /// A work queue of functions that may have been modified and should be
   /// analyzed again.
@@ -1118,13 +1132,13 @@ private:
   /// Replace G with an alias to F. Deletes G.
   void writeAlias(Function *F, Function *G);
 
+  /// Replace function F with function G in the function tree.
+  void replaceFunctionInTree(FnTreeType::iterator &IterToF, Function *G);
+
   /// The set of all distinct functions. Use the insert() and remove() methods
   /// to modify it.
   FnTreeType FnTree;
 
-  /// DataLayout for more accurate GEP comparisons. May be NULL.
-  const DataLayout *DL;
-
   /// Whether or not the target supports global aliases.
   bool HasGlobalAliases;
 };
@@ -1152,8 +1166,8 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) {
       for (std::vector<WeakVH>::iterator J = I; J != E && j < Max; ++J, ++j) {
         Function *F1 = cast<Function>(*I);
         Function *F2 = cast<Function>(*J);
-        int Res1 = FunctionComparator(DL, F1, F2).compare();
-        int Res2 = FunctionComparator(DL, F2, F1).compare();
+        int Res1 = FunctionComparator(F1, F2).compare();
+        int Res2 = FunctionComparator(F2, F1).compare();
 
         // If F1 <= F2, then F2 >= F1, otherwise report failure.
         if (Res1 != -Res2) {
@@ -1174,8 +1188,8 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) {
             continue;
 
           Function *F3 = cast<Function>(*K);
-          int Res3 = FunctionComparator(DL, F1, F3).compare();
-          int Res4 = FunctionComparator(DL, F2, F3).compare();
+          int Res3 = FunctionComparator(F1, F3).compare();
+          int Res4 = FunctionComparator(F2, F3).compare();
 
           bool Transitive = true;
 
@@ -1212,8 +1226,6 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) {
 
 bool MergeFunctions::runOnModule(Module &M) {
   bool Changed = false;
-  DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>();
-  DL = DLP ? &DLP->getDataLayout() : nullptr;
 
   for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) {
     if (!I->isDeclaration() && !I->hasAvailableExternallyLinkage())
@@ -1368,8 +1380,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
 // Replace G with an alias to F and delete G.
 void MergeFunctions::writeAlias(Function *F, Function *G) {
   PointerType *PTy = G->getType();
-  auto *GA = GlobalAlias::create(PTy->getElementType(), PTy->getAddressSpace(),
-                                 G->getLinkage(), "", F);
+  auto *GA = GlobalAlias::create(PTy, G->getLinkage(), "", F);
   F->setAlignment(std::max(F->getAlignment(), G->getAlignment()));
   GA->takeName(G);
   GA->setVisibility(G->getVisibility());
@@ -1386,28 +1397,26 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) {
   if (F->mayBeOverridden()) {
     assert(G->mayBeOverridden());
 
-    if (HasGlobalAliases) {
-      // Make them both thunks to the same internal function.
-      Function *H = Function::Create(F->getFunctionType(), F->getLinkage(), "",
-                                     F->getParent());
-      H->copyAttributesFrom(F);
-      H->takeName(F);
-      removeUsers(F);
-      F->replaceAllUsesWith(H);
+    // Make them both thunks to the same internal function.
+    Function *H = Function::Create(F->getFunctionType(), F->getLinkage(), "",
+                                   F->getParent());
+    H->copyAttributesFrom(F);
+    H->takeName(F);
+    removeUsers(F);
+    F->replaceAllUsesWith(H);
 
-      unsigned MaxAlignment = std::max(G->getAlignment(), H->getAlignment());
+    unsigned MaxAlignment = std::max(G->getAlignment(), H->getAlignment());
 
+    if (HasGlobalAliases) {
       writeAlias(F, G);
       writeAlias(F, H);
-
-      F->setAlignment(MaxAlignment);
-      F->setLinkage(GlobalValue::PrivateLinkage);
     } else {
-      // We can't merge them. Instead, pick one and update all direct callers
-      // to call it and hope that we improve the instruction cache hit rate.
-      replaceDirectCallers(G, F);
+      writeThunk(F, G);
+      writeThunk(F, H);
     }
 
+    F->setAlignment(MaxAlignment);
+    F->setLinkage(GlobalValue::PrivateLinkage);
     ++NumDoubleWeak;
   } else {
     writeThunkOrAlias(F, G);
@@ -1416,18 +1425,33 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) {
   ++NumFunctionsMerged;
 }
 
+/// Replace function F for function G in the map.
+void MergeFunctions::replaceFunctionInTree(FnTreeType::iterator &IterToF,
+                                           Function *G) {
+  Function *F = IterToF->getFunc();
+
+  // A total order is already guaranteed otherwise because we process strong
+  // functions before weak functions.
+  assert(((F->mayBeOverridden() && G->mayBeOverridden()) ||
+          (!F->mayBeOverridden() && !G->mayBeOverridden())) &&
+         "Only change functions if both are strong or both are weak");
+  (void)F;
+
+  IterToF->replaceBy(G);
+}
+
 // Insert a ComparableFunction into the FnTree, or merge it away if equal to one
 // that was already inserted.
 bool MergeFunctions::insert(Function *NewFunction) {
   std::pair<FnTreeType::iterator, bool> Result =
-      FnTree.insert(FunctionPtr(NewFunction, DL));
+      FnTree.insert(FunctionNode(NewFunction));
 
   if (Result.second) {
     DEBUG(dbgs() << "Inserting as unique: " << NewFunction->getName() << '\n');
     return false;
   }
 
-  const FunctionPtr &OldF = *Result.first;
+  const FunctionNode &OldF = *Result.first;
 
   // Don't merge tiny functions, since it can just end up making the function
   // larger.
@@ -1441,6 +1465,22 @@ bool MergeFunctions::insert(Function *NewFunction) {
     }
   }
 
+  // Impose a total order (by name) on the replacement of functions. This is
+  // important when operating on more than one module independently to prevent
+  // cycles of thunks calling each other when the modules are linked together.
+  //
+  // When one function is weak and the other is strong there is an order imposed
+  // already. We process strong functions before weak functions.
+  if ((OldF.getFunc()->mayBeOverridden() && NewFunction->mayBeOverridden()) ||
+      (!OldF.getFunc()->mayBeOverridden() && !NewFunction->mayBeOverridden()))
+    if (OldF.getFunc()->getName() > NewFunction->getName()) {
+      // Swap the two functions.
+      Function *F = OldF.getFunc();
+      replaceFunctionInTree(Result.first, NewFunction);
+      NewFunction = F;
+      assert(OldF.getFunc() != F && "Must have swapped the functions.");
+    }
+
   // Never thunk a strong function to a weak function.
   assert(!OldF.getFunc()->mayBeOverridden() || NewFunction->mayBeOverridden());
 
@@ -1457,7 +1497,7 @@ bool MergeFunctions::insert(Function *NewFunction) {
 void MergeFunctions::remove(Function *F) {
   // We need to make sure we remove F, not a function "equal" to F per the
   // function equality comparator.
-  FnTreeType::iterator found = FnTree.find(FunctionPtr(F, DL));
+  FnTreeType::iterator found = FnTree.find(FunctionNode(F));
   size_t Erased = 0;
   if (found != FnTree.end() && found->getFunc() == F) {
     Erased = 1;
@@ -1467,7 +1507,7 @@ void MergeFunctions::remove(Function *F) {
   if (Erased) {
     DEBUG(dbgs() << "Removed " << F->getName()
                  << " from set and deferred it.\n");
-    Deferred.push_back(F);
+    Deferred.emplace_back(F);
   }
 }