Fix llvm-extract so that it changes the linkage of all GlobalValues to
[oota-llvm.git] / lib / Transforms / IPO / ExtractGV.cpp
index dfbad61cf5d8f7ebc1b47ce519642ff6b064449b..9d432de9fa7b7fd7704ba3291c735e6e8ff85783 100644 (file)
 #include "llvm/Pass.h"
 #include "llvm/Constants.h"
 #include "llvm/Transforms/IPO.h"
-#include "llvm/Support/Compiler.h"
+#include "llvm/ADT/SetVector.h"
 #include <algorithm>
 using namespace llvm;
 
 namespace {
   /// @brief A pass to extract specific functions and their dependencies.
-  class VISIBILITY_HIDDEN GVExtractorPass : public ModulePass {
-    std::vector<GlobalValue*> Named;
+  class GVExtractorPass : public ModulePass {
+    SetVector<GlobalValue *> Named;
     bool deleteStuff;
-    bool reLink;
   public:
     static char ID; // Pass identification, replacement for typeid
 
@@ -34,134 +33,38 @@ namespace {
     /// specified function. Otherwise, it deletes as much of the module as
     /// possible, except for the function specified.
     ///
-    explicit GVExtractorPass(std::vector<GlobalValue*>& GVs, bool deleteS = true,
-                             bool relinkCallees = false)
-      : ModulePass(&ID), Named(GVs), deleteStuff(deleteS),
-        reLink(relinkCallees) {}
+    explicit GVExtractorPass(std::vector<GlobalValue*>& GVs, bool deleteS = true)
+      : ModulePass(ID), Named(GVs.begin(), GVs.end()), deleteStuff(deleteS) {}
 
     bool runOnModule(Module &M) {
-      if (Named.size() == 0) {
-        return false;  // Nothing to extract
+      // Visit the global inline asm.
+      if (!deleteStuff)
+        M.setModuleInlineAsm("");
+
+      // For simplicity, just give all GlobalValues ExternalLinkage. A trickier
+      // implementation could figure out which GlobalValues are actually
+      // referenced by the Named set, and which GlobalValues in the rest of
+      // the module are referenced by the NamedSet, and get away with leaving
+      // more internal and private things internal and private. But for now,
+      // be conservative and simple.
+
+      // Visit the GlobalVariables.
+      for (Module::global_iterator I = M.global_begin(), E = M.global_end();
+           I != E; ++I) {
+        if (I->hasLocalLinkage())
+          I->setVisibility(GlobalValue::HiddenVisibility);
+        I->setLinkage(GlobalValue::ExternalLinkage);
+        if (deleteStuff == (bool)Named.count(I) && !I->isDeclaration())
+          I->setInitializer(0);
       }
-      
-      
-      if (deleteStuff)
-        return deleteGV();
-      M.setModuleInlineAsm("");
-      return isolateGV(M);
-    }
-
-    bool deleteGV() {
-      for (std::vector<GlobalValue*>::iterator GI = Named.begin(), 
-             GE = Named.end(); GI != GE; ++GI) {
-        if (Function* NamedFunc = dyn_cast<Function>(*GI)) {
-         // If we're in relinking mode, set linkage of all internal callees to
-         // external. This will allow us extract function, and then - link
-         // everything together
-         if (reLink) {
-           for (Function::iterator B = NamedFunc->begin(), BE = NamedFunc->end();
-                B != BE; ++B) {
-             for (BasicBlock::iterator I = B->begin(), E = B->end();
-                  I != E; ++I) {
-               if (CallInst* callInst = dyn_cast<CallInst>(&*I)) {
-                 Function* Callee = callInst->getCalledFunction();
-                 if (Callee && Callee->hasLocalLinkage())
-                   Callee->setLinkage(GlobalValue::ExternalLinkage);
-               }
-             }
-           }
-         }
-         
-         NamedFunc->setLinkage(GlobalValue::ExternalLinkage);
-         NamedFunc->deleteBody();
-         assert(NamedFunc->isDeclaration() && "This didn't make the function external!");
-       } else {
-          if (!(*GI)->isDeclaration()) {
-            cast<GlobalVariable>(*GI)->setInitializer(0);  //clear the initializer
-            (*GI)->setLinkage(GlobalValue::ExternalLinkage);
-          }
-        }
-      }
-      return true;
-    }
-
-    bool isolateGV(Module &M) {
-      LLVMContext &Context = M.getContext();
-      
-      // Mark all globals internal
-      // FIXME: what should we do with private linkage?
-      for (Module::global_iterator I = M.global_begin(), E = M.global_end(); I != E; ++I)
-        if (!I->isDeclaration()) {
-          I->setLinkage(GlobalValue::InternalLinkage);
-        }
-      for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I)
-        if (!I->isDeclaration()) {
-          I->setLinkage(GlobalValue::InternalLinkage);
-        }
-
-      // Make sure our result is globally accessible...
-      // by putting them in the used array
-      {
-        std::vector<Constant *> AUGs;
-        const Type *SBP= PointerType::getUnqual(Type::Int8Ty);
-        for (std::vector<GlobalValue*>::iterator GI = Named.begin(), 
-               GE = Named.end(); GI != GE; ++GI) {
-          (*GI)->setLinkage(GlobalValue::ExternalLinkage);
-          AUGs.push_back(ConstantExpr::getBitCast(*GI, SBP));
-        }
-        ArrayType *AT = ArrayType::get(SBP, AUGs.size());
-        Constant *Init = ConstantArray::get(AT, AUGs);
-        GlobalValue *gv = new GlobalVariable(M, AT, false, 
-                                             GlobalValue::AppendingLinkage, 
-                                             Init, "llvm.used");
-        gv->setSection("llvm.metadata");
-      }
-
-      // All of the functions may be used by global variables or the named
-      // globals.  Loop through them and create a new, external functions that
-      // can be "used", instead of ones with bodies.
-      std::vector<Function*> NewFunctions;
-
-      Function *Last = --M.end();  // Figure out where the last real fn is.
-
-      for (Module::iterator I = M.begin(); ; ++I) {
-        if (std::find(Named.begin(), Named.end(), &*I) == Named.end()) {
-          Function *New = Function::Create(I->getFunctionType(),
-                                           GlobalValue::ExternalLinkage);
-          New->copyAttributesFrom(I);
-
-          // If it's not the named function, delete the body of the function
-          I->dropAllReferences();
-
-          M.getFunctionList().push_back(New);
-          NewFunctions.push_back(New);
-          New->takeName(I);
-        }
-
-        if (&*I == Last) break;  // Stop after processing the last function
-      }
-
-      // Now that we have replacements all set up, loop through the module,
-      // deleting the old functions, replacing them with the newly created
-      // functions.
-      if (!NewFunctions.empty()) {
-        unsigned FuncNum = 0;
-        Module::iterator I = M.begin();
-        do {
-          if (std::find(Named.begin(), Named.end(), &*I) == Named.end()) {
-            // Make everything that uses the old function use the new dummy fn
-            I->replaceAllUsesWith(NewFunctions[FuncNum++]);
-
-            Function *Old = I;
-            ++I;  // Move the iterator to the new function
-
-            // Delete the old function!
-            M.getFunctionList().erase(Old);
 
-          } else {
-            ++I;  // Skip the function we are extracting
-          }
-        } while (&*I != NewFunctions[0]);
+      // Visit the Functions.
+      for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) {
+        if (I->hasLocalLinkage())
+          I->setVisibility(GlobalValue::HiddenVisibility);
+        I->setLinkage(GlobalValue::ExternalLinkage);
+        if (deleteStuff == (bool)Named.count(I) && !I->isDeclaration())
+          I->deleteBody();
       }
 
       return true;
@@ -172,6 +75,6 @@ namespace {
 }
 
 ModulePass *llvm::createGVExtractionPass(std::vector<GlobalValue*>& GVs, 
-                                         bool deleteFn, bool relinkCallees) {
-  return new GVExtractorPass(GVs, deleteFn, relinkCallees);
+                                         bool deleteFn) {
+  return new GVExtractorPass(GVs, deleteFn);
 }