[IR] Add support for empty tokens
[oota-llvm.git] / lib / Transforms / IPO / MergeFunctions.cpp
index a31a08039796300ce075e17913c02e4c5e8b66b1..bb75ab6ece16aa0b3c3e94025dc8e0d949a65b61 100644 (file)
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
 #include <vector>
+
 using namespace llvm;
 
 #define DEBUG_TYPE "mergefunc"
@@ -164,6 +165,9 @@ class GlobalNumberState {
         NextNumber++;
       return MapIter->second;
     }
+    void clear() {
+      GlobalNumbers.clear();
+    }
 };
 
 /// FunctionComparator - Compares two functions to determine whether or not
@@ -397,12 +401,12 @@ private:
   int cmpTypes(Type *TyL, Type *TyR) const;
 
   int cmpNumbers(uint64_t L, uint64_t R) const;
-
   int cmpAPInts(const APInt &L, const APInt &R) const;
   int cmpAPFloats(const APFloat &L, const APFloat &R) const;
   int cmpInlineAsm(const InlineAsm *L, const InlineAsm *R) const;
   int cmpMem(StringRef L, StringRef R) const;
   int cmpAttrs(const AttributeSet L, const AttributeSet R) const;
+  int cmpRangeMetadata(const MDNode* L, const MDNode* R) const;
 
   // The two functions undergoing comparison.
   const Function *FnL, *FnR;
@@ -462,9 +466,9 @@ public:
     F = G;
   }
 
-  void release() { F = 0; }
+  void release() { F = nullptr; }
 };
-}
+} // end anonymous namespace
 
 int FunctionComparator::cmpNumbers(uint64_t L, uint64_t R) const {
   if (L < R) return -1;
@@ -481,13 +485,21 @@ int FunctionComparator::cmpAPInts(const APInt &L, const APInt &R) const {
 }
 
 int FunctionComparator::cmpAPFloats(const APFloat &L, const APFloat &R) const {
-  // TODO: This correctly handles all existing fltSemantics, because they all
-  // have different precisions. This isn't very robust, however, if new types
-  // with different exponent ranges are introduced.
+  // Floats are ordered first by semantics (i.e. float, double, half, etc.),
+  // then by value interpreted as a bitstring (aka APInt).
   const fltSemantics &SL = L.getSemantics(), &SR = R.getSemantics();
   if (int Res = cmpNumbers(APFloat::semanticsPrecision(SL),
                            APFloat::semanticsPrecision(SR)))
     return Res;
+  if (int Res = cmpNumbers(APFloat::semanticsMaxExponent(SL),
+                           APFloat::semanticsMaxExponent(SR)))
+    return Res;
+  if (int Res = cmpNumbers(APFloat::semanticsMinExponent(SL),
+                           APFloat::semanticsMinExponent(SR)))
+    return Res;
+  if (int Res = cmpNumbers(APFloat::semanticsSizeInBits(SL),
+                           APFloat::semanticsSizeInBits(SR)))
+    return Res;
   return cmpAPInts(L.bitcastToAPInt(), R.bitcastToAPInt());
 }
 
@@ -525,6 +537,33 @@ int FunctionComparator::cmpAttrs(const AttributeSet L,
   return 0;
 }
 
+int FunctionComparator::cmpRangeMetadata(const MDNode* L,
+                                         const MDNode* R) const {
+  if (L == R)
+    return 0;
+  if (!L)
+    return -1;
+  if (!R)
+    return 1;
+  // Range metadata is a sequence of numbers. Make sure they are the same
+  // sequence. 
+  // TODO: Note that as this is metadata, it is possible to drop and/or merge
+  // this data when considering functions to merge. Thus this comparison would
+  // return 0 (i.e. equivalent), but merging would become more complicated
+  // because the ranges would need to be unioned. It is not likely that
+  // functions differ ONLY in this metadata if they are actually the same
+  // function semantically.
+  if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands()))
+    return Res;
+  for (size_t I = 0; I < L->getNumOperands(); ++I) {
+    ConstantInt* LLow = mdconst::extract<ConstantInt>(L->getOperand(I));
+    ConstantInt* RLow = mdconst::extract<ConstantInt>(R->getOperand(I));
+    if (int Res = cmpAPInts(LLow->getValue(), RLow->getValue()))
+      return Res;
+  }
+  return 0;
+}
+
 /// Constants comparison:
 /// 1. Check whether type of L constant could be losslessly bitcasted to R
 /// type.
@@ -607,7 +646,7 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) {
     return Res;
 
   if (const auto *SeqL = dyn_cast<ConstantDataSequential>(L)) {
-    const auto *SeqR = dyn_cast<ConstantDataSequential>(R);
+    const auto *SeqR = cast<ConstantDataSequential>(R);
     // This handles ConstantDataArray and ConstantDataVector. Note that we
     // compare the two raw data arrays, which might differ depending on the host
     // endianness. This isn't a problem though, because the endiness of a module
@@ -617,7 +656,9 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) {
   }
 
   switch (L->getValueID()) {
-  case Value::UndefValueVal: return TypesRes;
+  case Value::UndefValueVal:
+  case Value::ConstantTokenNoneVal:
+    return TypesRes;
   case Value::ConstantIntVal: {
     const APInt &LInt = cast<ConstantInt>(L)->getValue();
     const APInt &RInt = cast<ConstantInt>(R)->getValue();
@@ -685,10 +726,38 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) {
     return 0;
   }
   case Value::BlockAddressVal: {
-    // FIXME: This still uses a pointer comparison. It isn't clear how to remove
-    // this. This only affects programs which take BlockAddresses and store them
-    // as constants, which is limited to interepreters, etc.
-    return cmpNumbers((uint64_t)L, (uint64_t)R);
+    const BlockAddress *LBA = cast<BlockAddress>(L);
+    const BlockAddress *RBA = cast<BlockAddress>(R);
+    if (int Res = cmpValues(LBA->getFunction(), RBA->getFunction()))
+      return Res;
+    if (LBA->getFunction() == RBA->getFunction()) {
+      // They are BBs in the same function. Order by which comes first in the
+      // BB order of the function. This order is deterministic.
+      Function* F = LBA->getFunction();
+      BasicBlock *LBB = LBA->getBasicBlock();
+      BasicBlock *RBB = RBA->getBasicBlock();
+      if (LBB == RBB)
+        return 0;
+      for(BasicBlock &BB : F->getBasicBlockList()) {
+        if (&BB == LBB) {
+          assert(&BB != RBB);
+          return -1;
+        }
+        if (&BB == RBB)
+          return 1;
+      }
+      llvm_unreachable("Basic Block Address does not point to a basic block in "
+                       "its function.");
+      return -1;
+    } else {
+      // cmpValues said the functions are the same. So because they aren't
+      // literally the same pointer, they must respectively be the left and
+      // right functions.
+      assert(LBA->getFunction() == FnL && RBA->getFunction() == FnR);
+      // cmpValues will tell us if these are equivalent BasicBlocks, in the
+      // context of their respective functions.
+      return cmpValues(LBA->getBasicBlock(), RBA->getBasicBlock());
+    }
   }
   default: // Unknown constant, abort.
     DEBUG(dbgs() << "Looking at valueID " << L->getValueID() << "\n");
@@ -705,7 +774,6 @@ int FunctionComparator::cmpGlobalValues(GlobalValue *L, GlobalValue* R) {
 /// defines total ordering among the types set.
 /// See method declaration comments for more details.
 int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const {
-
   PointerType *PTyL = dyn_cast<PointerType>(TyL);
   PointerType *PTyR = dyn_cast<PointerType>(TyR);
 
@@ -849,8 +917,8 @@ int FunctionComparator::cmpOperations(const Instruction *L,
     if (int Res =
             cmpNumbers(LI->getSynchScope(), cast<LoadInst>(R)->getSynchScope()))
       return Res;
-    return cmpNumbers((uint64_t)LI->getMetadata(LLVMContext::MD_range),
-                      (uint64_t)cast<LoadInst>(R)->getMetadata(LLVMContext::MD_range));
+    return cmpRangeMetadata(LI->getMetadata(LLVMContext::MD_range),
+        cast<LoadInst>(R)->getMetadata(LLVMContext::MD_range));
   }
   if (const StoreInst *SI = dyn_cast<StoreInst>(L)) {
     if (int Res =
@@ -873,9 +941,9 @@ int FunctionComparator::cmpOperations(const Instruction *L,
     if (int Res =
             cmpAttrs(CI->getAttributes(), cast<CallInst>(R)->getAttributes()))
       return Res;
-    return cmpNumbers(
-        (uint64_t)CI->getMetadata(LLVMContext::MD_range),
-        (uint64_t)cast<CallInst>(R)->getMetadata(LLVMContext::MD_range));
+    return cmpRangeMetadata(
+        CI->getMetadata(LLVMContext::MD_range),
+        cast<CallInst>(R)->getMetadata(LLVMContext::MD_range));
   }
   if (const InvokeInst *CI = dyn_cast<InvokeInst>(L)) {
     if (int Res = cmpNumbers(CI->getCallingConv(),
@@ -884,9 +952,9 @@ int FunctionComparator::cmpOperations(const Instruction *L,
     if (int Res =
             cmpAttrs(CI->getAttributes(), cast<InvokeInst>(R)->getAttributes()))
       return Res;
-    return cmpNumbers(
-        (uint64_t)CI->getMetadata(LLVMContext::MD_range),
-        (uint64_t)cast<InvokeInst>(R)->getMetadata(LLVMContext::MD_range));
+    return cmpRangeMetadata(
+        CI->getMetadata(LLVMContext::MD_range),
+        cast<InvokeInst>(R)->getMetadata(LLVMContext::MD_range));
   }
   if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(L)) {
     ArrayRef<unsigned> LIndices = IVI->getIndices();
@@ -966,8 +1034,8 @@ int FunctionComparator::cmpGEPs(const GEPOperator *GEPL,
   if (GEPL->accumulateConstantOffset(DL, OffsetL) &&
       GEPR->accumulateConstantOffset(DL, OffsetR))
     return cmpAPInts(OffsetL, OffsetR);
-  if (int Res = cmpTypes(GEPL->getPointerOperand()->getType(),
-                         GEPR->getPointerOperand()->getType()))
+  if (int Res = cmpTypes(GEPL->getSourceElementType(),
+                         GEPR->getSourceElementType()))
     return Res;
 
   if (int Res = cmpNumbers(GEPL->getNumOperands(), GEPR->getNumOperands()))
@@ -1055,7 +1123,7 @@ int FunctionComparator::cmpBasicBlocks(const BasicBlock *BBL,
   BasicBlock::const_iterator InstR = BBR->begin(), InstRE = BBR->end();
 
   do {
-    if (int Res = cmpValues(InstL, InstR))
+    if (int Res = cmpValues(&*InstL, &*InstR))
       return Res;
 
     const GetElementPtrInst *GEPL = dyn_cast<GetElementPtrInst>(InstL);
@@ -1073,7 +1141,7 @@ int FunctionComparator::cmpBasicBlocks(const BasicBlock *BBL,
       if (int Res = cmpGEPs(GEPL, GEPR))
         return Res;
     } else {
-      if (int Res = cmpOperations(InstL, InstR))
+      if (int Res = cmpOperations(&*InstL, &*InstR))
         return Res;
       assert(InstL->getNumOperands() == InstR->getNumOperands());
 
@@ -1082,11 +1150,8 @@ int FunctionComparator::cmpBasicBlocks(const BasicBlock *BBL,
         Value *OpR = InstR->getOperand(i);
         if (int Res = cmpValues(OpL, OpR))
           return Res;
-        if (int Res = cmpNumbers(OpL->getValueID(), OpR->getValueID()))
-          return Res;
-        // TODO: Already checked in cmpOperation
-        if (int Res = cmpTypes(OpL->getType(), OpR->getType()))
-          return Res;
+        // cmpValues should ensure this is true.
+        assert(cmpTypes(OpL->getType(), OpR->getType()) == 0);
       }
     }
 
@@ -1102,7 +1167,6 @@ int FunctionComparator::cmpBasicBlocks(const BasicBlock *BBL,
 
 // Test whether the two functions have equivalent behaviour.
 int FunctionComparator::compare() {
-
   sn_mapL.clear();
   sn_mapR.clear();
 
@@ -1145,7 +1209,7 @@ int FunctionComparator::compare() {
                                     ArgRI = FnR->arg_begin(),
                                     ArgLE = FnL->arg_end();
        ArgLI != ArgLE; ++ArgLI, ++ArgRI) {
-    if (cmpValues(ArgLI, ArgRI) != 0)
+    if (cmpValues(&*ArgLI, &*ArgRI) != 0)
       llvm_unreachable("Arguments repeat!");
   }
 
@@ -1256,7 +1320,7 @@ class MergeFunctions : public ModulePass {
 public:
   static char ID;
   MergeFunctions()
-    : ModulePass(ID), FnTree(FunctionNodeCmp(&GlobalNumbers)),
+    : ModulePass(ID), FnTree(FunctionNodeCmp(&GlobalNumbers)), FNodesInTree(),
       HasGlobalAliases(false) {
     initializeMergeFunctionsPass(*PassRegistry::getPassRegistry());
   }
@@ -1322,17 +1386,23 @@ private:
   void writeAlias(Function *F, Function *G);
 
   /// Replace function F with function G in the function tree.
-  void replaceFunctionInTree(FnTreeType::iterator &IterToF, Function *G);
+  void replaceFunctionInTree(const FunctionNode &FN, Function *G);
 
   /// The set of all distinct functions. Use the insert() and remove() methods
-  /// to modify it.
+  /// to modify it. The map allows efficient lookup and deferring of Functions.
   FnTreeType FnTree;
+  // Map functions to the iterators of the FunctionNode which contains them
+  // in the FnTree. This must be updated carefully whenever the FnTree is
+  // modified, i.e. in insert(), remove(), and replaceFunctionInTree(), to avoid
+  // dangling iterators into FnTree. The invariant that preserves this is that
+  // there is exactly one mapping F -> FN for each FunctionNode FN in FnTree.
+  ValueMap<Function*, FnTreeType::iterator> FNodesInTree;
 
   /// Whether or not the target supports global aliases.
   bool HasGlobalAliases;
 };
 
-}  // end anonymous namespace
+} // end anonymous namespace
 
 char MergeFunctions::ID = 0;
 INITIALIZE_PASS(MergeFunctions, "mergefunc", "Merge Functions", false, false)
@@ -1481,6 +1551,7 @@ bool MergeFunctions::runOnModule(Module &M) {
   } while (!Deferred.empty());
 
   FnTree.clear();
+  GlobalNumbers.clear();
 
   return Changed;
 }
@@ -1494,9 +1565,16 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
     CallSite CS(U->getUser());
     if (CS && CS.isCallee(U)) {
       // Transfer the called function's attributes to the call site. Due to the
-      // bitcast we will 'loose' ABI changing attributes because the 'called
+      // bitcast we will 'lose' ABI changing attributes because the 'called
       // function' is no longer a Function* but the bitcast. Code that looks up
       // the attributes from the called function will fail.
+
+      // FIXME: This is not actually true, at least not anymore. The callsite
+      // will always have the same ABI affecting attributes as the callee,
+      // because otherwise the original input has UB. Note that Old and New
+      // always have matching ABI, so no attributes need to be changed.
+      // Transferring other attributes may help other optimizations, but that
+      // should be done uniformly and not in this ad-hoc way.
       auto &Context = New->getContext();
       auto NewFuncAttrs = New->getAttributes();
       auto CallSiteAttrs = CS.getAttributes();
@@ -1582,15 +1660,15 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
   SmallVector<Value *, 16> Args;
   unsigned i = 0;
   FunctionType *FFTy = F->getFunctionType();
-  for (Function::arg_iterator AI = NewG->arg_begin(), AE = NewG->arg_end();
-       AI != AE; ++AI) {
-    Args.push_back(createCast(Builder, (Value*)AI, FFTy->getParamType(i)));
+  for (Argument & AI : NewG->args()) {
+    Args.push_back(createCast(Builder, &AI, FFTy->getParamType(i)));
     ++i;
   }
 
   CallInst *CI = Builder.CreateCall(F, Args);
   CI->setTailCall();
   CI->setCallingConv(F->getCallingConv());
+  CI->setAttributes(F->getAttributes());
   if (NewG->getReturnType()->isVoidTy()) {
     Builder.CreateRetVoid();
   } else {
@@ -1609,8 +1687,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
 
 // Replace G with an alias to F and delete G.
 void MergeFunctions::writeAlias(Function *F, Function *G) {
-  PointerType *PTy = G->getType();
-  auto *GA = GlobalAlias::create(PTy, G->getLinkage(), "", F);
+  auto *GA = GlobalAlias::create(G->getLinkage(), "", F);
   F->setAlignment(std::max(F->getAlignment(), G->getAlignment()));
   GA->takeName(G);
   GA->setVisibility(G->getVisibility());
@@ -1655,21 +1732,24 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) {
   ++NumFunctionsMerged;
 }
 
-/// Replace function F for function G in the map.
-void MergeFunctions::replaceFunctionInTree(FnTreeType::iterator &IterToF,
+/// Replace function F by function G.
+void MergeFunctions::replaceFunctionInTree(const FunctionNode &FN,
                                            Function *G) {
-  Function *F = IterToF->getFunc();
-
-  // A total order is already guaranteed otherwise because we process strong
-  // functions before weak functions.
-  assert(((F->mayBeOverridden() && G->mayBeOverridden()) ||
-          (!F->mayBeOverridden() && !G->mayBeOverridden())) &&
-         "Only change functions if both are strong or both are weak");
-  (void)F;
+  Function *F = FN.getFunc();
   assert(FunctionComparator(F, G, &GlobalNumbers).compare() == 0 &&
          "The two functions must be equal");
-
-  IterToF->replaceBy(G);
+  
+  auto I = FNodesInTree.find(F);
+  assert(I != FNodesInTree.end() && "F should be in FNodesInTree");
+  assert(FNodesInTree.count(G) == 0 && "FNodesInTree should not contain G");
+  
+  FnTreeType::iterator IterToFNInFnTree = I->second;
+  assert(&(*IterToFNInFnTree) == &FN && "F should map to FN in FNodesInTree.");
+  // Remove F -> FN and insert G -> FN
+  FNodesInTree.erase(I);
+  FNodesInTree.insert({G, IterToFNInFnTree});
+  // Replace F with G in FN, which is stored inside the FnTree.
+  FN.replaceBy(G);
 }
 
 // Insert a ComparableFunction into the FnTree, or merge it away if equal to one
@@ -1679,6 +1759,8 @@ bool MergeFunctions::insert(Function *NewFunction) {
       FnTree.insert(FunctionNode(NewFunction));
 
   if (Result.second) {
+    assert(FNodesInTree.count(NewFunction) == 0);
+    FNodesInTree.insert({NewFunction, Result.first});
     DEBUG(dbgs() << "Inserting as unique: " << NewFunction->getName() << '\n');
     return false;
   }
@@ -1708,7 +1790,7 @@ bool MergeFunctions::insert(Function *NewFunction) {
     if (OldF.getFunc()->getName() > NewFunction->getName()) {
       // Swap the two functions.
       Function *F = OldF.getFunc();
-      replaceFunctionInTree(Result.first, NewFunction);
+      replaceFunctionInTree(*Result.first, NewFunction);
       NewFunction = F;
       assert(OldF.getFunc() != F && "Must have swapped the functions.");
     }
@@ -1727,18 +1809,13 @@ bool MergeFunctions::insert(Function *NewFunction) {
 // Remove a function from FnTree. If it was already in FnTree, add
 // it to Deferred so that we'll look at it in the next round.
 void MergeFunctions::remove(Function *F) {
-  // We need to make sure we remove F, not a function "equal" to F per the
-  // function equality comparator.
-  FnTreeType::iterator found = FnTree.find(FunctionNode(F));
-  size_t Erased = 0;
-  if (found != FnTree.end() && found->getFunc() == F) {
-    Erased = 1;
-    FnTree.erase(found);
-  }
-
-  if (Erased) {
-    DEBUG(dbgs() << "Removed " << F->getName()
-                 << " from set and deferred it.\n");
+  auto I = FNodesInTree.find(F);
+  if (I != FNodesInTree.end()) {
+    DEBUG(dbgs() << "Deferred " << F->getName()<< ".\n");
+    FnTree.erase(I->second);
+    // I->second has been invalidated, remove it from the FNodesInTree map to
+    // preserve the invariant.
+    FNodesInTree.erase(I);
     Deferred.emplace_back(F);
   }
 }