[MergeFuncs] Fix callsite attributes in thunk generation
[oota-llvm.git] / lib / Transforms / IPO / MergeFunctions.cpp
index 9ffd6534a65a4511a84919878896e2f37755dea0..a8a85982e18d5fc4e6f4a65677f3a4b0557aa4ee 100644 (file)
@@ -397,12 +397,12 @@ private:
   int cmpTypes(Type *TyL, Type *TyR) const;
 
   int cmpNumbers(uint64_t L, uint64_t R) const;
-
   int cmpAPInts(const APInt &L, const APInt &R) const;
   int cmpAPFloats(const APFloat &L, const APFloat &R) const;
   int cmpInlineAsm(const InlineAsm *L, const InlineAsm *R) const;
   int cmpMem(StringRef L, StringRef R) const;
   int cmpAttrs(const AttributeSet L, const AttributeSet R) const;
+  int cmpRangeMetadata(const MDNode* L, const MDNode* R) const;
 
   // The two functions undergoing comparison.
   const Function *FnL, *FnR;
@@ -481,13 +481,21 @@ int FunctionComparator::cmpAPInts(const APInt &L, const APInt &R) const {
 }
 
 int FunctionComparator::cmpAPFloats(const APFloat &L, const APFloat &R) const {
-  // TODO: This correctly handles all existing fltSemantics, because they all
-  // have different precisions. This isn't very robust, however, if new types
-  // with different exponent ranges are introduced.
+  // Floats are ordered first by semantics (i.e. float, double, half, etc.),
+  // then by value interpreted as a bitstring (aka APInt).
   const fltSemantics &SL = L.getSemantics(), &SR = R.getSemantics();
   if (int Res = cmpNumbers(APFloat::semanticsPrecision(SL),
                            APFloat::semanticsPrecision(SR)))
     return Res;
+  if (int Res = cmpNumbers(APFloat::semanticsMaxExponent(SL),
+                           APFloat::semanticsMaxExponent(SR)))
+    return Res;
+  if (int Res = cmpNumbers(APFloat::semanticsMinExponent(SL),
+                           APFloat::semanticsMinExponent(SR)))
+    return Res;
+  if (int Res = cmpNumbers(APFloat::semanticsSizeInBits(SL),
+                           APFloat::semanticsSizeInBits(SR)))
+    return Res;
   return cmpAPInts(L.bitcastToAPInt(), R.bitcastToAPInt());
 }
 
@@ -524,6 +532,32 @@ int FunctionComparator::cmpAttrs(const AttributeSet L,
   }
   return 0;
 }
+int FunctionComparator::cmpRangeMetadata(const MDNode* L,
+                                         const MDNode* R) const {
+  if (L == R)
+    return 0;
+  if (!L)
+    return -1;
+  if (!R)
+    return 1;
+  // Range metadata is a sequence of numbers. Make sure they are the same
+  // sequence. 
+  // TODO: Note that as this is metadata, it is possible to drop and/or merge
+  // this data when considering functions to merge. Thus this comparison would
+  // return 0 (i.e. equivalent), but merging would become more complicated
+  // because the ranges would need to be unioned. It is not likely that
+  // functions differ ONLY in this metadata if they are actually the same
+  // function semantically.
+  if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands()))
+    return Res;
+  for (size_t I = 0; I < L->getNumOperands(); ++I) {
+    ConstantInt* LLow = mdconst::extract<ConstantInt>(L->getOperand(I));
+    ConstantInt* RLow = mdconst::extract<ConstantInt>(R->getOperand(I));
+    if (int Res = cmpAPInts(LLow->getValue(), RLow->getValue()))
+      return Res;
+  }
+  return 0;
+}
 
 /// Constants comparison:
 /// 1. Check whether type of L constant could be losslessly bitcasted to R
@@ -607,7 +641,7 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) {
     return Res;
 
   if (const auto *SeqL = dyn_cast<ConstantDataSequential>(L)) {
-    const auto *SeqR = dyn_cast<ConstantDataSequential>(R);
+    const auto *SeqR = cast<ConstantDataSequential>(R);
     // This handles ConstantDataArray and ConstantDataVector. Note that we
     // compare the two raw data arrays, which might differ depending on the host
     // endianness. This isn't a problem though, because the endiness of a module
@@ -685,10 +719,38 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) {
     return 0;
   }
   case Value::BlockAddressVal: {
-    // FIXME: This still uses a pointer comparison. It isn't clear how to remove
-    // this. This only affects programs which take BlockAddresses and store them
-    // as constants, which is limited to interepreters, etc.
-    return cmpNumbers((uint64_t)L, (uint64_t)R);
+    const BlockAddress *LBA = cast<BlockAddress>(L);
+    const BlockAddress *RBA = cast<BlockAddress>(R);
+    if (int Res = cmpValues(LBA->getFunction(), RBA->getFunction()))
+      return Res;
+    if (LBA->getFunction() == RBA->getFunction()) {
+      // They are BBs in the same function. Order by which comes first in the
+      // BB order of the function. This order is deterministic.
+      Function* F = LBA->getFunction();
+      BasicBlock *LBB = LBA->getBasicBlock();
+      BasicBlock *RBB = RBA->getBasicBlock();
+      if (LBB == RBB)
+        return 0;
+      for(BasicBlock &BB : F->getBasicBlockList()) {
+        if (&BB == LBB) {
+          assert(&BB != RBB);
+          return -1;
+        }
+        if (&BB == RBB)
+          return 1;
+      }
+      llvm_unreachable("Basic Block Address does not point to a basic block in "
+                       "its function.");
+      return -1;
+    } else {
+      // cmpValues said the functions are the same. So because they aren't
+      // literally the same pointer, they must respectively be the left and
+      // right functions.
+      assert(LBA->getFunction() == FnL && RBA->getFunction() == FnR);
+      // cmpValues will tell us if these are equivalent BasicBlocks, in the
+      // context of their respective functions.
+      return cmpValues(LBA->getBasicBlock(), RBA->getBasicBlock());
+    }
   }
   default: // Unknown constant, abort.
     DEBUG(dbgs() << "Looking at valueID " << L->getValueID() << "\n");
@@ -849,8 +911,8 @@ int FunctionComparator::cmpOperations(const Instruction *L,
     if (int Res =
             cmpNumbers(LI->getSynchScope(), cast<LoadInst>(R)->getSynchScope()))
       return Res;
-    return cmpNumbers((uint64_t)LI->getMetadata(LLVMContext::MD_range),
-                      (uint64_t)cast<LoadInst>(R)->getMetadata(LLVMContext::MD_range));
+    return cmpRangeMetadata(LI->getMetadata(LLVMContext::MD_range),
+        cast<LoadInst>(R)->getMetadata(LLVMContext::MD_range));
   }
   if (const StoreInst *SI = dyn_cast<StoreInst>(L)) {
     if (int Res =
@@ -873,9 +935,9 @@ int FunctionComparator::cmpOperations(const Instruction *L,
     if (int Res =
             cmpAttrs(CI->getAttributes(), cast<CallInst>(R)->getAttributes()))
       return Res;
-    return cmpNumbers(
-        (uint64_t)CI->getMetadata(LLVMContext::MD_range),
-        (uint64_t)cast<CallInst>(R)->getMetadata(LLVMContext::MD_range));
+    return cmpRangeMetadata(
+        CI->getMetadata(LLVMContext::MD_range),
+        cast<CallInst>(R)->getMetadata(LLVMContext::MD_range));
   }
   if (const InvokeInst *CI = dyn_cast<InvokeInst>(L)) {
     if (int Res = cmpNumbers(CI->getCallingConv(),
@@ -884,9 +946,9 @@ int FunctionComparator::cmpOperations(const Instruction *L,
     if (int Res =
             cmpAttrs(CI->getAttributes(), cast<InvokeInst>(R)->getAttributes()))
       return Res;
-    return cmpNumbers(
-        (uint64_t)CI->getMetadata(LLVMContext::MD_range),
-        (uint64_t)cast<InvokeInst>(R)->getMetadata(LLVMContext::MD_range));
+    return cmpRangeMetadata(
+        CI->getMetadata(LLVMContext::MD_range),
+        cast<InvokeInst>(R)->getMetadata(LLVMContext::MD_range));
   }
   if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(L)) {
     ArrayRef<unsigned> LIndices = IVI->getIndices();
@@ -1253,7 +1315,7 @@ class MergeFunctions : public ModulePass {
 public:
   static char ID;
   MergeFunctions()
-    : ModulePass(ID), FnTree(FunctionNodeCmp(&GlobalNumbers)),
+    : ModulePass(ID), FnTree(FunctionNodeCmp(&GlobalNumbers)), FNodesInTree(),
       HasGlobalAliases(false) {
     initializeMergeFunctionsPass(*PassRegistry::getPassRegistry());
   }
@@ -1319,11 +1381,17 @@ private:
   void writeAlias(Function *F, Function *G);
 
   /// Replace function F with function G in the function tree.
-  void replaceFunctionInTree(FnTreeType::iterator &IterToF, Function *G);
+  void replaceFunctionInTree(const FunctionNode &FN, Function *G);
 
   /// The set of all distinct functions. Use the insert() and remove() methods
-  /// to modify it.
+  /// to modify it. The map allows efficient lookup and deferring of Functions.
   FnTreeType FnTree;
+  // Map functions to the iterators of the FunctionNode which contains them
+  // in the FnTree. This must be updated carefully whenever the FnTree is
+  // modified, i.e. in insert(), remove(), and replaceFunctionInTree(), to avoid
+  // dangling iterators into FnTree. The invariant that preserves this is that
+  // there is exactly one mapping F -> FN for each FunctionNode FN in FnTree.
+  ValueMap<Function*, FnTreeType::iterator> FNodesInTree;
 
   /// Whether or not the target supports global aliases.
   bool HasGlobalAliases;
@@ -1491,9 +1559,16 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
     CallSite CS(U->getUser());
     if (CS && CS.isCallee(U)) {
       // Transfer the called function's attributes to the call site. Due to the
-      // bitcast we will 'loose' ABI changing attributes because the 'called
+      // bitcast we will 'lose' ABI changing attributes because the 'called
       // function' is no longer a Function* but the bitcast. Code that looks up
       // the attributes from the called function will fail.
+
+      // FIXME: This is not actually true, at least not anymore. The callsite
+      // will always have the same ABI affecting attributes as the callee,
+      // because otherwise the original input has UB. Note that Old and New
+      // always have matching ABI, so no attributes need to be changed.
+      // Transferring other attributes may help other optimizations, but that
+      // should be done uniformly and not in this ad-hoc way.
       auto &Context = New->getContext();
       auto NewFuncAttrs = New->getAttributes();
       auto CallSiteAttrs = CS.getAttributes();
@@ -1588,6 +1663,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
   CallInst *CI = Builder.CreateCall(F, Args);
   CI->setTailCall();
   CI->setCallingConv(F->getCallingConv());
+  CI->setAttributes(F->getAttributes());
   if (NewG->getReturnType()->isVoidTy()) {
     Builder.CreateRetVoid();
   } else {
@@ -1652,21 +1728,24 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) {
   ++NumFunctionsMerged;
 }
 
-/// Replace function F for function G in the map.
-void MergeFunctions::replaceFunctionInTree(FnTreeType::iterator &IterToF,
+/// Replace function F by function G.
+void MergeFunctions::replaceFunctionInTree(const FunctionNode &FN,
                                            Function *G) {
-  Function *F = IterToF->getFunc();
-
-  // A total order is already guaranteed otherwise because we process strong
-  // functions before weak functions.
-  assert(((F->mayBeOverridden() && G->mayBeOverridden()) ||
-          (!F->mayBeOverridden() && !G->mayBeOverridden())) &&
-         "Only change functions if both are strong or both are weak");
-  (void)F;
+  Function *F = FN.getFunc();
   assert(FunctionComparator(F, G, &GlobalNumbers).compare() == 0 &&
          "The two functions must be equal");
-
-  IterToF->replaceBy(G);
+  
+  auto I = FNodesInTree.find(F);
+  assert(I != FNodesInTree.end() && "F should be in FNodesInTree");
+  assert(FNodesInTree.count(G) == 0 && "FNodesInTree should not contain G");
+  
+  FnTreeType::iterator IterToFNInFnTree = I->second;
+  assert(&(*IterToFNInFnTree) == &FN && "F should map to FN in FNodesInTree.");
+  // Remove F -> FN and insert G -> FN
+  FNodesInTree.erase(I);
+  FNodesInTree.insert({G, IterToFNInFnTree});
+  // Replace F with G in FN, which is stored inside the FnTree.
+  FN.replaceBy(G);
 }
 
 // Insert a ComparableFunction into the FnTree, or merge it away if equal to one
@@ -1676,6 +1755,8 @@ bool MergeFunctions::insert(Function *NewFunction) {
       FnTree.insert(FunctionNode(NewFunction));
 
   if (Result.second) {
+    assert(FNodesInTree.count(NewFunction) == 0);
+    FNodesInTree.insert({NewFunction, Result.first});
     DEBUG(dbgs() << "Inserting as unique: " << NewFunction->getName() << '\n');
     return false;
   }
@@ -1705,7 +1786,7 @@ bool MergeFunctions::insert(Function *NewFunction) {
     if (OldF.getFunc()->getName() > NewFunction->getName()) {
       // Swap the two functions.
       Function *F = OldF.getFunc();
-      replaceFunctionInTree(Result.first, NewFunction);
+      replaceFunctionInTree(*Result.first, NewFunction);
       NewFunction = F;
       assert(OldF.getFunc() != F && "Must have swapped the functions.");
     }
@@ -1724,18 +1805,13 @@ bool MergeFunctions::insert(Function *NewFunction) {
 // Remove a function from FnTree. If it was already in FnTree, add
 // it to Deferred so that we'll look at it in the next round.
 void MergeFunctions::remove(Function *F) {
-  // We need to make sure we remove F, not a function "equal" to F per the
-  // function equality comparator.
-  FnTreeType::iterator found = FnTree.find(FunctionNode(F));
-  size_t Erased = 0;
-  if (found != FnTree.end() && found->getFunc() == F) {
-    Erased = 1;
-    FnTree.erase(found);
-  }
-
-  if (Erased) {
-    DEBUG(dbgs() << "Removed " << F->getName()
-                 << " from set and deferred it.\n");
+  auto I = FNodesInTree.find(F);
+  if (I != FNodesInTree.end()) {
+    DEBUG(dbgs() << "Deferred " << F->getName()<< ".\n");
+    FnTree.erase(I->second);
+    // I->second has been invalidated, remove it from the FNodesInTree map to
+    // preserve the invariant.
+    FNodesInTree.erase(I);
     Deferred.emplace_back(F);
   }
 }