Use references now that it is natural to do so.
[oota-llvm.git] / lib / Linker / LinkModules.cpp
index 7ba622f6ee1d34460cabb9380a2d71e710f299bf..c2f472c7e33f5f39834ad5bdca25fd462caa39cc 100644 (file)
@@ -390,7 +390,8 @@ void LinkDiagnosticInfo::print(DiagnosticPrinter &DP) const { DP << Msg; }
 /// This is an implementation class for the LinkModules function, which is the
 /// entrypoint for this file.
 class ModuleLinker {
-  Module *DstM, *SrcM;
+  Module &DstM;
+  Module &SrcM;
 
   TypeMapTy TypeMap;
   ValueMaterializerTy ValMaterializer;
@@ -431,11 +432,11 @@ class ModuleLinker {
   bool HasError = false;
 
 public:
-  ModuleLinker(Module *dstM, Linker::IdentifiedStructTypeSet &Set, Module *srcM,
+  ModuleLinker(Module &DstM, Linker::IdentifiedStructTypeSet &Set, Module &SrcM,
                DiagnosticHandlerFunction DiagnosticHandler, unsigned Flags,
                const FunctionInfoIndex *Index = nullptr,
                Function *FuncToImport = nullptr)
-      : DstM(dstM), SrcM(srcM), TypeMap(Set), ValMaterializer(this),
+      : DstM(DstM), SrcM(SrcM), TypeMap(Set), ValMaterializer(this),
         DiagnosticHandler(DiagnosticHandler), Flags(Flags), ImportIndex(Index),
         ImportFunction(FuncToImport), HasExportedFunctions(false),
         DoneLinkingBodies(false) {
@@ -446,7 +447,7 @@ public:
     // backend compilation, and we need to see if it has functions that
     // may be exported to another backend compilation.
     if (ImportIndex && !ImportFunction)
-      HasExportedFunctions = ImportIndex->hasExportedFunctions(SrcM);
+      HasExportedFunctions = ImportIndex->hasExportedFunctions(&SrcM);
   }
 
   bool run();
@@ -485,7 +486,7 @@ private:
     DiagnosticHandler(LinkDiagnosticInfo(DS_Warning, Message));
   }
 
-  bool getComdatLeader(Module *M, StringRef ComdatName,
+  bool getComdatLeader(Module &M, StringRef ComdatName,
                        const GlobalVariable *&GVar);
   bool computeResultingSelectionKind(StringRef ComdatName,
                                      Comdat::SelectionKind Src,
@@ -508,7 +509,7 @@ private:
       return nullptr;
 
     // Otherwise see if we have a match in the destination module's symtab.
-    GlobalValue *DGV = DstM->getNamedValue(getName(SrcGV));
+    GlobalValue *DGV = DstM.getNamedValue(getName(SrcGV));
     if (!DGV)
       return nullptr;
 
@@ -793,7 +794,7 @@ ModuleLinker::copyGlobalVariableProto(TypeMapTy &TypeMap,
   // identical version of the symbol over in the dest module... the
   // initializer will be filled in later by LinkGlobalInits.
   GlobalVariable *NewDGV = new GlobalVariable(
-      *DstM, TypeMap.get(SGVar->getType()->getElementType()),
+      DstM, TypeMap.get(SGVar->getType()->getElementType()),
       SGVar->isConstant(), getLinkage(SGVar), /*init*/ nullptr, getName(SGVar),
       /*insertbefore*/ nullptr, SGVar->getThreadLocalMode(),
       SGVar->getType()->getAddressSpace());
@@ -808,7 +809,7 @@ Function *ModuleLinker::copyFunctionProto(TypeMapTy &TypeMap,
   // If there is no linkage to be performed or we are linking from the source,
   // bring SF over.
   return Function::Create(TypeMap.get(SF->getFunctionType()), getLinkage(SF),
-                          getName(SF), DstM);
+                          getName(SF), &DstM);
 }
 
 /// Set up prototypes for any aliases that come over from the source module.
@@ -842,7 +843,7 @@ GlobalValue *ModuleLinker::copyGlobalAliasProto(TypeMapTy &TypeMap,
   // bring over SGA.
   auto *Ty = TypeMap.get(SGA->getValueType());
   return GlobalAlias::create(Ty, SGA->getType()->getPointerAddressSpace(),
-                             getLinkage(SGA), getName(SGA), DstM);
+                             getLinkage(SGA), getName(SGA), &DstM);
 }
 
 static GlobalValue::VisibilityTypes
@@ -936,9 +937,9 @@ void ModuleLinker::materializeInitFor(GlobalValue *New, GlobalValue *Old) {
   linkGlobalValueBody(*Old);
 }
 
-bool ModuleLinker::getComdatLeader(Module *M, StringRef ComdatName,
+bool ModuleLinker::getComdatLeader(Module &M, StringRef ComdatName,
                                    const GlobalVariable *&GVar) {
-  const GlobalValue *GVal = M->getNamedValue(ComdatName);
+  const GlobalValue *GVal = M.getNamedValue(ComdatName);
   if (const auto *GA = dyn_cast_or_null<GlobalAlias>(GVal)) {
     GVal = GA->getBaseObject();
     if (!GVal)
@@ -997,8 +998,8 @@ bool ModuleLinker::computeResultingSelectionKind(StringRef ComdatName,
         getComdatLeader(SrcM, ComdatName, SrcGV))
       return true;
 
-    const DataLayout &DstDL = DstM->getDataLayout();
-    const DataLayout &SrcDL = SrcM->getDataLayout();
+    const DataLayout &DstDL = DstM.getDataLayout();
+    const DataLayout &SrcDL = SrcM.getDataLayout();
     uint64_t DstSize =
         DstDL.getTypeAllocSize(DstGV->getType()->getPointerElementType());
     uint64_t SrcSize =
@@ -1030,7 +1031,7 @@ bool ModuleLinker::getComdatResult(const Comdat *SrcC,
                                    bool &LinkFromSrc) {
   Comdat::SelectionKind SSK = SrcC->getSelectionKind();
   StringRef ComdatName = SrcC->getName();
-  Module::ComdatSymTabType &ComdatSymTab = DstM->getComdatSymbolTable();
+  Module::ComdatSymTabType &ComdatSymTab = DstM.getComdatSymbolTable();
   Module::ComdatSymTabType::iterator DstCI = ComdatSymTab.find(ComdatName);
 
   if (DstCI == ComdatSymTab.end()) {
@@ -1158,7 +1159,7 @@ bool ModuleLinker::shouldLinkFromSource(bool &LinkFromSrc,
 /// types 'Foo' but one got renamed when the module was loaded into the same
 /// LLVMContext.
 void ModuleLinker::computeTypeMapping() {
-  for (GlobalValue &SGV : SrcM->globals()) {
+  for (GlobalValue &SGV : SrcM.globals()) {
     GlobalValue *DGV = getLinkedToGlobal(&SGV);
     if (!DGV)
       continue;
@@ -1174,12 +1175,12 @@ void ModuleLinker::computeTypeMapping() {
     TypeMap.addTypeMapping(DAT->getElementType(), SAT->getElementType());
   }
 
-  for (GlobalValue &SGV : *SrcM) {
+  for (GlobalValue &SGV : SrcM) {
     if (GlobalValue *DGV = getLinkedToGlobal(&SGV))
       TypeMap.addTypeMapping(DGV->getType(), SGV.getType());
   }
 
-  for (GlobalValue &SGV : SrcM->aliases()) {
+  for (GlobalValue &SGV : SrcM.aliases()) {
     if (GlobalValue *DGV = getLinkedToGlobal(&SGV))
       TypeMap.addTypeMapping(DGV->getType(), SGV.getType());
   }
@@ -1188,7 +1189,7 @@ void ModuleLinker::computeTypeMapping() {
   // At this point, the destination module may have a type "%foo = { i32 }" for
   // example.  When the source module got loaded into the same LLVMContext, if
   // it had the same type, it would have been renamed to "%foo.42 = { i32 }".
-  std::vector<StructType *> Types = SrcM->getIdentifiedStructTypes();
+  std::vector<StructType *> Types = SrcM.getIdentifiedStructTypes();
   for (StructType *ST : Types) {
     if (!ST->hasName())
       continue;
@@ -1201,7 +1202,7 @@ void ModuleLinker::computeTypeMapping() {
       continue;
 
     // Check to see if the destination module has a struct with the prefix name.
-    StructType *DST = DstM->getTypeByName(ST->getName().substr(0, DotPos));
+    StructType *DST = DstM.getTypeByName(ST->getName().substr(0, DotPos));
     if (!DST)
       continue;
 
@@ -1275,10 +1276,10 @@ static void upgradeGlobalArray(GlobalVariable *GV) {
 
 void ModuleLinker::upgradeMismatchedGlobalArray(StringRef Name) {
   // Look for the global arrays.
-  auto *DstGV = dyn_cast_or_null<GlobalVariable>(DstM->getNamedValue(Name));
+  auto *DstGV = dyn_cast_or_null<GlobalVariable>(DstM.getNamedValue(Name));
   if (!DstGV)
     return;
-  auto *SrcGV = dyn_cast_or_null<GlobalVariable>(SrcM->getNamedValue(Name));
+  auto *SrcGV = dyn_cast_or_null<GlobalVariable>(SrcM.getNamedValue(Name));
   if (!SrcGV)
     return;
 
@@ -1380,7 +1381,7 @@ bool ModuleLinker::linkAppendingVarProto(GlobalVariable *DstGV,
 
   // Create the new global variable.
   GlobalVariable *NG = new GlobalVariable(
-      *DstM, NewType, SrcGV->isConstant(), SrcGV->getLinkage(),
+      DstM, NewType, SrcGV->isConstant(), SrcGV->getLinkage(),
       /*init*/ nullptr, /*name*/ "", DstGV, SrcGV->getThreadLocalMode(),
       SrcGV->getType()->getAddressSpace());
 
@@ -1430,7 +1431,7 @@ bool ModuleLinker::linkGlobalValueProto(GlobalValue *SGV) {
   if (const Comdat *SC = SGV->getComdat()) {
     Comdat::SelectionKind SK;
     std::tie(SK, LinkFromSrc) = ComdatsChosen[SC];
-    C = DstM->getOrInsertComdat(SC->getName());
+    C = DstM.getOrInsertComdat(SC->getName());
     C->setSelectionKind(SK);
     if (SGV->hasInternalLinkage())
       LinkFromSrc = true;
@@ -1606,12 +1607,12 @@ bool ModuleLinker::linkGlobalValueBody(GlobalValue &Src) {
 
 /// Insert all of the named MDNodes in Src into the Dest module.
 void ModuleLinker::linkNamedMDNodes() {
-  const NamedMDNode *SrcModFlags = SrcM->getModuleFlagsMetadata();
-  for (const NamedMDNode &NMD : SrcM->named_metadata()) {
+  const NamedMDNode *SrcModFlags = SrcM.getModuleFlagsMetadata();
+  for (const NamedMDNode &NMD : SrcM.named_metadata()) {
     // Don't link module flags here. Do them separately.
     if (&NMD == SrcModFlags)
       continue;
-    NamedMDNode *DestNMD = DstM->getOrInsertNamedMetadata(NMD.getName());
+    NamedMDNode *DestNMD = DstM.getOrInsertNamedMetadata(NMD.getName());
     // Add Src elements into Dest node.
     for (const MDNode *op : NMD.operands())
       DestNMD->addOperand(MapMetadata(
@@ -1623,12 +1624,12 @@ void ModuleLinker::linkNamedMDNodes() {
 /// Merge the linker flags in Src into the Dest module.
 bool ModuleLinker::linkModuleFlagsMetadata() {
   // If the source module has no module flags, we are done.
-  const NamedMDNode *SrcModFlags = SrcM->getModuleFlagsMetadata();
+  const NamedMDNode *SrcModFlags = SrcM.getModuleFlagsMetadata();
   if (!SrcModFlags) return false;
 
   // If the destination module doesn't have module flags yet, then just copy
   // over the source module's flags.
-  NamedMDNode *DstModFlags = DstM->getOrInsertModuleFlagsMetadata();
+  NamedMDNode *DstModFlags = DstM.getOrInsertModuleFlagsMetadata();
   if (DstModFlags->getNumOperands() == 0) {
     for (unsigned I = 0, E = SrcModFlags->getNumOperands(); I != E; ++I)
       DstModFlags->addOperand(SrcModFlags->getOperand(I));
@@ -1711,7 +1712,7 @@ bool ModuleLinker::linkModuleFlagsMetadata() {
 
     auto replaceDstValue = [&](MDNode *New) {
       Metadata *FlagOps[] = {DstOp->getOperand(0), ID, New};
-      MDNode *Flag = MDNode::get(DstM->getContext(), FlagOps);
+      MDNode *Flag = MDNode::get(DstM.getContext(), FlagOps);
       DstModFlags->setOperand(DstIndex, Flag);
       Flags[ID].first = Flag;
     };
@@ -1744,7 +1745,7 @@ bool ModuleLinker::linkModuleFlagsMetadata() {
       MDs.append(DstValue->op_begin(), DstValue->op_end());
       MDs.append(SrcValue->op_begin(), SrcValue->op_end());
 
-      replaceDstValue(MDNode::get(DstM->getContext(), MDs));
+      replaceDstValue(MDNode::get(DstM.getContext(), MDs));
       break;
     }
     case Module::AppendUnique: {
@@ -1754,7 +1755,7 @@ bool ModuleLinker::linkModuleFlagsMetadata() {
       Elts.insert(DstValue->op_begin(), DstValue->op_end());
       Elts.insert(SrcValue->op_begin(), SrcValue->op_end());
 
-      replaceDstValue(MDNode::get(DstM->getContext(),
+      replaceDstValue(MDNode::get(DstM.getContext(),
                                   makeArrayRef(Elts.begin(), Elts.end())));
       break;
     }
@@ -1833,51 +1834,47 @@ bool ModuleLinker::linkIfNeeded(GlobalValue &GV) {
 }
 
 bool ModuleLinker::run() {
-  assert(DstM && "Null destination module");
-  assert(SrcM && "Null source module");
-
   // Inherit the target data from the source module if the destination module
   // doesn't have one already.
-  if (DstM->getDataLayout().isDefault())
-    DstM->setDataLayout(SrcM->getDataLayout());
+  if (DstM.getDataLayout().isDefault())
+    DstM.setDataLayout(SrcM.getDataLayout());
 
-  if (SrcM->getDataLayout() != DstM->getDataLayout()) {
+  if (SrcM.getDataLayout() != DstM.getDataLayout()) {
     emitWarning("Linking two modules of different data layouts: '" +
-                SrcM->getModuleIdentifier() + "' is '" +
-                SrcM->getDataLayoutStr() + "' whereas '" +
-                DstM->getModuleIdentifier() + "' is '" +
-                DstM->getDataLayoutStr() + "'\n");
+                SrcM.getModuleIdentifier() + "' is '" +
+                SrcM.getDataLayoutStr() + "' whereas '" +
+                DstM.getModuleIdentifier() + "' is '" +
+                DstM.getDataLayoutStr() + "'\n");
   }
 
   // Copy the target triple from the source to dest if the dest's is empty.
-  if (DstM->getTargetTriple().empty() && !SrcM->getTargetTriple().empty())
-    DstM->setTargetTriple(SrcM->getTargetTriple());
+  if (DstM.getTargetTriple().empty() && !SrcM.getTargetTriple().empty())
+    DstM.setTargetTriple(SrcM.getTargetTriple());
 
-  Triple SrcTriple(SrcM->getTargetTriple()), DstTriple(DstM->getTargetTriple());
+  Triple SrcTriple(SrcM.getTargetTriple()), DstTriple(DstM.getTargetTriple());
 
-  if (!SrcM->getTargetTriple().empty() && !triplesMatch(SrcTriple, DstTriple))
+  if (!SrcM.getTargetTriple().empty() && !triplesMatch(SrcTriple, DstTriple))
     emitWarning("Linking two modules of different target triples: " +
-                SrcM->getModuleIdentifier() + "' is '" +
-                SrcM->getTargetTriple() + "' whereas '" +
-                DstM->getModuleIdentifier() + "' is '" +
-                DstM->getTargetTriple() + "'\n");
+                SrcM.getModuleIdentifier() + "' is '" + SrcM.getTargetTriple() +
+                "' whereas '" + DstM.getModuleIdentifier() + "' is '" +
+                DstM.getTargetTriple() + "'\n");
 
-  DstM->setTargetTriple(mergeTriples(SrcTriple, DstTriple));
+  DstM.setTargetTriple(mergeTriples(SrcTriple, DstTriple));
 
   // Append the module inline asm string.
-  if (!SrcM->getModuleInlineAsm().empty()) {
-    if (DstM->getModuleInlineAsm().empty())
-      DstM->setModuleInlineAsm(SrcM->getModuleInlineAsm());
+  if (!SrcM.getModuleInlineAsm().empty()) {
+    if (DstM.getModuleInlineAsm().empty())
+      DstM.setModuleInlineAsm(SrcM.getModuleInlineAsm());
     else
-      DstM->setModuleInlineAsm(DstM->getModuleInlineAsm()+"\n"+
-                               SrcM->getModuleInlineAsm());
+      DstM.setModuleInlineAsm(DstM.getModuleInlineAsm() + "\n" +
+                              SrcM.getModuleInlineAsm());
   }
 
   // Loop over all of the linked values to compute type mappings.
   computeTypeMapping();
 
   ComdatsChosen.clear();
-  for (const auto &SMEC : SrcM->getComdatSymbolTable()) {
+  for (const auto &SMEC : SrcM.getComdatSymbolTable()) {
     const Comdat &C = SMEC.getValue();
     if (ComdatsChosen.count(&C))
       continue;
@@ -1891,37 +1888,37 @@ bool ModuleLinker::run() {
   // Upgrade mismatched global arrays.
   upgradeMismatchedGlobals();
 
-  for (GlobalVariable &GV : SrcM->globals())
+  for (GlobalVariable &GV : SrcM.globals())
     if (const Comdat *SC = GV.getComdat())
       ComdatMembers[SC].push_back(&GV);
 
-  for (Function &SF : *SrcM)
+  for (Function &SF : SrcM)
     if (const Comdat *SC = SF.getComdat())
       ComdatMembers[SC].push_back(&SF);
 
-  for (GlobalAlias &GA : SrcM->aliases())
+  for (GlobalAlias &GA : SrcM.aliases())
     if (const Comdat *SC = GA.getComdat())
       ComdatMembers[SC].push_back(&GA);
 
   // Insert all of the globals in src into the DstM module... without linking
   // initializers (which could refer to functions not yet mapped over).
-  for (GlobalVariable &GV : SrcM->globals())
+  for (GlobalVariable &GV : SrcM.globals())
     if (linkIfNeeded(GV))
       return true;
 
-  for (Function &SF : *SrcM)
+  for (Function &SF : SrcM)
     if (linkIfNeeded(SF))
       return true;
 
-  for (GlobalAlias &GA : SrcM->aliases())
+  for (GlobalAlias &GA : SrcM.aliases())
     if (linkIfNeeded(GA))
       return true;
 
-  for (const auto &Entry : DstM->getComdatSymbolTable()) {
+  for (const auto &Entry : DstM.getComdatSymbolTable()) {
     const Comdat &C = Entry.getValue();
     if (C.getSelectionKind() == Comdat::Any)
       continue;
-    const GlobalValue *GV = SrcM->getNamedValue(C.getName());
+    const GlobalValue *GV = SrcM.getNamedValue(C.getName());
     if (GV)
       MapValue(GV, ValueMap, RF_MoveDistinctMDs, &TypeMap, &ValMaterializer);
   }
@@ -2032,12 +2029,10 @@ bool Linker::IdentifiedStructTypeSet::hasType(StructType *Ty) {
   return *I == Ty;
 }
 
-Linker::Linker(Module *M, DiagnosticHandlerFunction DiagnosticHandler) {
-  this->Composite = M;
-  this->DiagnosticHandler = DiagnosticHandler;
-
+Linker::Linker(Module &M, DiagnosticHandlerFunction DiagnosticHandler)
+    : Composite(M), DiagnosticHandler(DiagnosticHandler) {
   TypeFinder StructTypes;
-  StructTypes.run(*M, true);
+  StructTypes.run(M, true);
   for (StructType *Ty : StructTypes) {
     if (Ty->isOpaque())
       IdentifiedStructTypes.addOpaque(Ty);
@@ -2046,18 +2041,18 @@ Linker::Linker(Module *M, DiagnosticHandlerFunction DiagnosticHandler) {
   }
 }
 
-Linker::Linker(Module *M)
+Linker::Linker(Module &M)
     : Linker(M, [this](const DiagnosticInfo &DI) {
-        Composite->getContext().diagnose(DI);
+        Composite.getContext().diagnose(DI);
       }) {}
 
-bool Linker::linkInModule(Module *Src, unsigned Flags,
+bool Linker::linkInModule(Module &Src, unsigned Flags,
                           const FunctionInfoIndex *Index,
                           Function *FuncToImport) {
   ModuleLinker TheLinker(Composite, IdentifiedStructTypes, Src,
                          DiagnosticHandler, Flags, Index, FuncToImport);
   bool RetCode = TheLinker.run();
-  Composite->dropTriviallyDeadConstantArrays();
+  Composite.dropTriviallyDeadConstantArrays();
   return RetCode;
 }
 
@@ -2070,14 +2065,14 @@ bool Linker::linkInModule(Module *Src, unsigned Flags,
 /// true is returned and ErrorMsg (if not null) is set to indicate the problem.
 /// Upon failure, the Dest module could be in a modified state, and shouldn't be
 /// relied on to be consistent.
-bool Linker::LinkModules(Module *Dest, Module *Src,
+bool Linker::linkModules(Module &Dest, Module &Src,
                          DiagnosticHandlerFunction DiagnosticHandler,
                          unsigned Flags) {
   Linker L(Dest, DiagnosticHandler);
   return L.linkInModule(Src, Flags);
 }
 
-bool Linker::LinkModules(Module *Dest, Module *Src, unsigned Flags) {
+bool Linker::linkModules(Module &Dest, Module &Src, unsigned Flags) {
   Linker L(Dest);
   return L.linkInModule(Src, Flags);
 }
@@ -2093,8 +2088,8 @@ LLVMBool LLVMLinkModules(LLVMModuleRef Dest, LLVMModuleRef Src,
   raw_string_ostream Stream(Message);
   DiagnosticPrinterRawOStream DP(Stream);
 
-  LLVMBool Result = Linker::LinkModules(
-      D, unwrap(Src), [&](const DiagnosticInfo &DI) { DI.print(DP); });
+  LLVMBool Result = Linker::linkModules(
+      *D, *unwrap(Src), [&](const DiagnosticInfo &DI) { DI.print(DP); });
 
   if (OutMessages && Result) {
     Stream.flush();