#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <vector>
+
using namespace llvm;
#define DEBUG_TYPE "mergefunc"
NextNumber++;
return MapIter->second;
}
+ void clear() {
+ GlobalNumbers.clear();
+ }
};
/// FunctionComparator - Compares two functions to determine whether or not
int cmpTypes(Type *TyL, Type *TyR) const;
int cmpNumbers(uint64_t L, uint64_t R) const;
-
int cmpAPInts(const APInt &L, const APInt &R) const;
int cmpAPFloats(const APFloat &L, const APFloat &R) const;
int cmpInlineAsm(const InlineAsm *L, const InlineAsm *R) const;
int cmpMem(StringRef L, StringRef R) const;
int cmpAttrs(const AttributeSet L, const AttributeSet R) const;
+ int cmpRangeMetadata(const MDNode* L, const MDNode* R) const;
// The two functions undergoing comparison.
const Function *FnL, *FnR;
F = G;
}
- void release() { F = 0; }
+ void release() { F = nullptr; }
};
-}
+} // end anonymous namespace
int FunctionComparator::cmpNumbers(uint64_t L, uint64_t R) const {
if (L < R) return -1;
}
int FunctionComparator::cmpAPFloats(const APFloat &L, const APFloat &R) const {
- // TODO: This correctly handles all existing fltSemantics, because they all
- // have different precisions. This isn't very robust, however, if new types
- // with different exponent ranges are introduced.
+ // Floats are ordered first by semantics (i.e. float, double, half, etc.),
+ // then by value interpreted as a bitstring (aka APInt).
const fltSemantics &SL = L.getSemantics(), &SR = R.getSemantics();
if (int Res = cmpNumbers(APFloat::semanticsPrecision(SL),
APFloat::semanticsPrecision(SR)))
return Res;
+ if (int Res = cmpNumbers(APFloat::semanticsMaxExponent(SL),
+ APFloat::semanticsMaxExponent(SR)))
+ return Res;
+ if (int Res = cmpNumbers(APFloat::semanticsMinExponent(SL),
+ APFloat::semanticsMinExponent(SR)))
+ return Res;
+ if (int Res = cmpNumbers(APFloat::semanticsSizeInBits(SL),
+ APFloat::semanticsSizeInBits(SR)))
+ return Res;
return cmpAPInts(L.bitcastToAPInt(), R.bitcastToAPInt());
}
return 0;
}
+int FunctionComparator::cmpRangeMetadata(const MDNode* L,
+ const MDNode* R) const {
+ if (L == R)
+ return 0;
+ if (!L)
+ return -1;
+ if (!R)
+ return 1;
+ // Range metadata is a sequence of numbers. Make sure they are the same
+ // sequence.
+ // TODO: Note that as this is metadata, it is possible to drop and/or merge
+ // this data when considering functions to merge. Thus this comparison would
+ // return 0 (i.e. equivalent), but merging would become more complicated
+ // because the ranges would need to be unioned. It is not likely that
+ // functions differ ONLY in this metadata if they are actually the same
+ // function semantically.
+ if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands()))
+ return Res;
+ for (size_t I = 0; I < L->getNumOperands(); ++I) {
+ ConstantInt* LLow = mdconst::extract<ConstantInt>(L->getOperand(I));
+ ConstantInt* RLow = mdconst::extract<ConstantInt>(R->getOperand(I));
+ if (int Res = cmpAPInts(LLow->getValue(), RLow->getValue()))
+ return Res;
+ }
+ return 0;
+}
+
/// Constants comparison:
/// 1. Check whether type of L constant could be losslessly bitcasted to R
/// type.
return Res;
if (const auto *SeqL = dyn_cast<ConstantDataSequential>(L)) {
- const auto *SeqR = dyn_cast<ConstantDataSequential>(R);
+ const auto *SeqR = cast<ConstantDataSequential>(R);
// This handles ConstantDataArray and ConstantDataVector. Note that we
// compare the two raw data arrays, which might differ depending on the host
// endianness. This isn't a problem though, because the endiness of a module
}
switch (L->getValueID()) {
- case Value::UndefValueVal: return TypesRes;
+ case Value::UndefValueVal:
+ case Value::ConstantTokenNoneVal:
+ return TypesRes;
case Value::ConstantIntVal: {
const APInt &LInt = cast<ConstantInt>(L)->getValue();
const APInt &RInt = cast<ConstantInt>(R)->getValue();
return 0;
}
case Value::BlockAddressVal: {
- // FIXME: This still uses a pointer comparison. It isn't clear how to remove
- // this. This only affects programs which take BlockAddresses and store them
- // as constants, which is limited to interepreters, etc.
- return cmpNumbers((uint64_t)L, (uint64_t)R);
+ const BlockAddress *LBA = cast<BlockAddress>(L);
+ const BlockAddress *RBA = cast<BlockAddress>(R);
+ if (int Res = cmpValues(LBA->getFunction(), RBA->getFunction()))
+ return Res;
+ if (LBA->getFunction() == RBA->getFunction()) {
+ // They are BBs in the same function. Order by which comes first in the
+ // BB order of the function. This order is deterministic.
+ Function* F = LBA->getFunction();
+ BasicBlock *LBB = LBA->getBasicBlock();
+ BasicBlock *RBB = RBA->getBasicBlock();
+ if (LBB == RBB)
+ return 0;
+ for(BasicBlock &BB : F->getBasicBlockList()) {
+ if (&BB == LBB) {
+ assert(&BB != RBB);
+ return -1;
+ }
+ if (&BB == RBB)
+ return 1;
+ }
+ llvm_unreachable("Basic Block Address does not point to a basic block in "
+ "its function.");
+ return -1;
+ } else {
+ // cmpValues said the functions are the same. So because they aren't
+ // literally the same pointer, they must respectively be the left and
+ // right functions.
+ assert(LBA->getFunction() == FnL && RBA->getFunction() == FnR);
+ // cmpValues will tell us if these are equivalent BasicBlocks, in the
+ // context of their respective functions.
+ return cmpValues(LBA->getBasicBlock(), RBA->getBasicBlock());
+ }
}
default: // Unknown constant, abort.
DEBUG(dbgs() << "Looking at valueID " << L->getValueID() << "\n");
/// defines total ordering among the types set.
/// See method declaration comments for more details.
int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const {
-
PointerType *PTyL = dyn_cast<PointerType>(TyL);
PointerType *PTyR = dyn_cast<PointerType>(TyR);
if (int Res =
cmpNumbers(LI->getSynchScope(), cast<LoadInst>(R)->getSynchScope()))
return Res;
- return cmpNumbers((uint64_t)LI->getMetadata(LLVMContext::MD_range),
- (uint64_t)cast<LoadInst>(R)->getMetadata(LLVMContext::MD_range));
+ return cmpRangeMetadata(LI->getMetadata(LLVMContext::MD_range),
+ cast<LoadInst>(R)->getMetadata(LLVMContext::MD_range));
}
if (const StoreInst *SI = dyn_cast<StoreInst>(L)) {
if (int Res =
if (int Res =
cmpAttrs(CI->getAttributes(), cast<CallInst>(R)->getAttributes()))
return Res;
- return cmpNumbers(
- (uint64_t)CI->getMetadata(LLVMContext::MD_range),
- (uint64_t)cast<CallInst>(R)->getMetadata(LLVMContext::MD_range));
+ return cmpRangeMetadata(
+ CI->getMetadata(LLVMContext::MD_range),
+ cast<CallInst>(R)->getMetadata(LLVMContext::MD_range));
}
if (const InvokeInst *CI = dyn_cast<InvokeInst>(L)) {
if (int Res = cmpNumbers(CI->getCallingConv(),
if (int Res =
cmpAttrs(CI->getAttributes(), cast<InvokeInst>(R)->getAttributes()))
return Res;
- return cmpNumbers(
- (uint64_t)CI->getMetadata(LLVMContext::MD_range),
- (uint64_t)cast<InvokeInst>(R)->getMetadata(LLVMContext::MD_range));
+ return cmpRangeMetadata(
+ CI->getMetadata(LLVMContext::MD_range),
+ cast<InvokeInst>(R)->getMetadata(LLVMContext::MD_range));
}
if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(L)) {
ArrayRef<unsigned> LIndices = IVI->getIndices();
if (GEPL->accumulateConstantOffset(DL, OffsetL) &&
GEPR->accumulateConstantOffset(DL, OffsetR))
return cmpAPInts(OffsetL, OffsetR);
- if (int Res = cmpTypes(GEPL->getPointerOperand()->getType(),
- GEPR->getPointerOperand()->getType()))
+ if (int Res = cmpTypes(GEPL->getSourceElementType(),
+ GEPR->getSourceElementType()))
return Res;
if (int Res = cmpNumbers(GEPL->getNumOperands(), GEPR->getNumOperands()))
BasicBlock::const_iterator InstR = BBR->begin(), InstRE = BBR->end();
do {
- if (int Res = cmpValues(InstL, InstR))
+ if (int Res = cmpValues(&*InstL, &*InstR))
return Res;
const GetElementPtrInst *GEPL = dyn_cast<GetElementPtrInst>(InstL);
if (int Res = cmpGEPs(GEPL, GEPR))
return Res;
} else {
- if (int Res = cmpOperations(InstL, InstR))
+ if (int Res = cmpOperations(&*InstL, &*InstR))
return Res;
assert(InstL->getNumOperands() == InstR->getNumOperands());
Value *OpR = InstR->getOperand(i);
if (int Res = cmpValues(OpL, OpR))
return Res;
- if (int Res = cmpNumbers(OpL->getValueID(), OpR->getValueID()))
- return Res;
- // TODO: Already checked in cmpOperation
- if (int Res = cmpTypes(OpL->getType(), OpR->getType()))
- return Res;
+ // cmpValues should ensure this is true.
+ assert(cmpTypes(OpL->getType(), OpR->getType()) == 0);
}
}
// Test whether the two functions have equivalent behaviour.
int FunctionComparator::compare() {
-
sn_mapL.clear();
sn_mapR.clear();
ArgRI = FnR->arg_begin(),
ArgLE = FnL->arg_end();
ArgLI != ArgLE; ++ArgLI, ++ArgRI) {
- if (cmpValues(ArgLI, ArgRI) != 0)
+ if (cmpValues(&*ArgLI, &*ArgRI) != 0)
llvm_unreachable("Arguments repeat!");
}
public:
static char ID;
MergeFunctions()
- : ModulePass(ID), FnTree(FunctionNodeCmp(&GlobalNumbers)),
+ : ModulePass(ID), FnTree(FunctionNodeCmp(&GlobalNumbers)), FNodesInTree(),
HasGlobalAliases(false) {
initializeMergeFunctionsPass(*PassRegistry::getPassRegistry());
}
void writeAlias(Function *F, Function *G);
/// Replace function F with function G in the function tree.
- void replaceFunctionInTree(FnTreeType::iterator &IterToF, Function *G);
+ void replaceFunctionInTree(const FunctionNode &FN, Function *G);
/// The set of all distinct functions. Use the insert() and remove() methods
- /// to modify it.
+ /// to modify it. The map allows efficient lookup and deferring of Functions.
FnTreeType FnTree;
+ // Map functions to the iterators of the FunctionNode which contains them
+ // in the FnTree. This must be updated carefully whenever the FnTree is
+ // modified, i.e. in insert(), remove(), and replaceFunctionInTree(), to avoid
+ // dangling iterators into FnTree. The invariant that preserves this is that
+ // there is exactly one mapping F -> FN for each FunctionNode FN in FnTree.
+ ValueMap<Function*, FnTreeType::iterator> FNodesInTree;
/// Whether or not the target supports global aliases.
bool HasGlobalAliases;
};
-} // end anonymous namespace
+} // end anonymous namespace
char MergeFunctions::ID = 0;
INITIALIZE_PASS(MergeFunctions, "mergefunc", "Merge Functions", false, false)
} while (!Deferred.empty());
FnTree.clear();
+ GlobalNumbers.clear();
return Changed;
}
CallSite CS(U->getUser());
if (CS && CS.isCallee(U)) {
// Transfer the called function's attributes to the call site. Due to the
- // bitcast we will 'loose' ABI changing attributes because the 'called
+ // bitcast we will 'lose' ABI changing attributes because the 'called
// function' is no longer a Function* but the bitcast. Code that looks up
// the attributes from the called function will fail.
+
+ // FIXME: This is not actually true, at least not anymore. The callsite
+ // will always have the same ABI affecting attributes as the callee,
+ // because otherwise the original input has UB. Note that Old and New
+ // always have matching ABI, so no attributes need to be changed.
+ // Transferring other attributes may help other optimizations, but that
+ // should be done uniformly and not in this ad-hoc way.
auto &Context = New->getContext();
auto NewFuncAttrs = New->getAttributes();
auto CallSiteAttrs = CS.getAttributes();
SmallVector<Value *, 16> Args;
unsigned i = 0;
FunctionType *FFTy = F->getFunctionType();
- for (Function::arg_iterator AI = NewG->arg_begin(), AE = NewG->arg_end();
- AI != AE; ++AI) {
- Args.push_back(createCast(Builder, (Value*)AI, FFTy->getParamType(i)));
+ for (Argument & AI : NewG->args()) {
+ Args.push_back(createCast(Builder, &AI, FFTy->getParamType(i)));
++i;
}
CallInst *CI = Builder.CreateCall(F, Args);
CI->setTailCall();
CI->setCallingConv(F->getCallingConv());
+ CI->setAttributes(F->getAttributes());
if (NewG->getReturnType()->isVoidTy()) {
Builder.CreateRetVoid();
} else {
// Replace G with an alias to F and delete G.
void MergeFunctions::writeAlias(Function *F, Function *G) {
- PointerType *PTy = G->getType();
- auto *GA = GlobalAlias::create(PTy, G->getLinkage(), "", F);
+ auto *GA = GlobalAlias::create(G->getLinkage(), "", F);
F->setAlignment(std::max(F->getAlignment(), G->getAlignment()));
GA->takeName(G);
GA->setVisibility(G->getVisibility());
++NumFunctionsMerged;
}
-/// Replace function F for function G in the map.
-void MergeFunctions::replaceFunctionInTree(FnTreeType::iterator &IterToF,
+/// Replace function F by function G.
+void MergeFunctions::replaceFunctionInTree(const FunctionNode &FN,
Function *G) {
- Function *F = IterToF->getFunc();
-
- // A total order is already guaranteed otherwise because we process strong
- // functions before weak functions.
- assert(((F->mayBeOverridden() && G->mayBeOverridden()) ||
- (!F->mayBeOverridden() && !G->mayBeOverridden())) &&
- "Only change functions if both are strong or both are weak");
- (void)F;
+ Function *F = FN.getFunc();
assert(FunctionComparator(F, G, &GlobalNumbers).compare() == 0 &&
"The two functions must be equal");
-
- IterToF->replaceBy(G);
+
+ auto I = FNodesInTree.find(F);
+ assert(I != FNodesInTree.end() && "F should be in FNodesInTree");
+ assert(FNodesInTree.count(G) == 0 && "FNodesInTree should not contain G");
+
+ FnTreeType::iterator IterToFNInFnTree = I->second;
+ assert(&(*IterToFNInFnTree) == &FN && "F should map to FN in FNodesInTree.");
+ // Remove F -> FN and insert G -> FN
+ FNodesInTree.erase(I);
+ FNodesInTree.insert({G, IterToFNInFnTree});
+ // Replace F with G in FN, which is stored inside the FnTree.
+ FN.replaceBy(G);
}
// Insert a ComparableFunction into the FnTree, or merge it away if equal to one
FnTree.insert(FunctionNode(NewFunction));
if (Result.second) {
+ assert(FNodesInTree.count(NewFunction) == 0);
+ FNodesInTree.insert({NewFunction, Result.first});
DEBUG(dbgs() << "Inserting as unique: " << NewFunction->getName() << '\n');
return false;
}
if (OldF.getFunc()->getName() > NewFunction->getName()) {
// Swap the two functions.
Function *F = OldF.getFunc();
- replaceFunctionInTree(Result.first, NewFunction);
+ replaceFunctionInTree(*Result.first, NewFunction);
NewFunction = F;
assert(OldF.getFunc() != F && "Must have swapped the functions.");
}
// Remove a function from FnTree. If it was already in FnTree, add
// it to Deferred so that we'll look at it in the next round.
void MergeFunctions::remove(Function *F) {
- // We need to make sure we remove F, not a function "equal" to F per the
- // function equality comparator.
- FnTreeType::iterator found = FnTree.find(FunctionNode(F));
- size_t Erased = 0;
- if (found != FnTree.end() && found->getFunc() == F) {
- Erased = 1;
- FnTree.erase(found);
- }
-
- if (Erased) {
- DEBUG(dbgs() << "Removed " << F->getName()
- << " from set and deferred it.\n");
+ auto I = FNodesInTree.find(F);
+ if (I != FNodesInTree.end()) {
+ DEBUG(dbgs() << "Deferred " << F->getName()<< ".\n");
+ FnTree.erase(I->second);
+ // I->second has been invalidated, remove it from the FNodesInTree map to
+ // preserve the invariant.
+ FNodesInTree.erase(I);
Deferred.emplace_back(F);
}
}