Bring r254336 back:
[oota-llvm.git] / lib / Transforms / Utils / SymbolRewriter.cpp
index a678dde26da0bbdfca54073784e4a0915c27fd44..1d1f602b041dc8742231f1e9b563035eda381ad6 100644 (file)
@@ -60,7 +60,8 @@
 #define DEBUG_TYPE "symbol-rewriter"
 #include "llvm/CodeGen/Passes.h"
 #include "llvm/Pass.h"
-#include "llvm/PassManager.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/IR/LegacyPassManager.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/YAMLParser.h"
 #include "llvm/Support/raw_ostream.h"
-#include "llvm/Transforms/IPO/PassManagerBuilder.h"
 #include "llvm/Transforms/Utils/SymbolRewriter.h"
 
 using namespace llvm;
+using namespace SymbolRewriter;
 
 static cl::list<std::string> RewriteMapFiles("rewrite-map-file",
                                              cl::desc("Symbol Rewrite Map"),
                                              cl::value_desc("filename"));
 
-namespace llvm {
-namespace SymbolRewriter {
+static void rewriteComdat(Module &M, GlobalObject *GO,
+                          const std::string &Source,
+                          const std::string &Target) {
+  if (Comdat *CD = GO->getComdat()) {
+    auto &Comdats = M.getComdatSymbolTable();
+
+    Comdat *C = M.getOrInsertComdat(Target);
+    C->setSelectionKind(CD->getSelectionKind());
+    GO->setComdat(C);
+
+    Comdats.erase(Comdats.find(Source));
+  }
+}
+
+namespace {
 template <RewriteDescriptor::Type DT, typename ValueType,
           ValueType *(llvm::Module::*Get)(StringRef) const>
 class ExplicitRewriteDescriptor : public RewriteDescriptor {
@@ -102,10 +116,14 @@ template <RewriteDescriptor::Type DT, typename ValueType,
 bool ExplicitRewriteDescriptor<DT, ValueType, Get>::performOnModule(Module &M) {
   bool Changed = false;
   if (ValueType *S = (M.*Get)(Source)) {
+    if (GlobalObject *GO = dyn_cast<GlobalObject>(S))
+      rewriteComdat(M, GO, Source, Target);
+
     if (Value *T = (M.*Get)(Target))
       S->setValueName(T->getValueName());
     else
       S->setName(Target);
+
     Changed = true;
   }
   return Changed;
@@ -113,7 +131,8 @@ bool ExplicitRewriteDescriptor<DT, ValueType, Get>::performOnModule(Module &M) {
 
 template <RewriteDescriptor::Type DT, typename ValueType,
           ValueType *(llvm::Module::*Get)(StringRef) const,
-          iterator_range<typename iplist<ValueType>::iterator> (llvm::Module::*Iterator)()>
+          iterator_range<typename iplist<ValueType>::iterator>
+          (llvm::Module::*Iterator)()>
 class PatternRewriteDescriptor : public RewriteDescriptor {
 public:
   const std::string Pattern;
@@ -131,7 +150,8 @@ public:
 
 template <RewriteDescriptor::Type DT, typename ValueType,
           ValueType *(llvm::Module::*Get)(StringRef) const,
-          iterator_range<typename iplist<ValueType>::iterator> (llvm::Module::*Iterator)()>
+          iterator_range<typename iplist<ValueType>::iterator>
+          (llvm::Module::*Iterator)()>
 bool PatternRewriteDescriptor<DT, ValueType, Get, Iterator>::
 performOnModule(Module &M) {
   bool Changed = false;
@@ -143,6 +163,12 @@ performOnModule(Module &M) {
       report_fatal_error("unable to transforn " + C.getName() + " in " +
                          M.getModuleIdentifier() + ": " + Error);
 
+    if (C.getName() == Name)
+      continue;
+
+    if (GlobalObject *GO = dyn_cast<GlobalObject>(&C))
+      rewriteComdat(M, GO, C.getName(), Name);
+
     if (Value *V = (M.*Get)(Name))
       C.setValueName(V->getValueName());
     else
@@ -201,6 +227,7 @@ typedef PatternRewriteDescriptor<RewriteDescriptor::Type::NamedAlias,
                                  &llvm::Module::getNamedAlias,
                                  &llvm::Module::aliases>
     PatternRewriteNamedAliasDescriptor;
+} // namespace
 
 bool RewriteMapParser::parse(const std::string &MapFile,
                              RewriteDescriptorList *DL) {
@@ -464,8 +491,6 @@ parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
 
   return true;
 }
-}
-}
 
 namespace {
 class RewriteSymbols : public ModulePass {
@@ -475,7 +500,7 @@ public:
   RewriteSymbols();
   RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL);
 
-  virtual bool runOnModule(Module &M) override;
+  bool runOnModule(Module &M) override;
 
 private:
   void loadAndParseMapFiles();
@@ -492,7 +517,7 @@ RewriteSymbols::RewriteSymbols() : ModulePass(ID) {
 
 RewriteSymbols::RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL)
     : ModulePass(ID) {
-  std::swap(Descriptors, DL);
+  Descriptors.splice(Descriptors.begin(), DL);
 }
 
 bool RewriteSymbols::runOnModule(Module &M) {