Remove Merge Functions pointer comparisons
[oota-llvm.git] / lib / Transforms / IPO / MergeFunctions.cpp
index 9ffd6534a65a4511a84919878896e2f37755dea0..5b7198535a42aa08d33f613987169833d64e0541 100644 (file)
@@ -397,12 +397,12 @@ private:
   int cmpTypes(Type *TyL, Type *TyR) const;
 
   int cmpNumbers(uint64_t L, uint64_t R) const;
-
   int cmpAPInts(const APInt &L, const APInt &R) const;
   int cmpAPFloats(const APFloat &L, const APFloat &R) const;
   int cmpInlineAsm(const InlineAsm *L, const InlineAsm *R) const;
   int cmpMem(StringRef L, StringRef R) const;
   int cmpAttrs(const AttributeSet L, const AttributeSet R) const;
+  int cmpRangeMetadata(const MDNode* L, const MDNode* R) const;
 
   // The two functions undergoing comparison.
   const Function *FnL, *FnR;
@@ -481,13 +481,21 @@ int FunctionComparator::cmpAPInts(const APInt &L, const APInt &R) const {
 }
 
 int FunctionComparator::cmpAPFloats(const APFloat &L, const APFloat &R) const {
-  // TODO: This correctly handles all existing fltSemantics, because they all
-  // have different precisions. This isn't very robust, however, if new types
-  // with different exponent ranges are introduced.
+  // Floats are ordered first by semantics (i.e. float, double, half, etc.),
+  // then by value interpreted as a bitstring (aka APInt).
   const fltSemantics &SL = L.getSemantics(), &SR = R.getSemantics();
   if (int Res = cmpNumbers(APFloat::semanticsPrecision(SL),
                            APFloat::semanticsPrecision(SR)))
     return Res;
+  if (int Res = cmpNumbers(APFloat::semanticsMaxExponent(SL),
+                           APFloat::semanticsMaxExponent(SR)))
+    return Res;
+  if (int Res = cmpNumbers(APFloat::semanticsMinExponent(SL),
+                           APFloat::semanticsMinExponent(SR)))
+    return Res;
+  if (int Res = cmpNumbers(APFloat::semanticsSizeInBits(SL),
+                           APFloat::semanticsSizeInBits(SR)))
+    return Res;
   return cmpAPInts(L.bitcastToAPInt(), R.bitcastToAPInt());
 }
 
@@ -524,6 +532,32 @@ int FunctionComparator::cmpAttrs(const AttributeSet L,
   }
   return 0;
 }
+int FunctionComparator::cmpRangeMetadata(const MDNode* L,
+                                         const MDNode* R) const {
+  if (L == R)
+    return 0;
+  if (!L)
+    return -1;
+  if (!R)
+    return 1;
+  // Range metadata is a sequence of numbers. Make sure they are the same
+  // sequence. 
+  // TODO: Note that as this is metadata, it is possible to drop and/or merge
+  // this data when considering functions to merge. Thus this comparison would
+  // return 0 (i.e. equivalent), but merging would become more complicated
+  // because the ranges would need to be unioned. It is not likely that
+  // functions differ ONLY in this metadata if they are actually the same
+  // function semantically.
+  if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands()))
+    return Res;
+  for (size_t I = 0; I < L->getNumOperands(); ++I) {
+    ConstantInt* LLow = mdconst::extract<ConstantInt>(L->getOperand(I));
+    ConstantInt* RLow = mdconst::extract<ConstantInt>(R->getOperand(I));
+    if (int Res = cmpAPInts(LLow->getValue(), RLow->getValue()))
+      return Res;
+  }
+  return 0;
+}
 
 /// Constants comparison:
 /// 1. Check whether type of L constant could be losslessly bitcasted to R
@@ -607,7 +641,7 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) {
     return Res;
 
   if (const auto *SeqL = dyn_cast<ConstantDataSequential>(L)) {
-    const auto *SeqR = dyn_cast<ConstantDataSequential>(R);
+    const auto *SeqR = cast<ConstantDataSequential>(R);
     // This handles ConstantDataArray and ConstantDataVector. Note that we
     // compare the two raw data arrays, which might differ depending on the host
     // endianness. This isn't a problem though, because the endiness of a module
@@ -685,10 +719,38 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) {
     return 0;
   }
   case Value::BlockAddressVal: {
-    // FIXME: This still uses a pointer comparison. It isn't clear how to remove
-    // this. This only affects programs which take BlockAddresses and store them
-    // as constants, which is limited to interepreters, etc.
-    return cmpNumbers((uint64_t)L, (uint64_t)R);
+    const BlockAddress *LBA = cast<BlockAddress>(L);
+    const BlockAddress *RBA = cast<BlockAddress>(R);
+    if (int Res = cmpValues(LBA->getFunction(), RBA->getFunction()))
+      return Res;
+    if (LBA->getFunction() == RBA->getFunction()) {
+      // They are BBs in the same function. Order by which comes first in the
+      // BB order of the function. This order is deterministic.
+      Function* F = LBA->getFunction();
+      BasicBlock *LBB = LBA->getBasicBlock();
+      BasicBlock *RBB = RBA->getBasicBlock();
+      if (LBB == RBB)
+        return 0;
+      for(BasicBlock &BB : F->getBasicBlockList()) {
+        if (&BB == LBB) {
+          assert(&BB != RBB);
+          return -1;
+        }
+        if (&BB == RBB)
+          return 1;
+      }
+      llvm_unreachable("Basic Block Address does not point to a basic block in "
+                       "its function.");
+      return -1;
+    } else {
+      // cmpValues said the functions are the same. So because they aren't
+      // literally the same pointer, they must respectively be the left and
+      // right functions.
+      assert(LBA->getFunction() == FnL && RBA->getFunction() == FnR);
+      // cmpValues will tell us if these are equivalent BasicBlocks, in the
+      // context of their respective functions.
+      return cmpValues(LBA->getBasicBlock(), RBA->getBasicBlock());
+    }
   }
   default: // Unknown constant, abort.
     DEBUG(dbgs() << "Looking at valueID " << L->getValueID() << "\n");
@@ -849,8 +911,8 @@ int FunctionComparator::cmpOperations(const Instruction *L,
     if (int Res =
             cmpNumbers(LI->getSynchScope(), cast<LoadInst>(R)->getSynchScope()))
       return Res;
-    return cmpNumbers((uint64_t)LI->getMetadata(LLVMContext::MD_range),
-                      (uint64_t)cast<LoadInst>(R)->getMetadata(LLVMContext::MD_range));
+    return cmpRangeMetadata(LI->getMetadata(LLVMContext::MD_range),
+        cast<LoadInst>(R)->getMetadata(LLVMContext::MD_range));
   }
   if (const StoreInst *SI = dyn_cast<StoreInst>(L)) {
     if (int Res =
@@ -873,9 +935,9 @@ int FunctionComparator::cmpOperations(const Instruction *L,
     if (int Res =
             cmpAttrs(CI->getAttributes(), cast<CallInst>(R)->getAttributes()))
       return Res;
-    return cmpNumbers(
-        (uint64_t)CI->getMetadata(LLVMContext::MD_range),
-        (uint64_t)cast<CallInst>(R)->getMetadata(LLVMContext::MD_range));
+    return cmpRangeMetadata(
+        CI->getMetadata(LLVMContext::MD_range),
+        cast<CallInst>(R)->getMetadata(LLVMContext::MD_range));
   }
   if (const InvokeInst *CI = dyn_cast<InvokeInst>(L)) {
     if (int Res = cmpNumbers(CI->getCallingConv(),
@@ -884,9 +946,9 @@ int FunctionComparator::cmpOperations(const Instruction *L,
     if (int Res =
             cmpAttrs(CI->getAttributes(), cast<InvokeInst>(R)->getAttributes()))
       return Res;
-    return cmpNumbers(
-        (uint64_t)CI->getMetadata(LLVMContext::MD_range),
-        (uint64_t)cast<InvokeInst>(R)->getMetadata(LLVMContext::MD_range));
+    return cmpRangeMetadata(
+        CI->getMetadata(LLVMContext::MD_range),
+        cast<InvokeInst>(R)->getMetadata(LLVMContext::MD_range));
   }
   if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(L)) {
     ArrayRef<unsigned> LIndices = IVI->getIndices();