Nuke the old JIT.
[oota-llvm.git] / unittests / ExecutionEngine / MCJIT / MCJITTestBase.h
index 71f2bc58f4a4012e0c19bd9d713d4b9d49c2b2d4..2c1d518da4cc710885fdfa4baec4d7910e00cf21 100644 (file)
@@ -17,6 +17,7 @@
 #ifndef MCJIT_TEST_BASE_H
 #define MCJIT_TEST_BASE_H
 
+#include "MCJITTestAPICommon.h"
 #include "llvm/Config/config.h"
 #include "llvm/ExecutionEngine/ExecutionEngine.h"
 #include "llvm/ExecutionEngine/SectionMemoryManager.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/TypeBuilder.h"
 #include "llvm/Support/CodeGen.h"
-#include "MCJITTestAPICommon.h"
 
 namespace llvm {
 
-class MCJITTestBase : public MCJITTestAPICommon {
+/// Helper class that can build very simple Modules
+class TrivialModuleBuilder {
 protected:
+  LLVMContext Context;
+  IRBuilder<> Builder;
+  std::string BuilderTriple;
 
-  MCJITTestBase()
-    : OptLevel(CodeGenOpt::None)
-    , RelocModel(Reloc::Default)
-    , CodeModel(CodeModel::Default)
-    , MArch("")
-    , Builder(Context)
-    , MM(new SectionMemoryManager)
-  {
-    // The architectures below are known to be compatible with MCJIT as they
-    // are copied from test/ExecutionEngine/MCJIT/lit.local.cfg and should be
-    // kept in sync.
-    SupportedArchs.push_back(Triple::aarch64);
-    SupportedArchs.push_back(Triple::arm);
-    SupportedArchs.push_back(Triple::mips);
-    SupportedArchs.push_back(Triple::x86);
-    SupportedArchs.push_back(Triple::x86_64);
+  TrivialModuleBuilder(const std::string &Triple)
+    : Builder(Context), BuilderTriple(Triple) {}
 
-    // The operating systems below are known to be incompatible with MCJIT as
-    // they are copied from the test/ExecutionEngine/MCJIT/lit.local.cfg and
-    // should be kept in sync.
-    UnsupportedOSs.push_back(Triple::Cygwin);
-    UnsupportedOSs.push_back(Triple::Darwin);
-  }
-
-  Module *createEmptyModule(StringRef Name) {
+  Module *createEmptyModule(StringRef Name = StringRef()) {
     Module * M = new Module(Name, Context);
-    M->setTargetTriple(Triple::normalize(HostTriple));
+    M->setTargetTriple(Triple::normalize(BuilderTriple));
     return M;
   }
 
@@ -136,12 +119,13 @@ protected:
   // Inserts an declaration to a function defined elsewhere
   Function *insertExternalReferenceToFunction(Module *M, Function *Func) {
     Function *Result = Function::Create(Func->getFunctionType(),
-                                        GlobalValue::AvailableExternallyLinkage,
+                                        GlobalValue::ExternalLinkage,
                                         Func->getName(), M);
     return Result;
   }
 
   // Inserts a global variable of type int32
+  // FIXME: make this a template function to support any type
   GlobalVariable *insertGlobalInt32(Module *M,
                                     StringRef name,
                                     int32_t InitialValue) {
@@ -156,6 +140,174 @@ protected:
     return Global;
   }
 
+  // Inserts a function
+  //   int32_t recursive_add(int32_t num) {
+  //     if (num == 0) {
+  //       return num;
+  //     } else {
+  //       int32_t recursive_param = num - 1;
+  //       return num + Helper(recursive_param);
+  //     }
+  //   }
+  // NOTE: if Helper is left as the default parameter, Helper == recursive_add.
+  Function *insertAccumulateFunction(Module *M,
+                                              Function *Helper = 0,
+                                              StringRef Name = "accumulate") {
+    Function *Result = startFunction<int32_t(int32_t)>(M, Name);
+    if (Helper == 0)
+      Helper = Result;
+
+    BasicBlock *BaseCase = BasicBlock::Create(Context, "", Result);
+    BasicBlock *RecursiveCase = BasicBlock::Create(Context, "", Result);
+
+    // if (num == 0)
+    Value *Param = Result->arg_begin();
+    Value *Zero = ConstantInt::get(Context, APInt(32, 0));
+    Builder.CreateCondBr(Builder.CreateICmpEQ(Param, Zero),
+                         BaseCase, RecursiveCase);
+
+    //   return num;
+    Builder.SetInsertPoint(BaseCase);
+    Builder.CreateRet(Param);
+
+    //   int32_t recursive_param = num - 1;
+    //   return Helper(recursive_param);
+    Builder.SetInsertPoint(RecursiveCase);
+    Value *One = ConstantInt::get(Context, APInt(32, 1));
+    Value *RecursiveParam = Builder.CreateSub(Param, One);
+    Value *RecursiveReturn = Builder.CreateCall(Helper, RecursiveParam);
+    Value *Accumulator = Builder.CreateAdd(Param, RecursiveReturn);
+    Builder.CreateRet(Accumulator);
+
+    return Result;
+  }
+
+  // Populates Modules A and B:
+  // Module A { Extern FB1, Function FA which calls FB1 },
+  // Module B { Extern FA, Function FB1, Function FB2 which calls FA },
+  void createCrossModuleRecursiveCase(std::unique_ptr<Module> &A, Function *&FA,
+                                      std::unique_ptr<Module> &B,
+                                      Function *&FB1, Function *&FB2) {
+    // Define FB1 in B.
+    B.reset(createEmptyModule("B"));
+    FB1 = insertAccumulateFunction(B.get(), 0, "FB1");
+
+    // Declare FB1 in A (as an external).
+    A.reset(createEmptyModule("A"));
+    Function *FB1Extern = insertExternalReferenceToFunction(A.get(), FB1);
+
+    // Define FA in A (with a call to FB1).
+    FA = insertAccumulateFunction(A.get(), FB1Extern, "FA");
+
+    // Declare FA in B (as an external)
+    Function *FAExtern = insertExternalReferenceToFunction(B.get(), FA);
+
+    // Define FB2 in B (with a call to FA)
+    FB2 = insertAccumulateFunction(B.get(), FAExtern, "FB2");
+  }
+
+  // Module A { Function FA },
+  // Module B { Extern FA, Function FB which calls FA },
+  // Module C { Extern FB, Function FC which calls FB },
+  void
+  createThreeModuleChainedCallsCase(std::unique_ptr<Module> &A, Function *&FA,
+                                    std::unique_ptr<Module> &B, Function *&FB,
+                                    std::unique_ptr<Module> &C, Function *&FC) {
+    A.reset(createEmptyModule("A"));
+    FA = insertAddFunction(A.get());
+
+    B.reset(createEmptyModule("B"));
+    Function *FAExtern_in_B = insertExternalReferenceToFunction(B.get(), FA);
+    FB = insertSimpleCallFunction<int32_t(int32_t, int32_t)>(B.get(), FAExtern_in_B);
+
+    C.reset(createEmptyModule("C"));
+    Function *FBExtern_in_C = insertExternalReferenceToFunction(C.get(), FB);
+    FC = insertSimpleCallFunction<int32_t(int32_t, int32_t)>(C.get(), FBExtern_in_C);
+  }
+
+
+  // Module A { Function FA },
+  // Populates Modules A and B:
+  // Module B { Function FB }
+  void createTwoModuleCase(std::unique_ptr<Module> &A, Function *&FA,
+                           std::unique_ptr<Module> &B, Function *&FB) {
+    A.reset(createEmptyModule("A"));
+    FA = insertAddFunction(A.get());
+
+    B.reset(createEmptyModule("B"));
+    FB = insertAddFunction(B.get());
+  }
+
+  // Module A { Function FA },
+  // Module B { Extern FA, Function FB which calls FA }
+  void createTwoModuleExternCase(std::unique_ptr<Module> &A, Function *&FA,
+                                 std::unique_ptr<Module> &B, Function *&FB) {
+    A.reset(createEmptyModule("A"));
+    FA = insertAddFunction(A.get());
+
+    B.reset(createEmptyModule("B"));
+    Function *FAExtern_in_B = insertExternalReferenceToFunction(B.get(), FA);
+    FB = insertSimpleCallFunction<int32_t(int32_t, int32_t)>(B.get(),
+                                                             FAExtern_in_B);
+  }
+
+  // Module A { Function FA },
+  // Module B { Extern FA, Function FB which calls FA },
+  // Module C { Extern FB, Function FC which calls FA },
+  void createThreeModuleCase(std::unique_ptr<Module> &A, Function *&FA,
+                             std::unique_ptr<Module> &B, Function *&FB,
+                             std::unique_ptr<Module> &C, Function *&FC) {
+    A.reset(createEmptyModule("A"));
+    FA = insertAddFunction(A.get());
+
+    B.reset(createEmptyModule("B"));
+    Function *FAExtern_in_B = insertExternalReferenceToFunction(B.get(), FA);
+    FB = insertSimpleCallFunction<int32_t(int32_t, int32_t)>(B.get(), FAExtern_in_B);
+
+    C.reset(createEmptyModule("C"));
+    Function *FAExtern_in_C = insertExternalReferenceToFunction(C.get(), FA);
+    FC = insertSimpleCallFunction<int32_t(int32_t, int32_t)>(C.get(), FAExtern_in_C);
+  }
+};
+
+
+class MCJITTestBase : public MCJITTestAPICommon, public TrivialModuleBuilder {
+protected:
+
+  MCJITTestBase()
+    : TrivialModuleBuilder(HostTriple)
+    , OptLevel(CodeGenOpt::None)
+    , RelocModel(Reloc::Default)
+    , CodeModel(CodeModel::Default)
+    , MArch("")
+    , MM(new SectionMemoryManager)
+  {
+    // The architectures below are known to be compatible with MCJIT as they
+    // are copied from test/ExecutionEngine/MCJIT/lit.local.cfg and should be
+    // kept in sync.
+    SupportedArchs.push_back(Triple::aarch64);
+    SupportedArchs.push_back(Triple::arm);
+    SupportedArchs.push_back(Triple::mips);
+    SupportedArchs.push_back(Triple::mipsel);
+    SupportedArchs.push_back(Triple::x86);
+    SupportedArchs.push_back(Triple::x86_64);
+
+    // Some architectures have sub-architectures in which tests will fail, like
+    // ARM. These two vectors will define if they do have sub-archs (to avoid
+    // extra work for those who don't), and if so, if they are listed to work
+    HasSubArchs.push_back(Triple::arm);
+    SupportedSubArchs.push_back("armv6");
+    SupportedSubArchs.push_back("armv7");
+
+    // The operating systems below are known to be incompatible with MCJIT as
+    // they are copied from the test/ExecutionEngine/MCJIT/lit.local.cfg and
+    // should be kept in sync.
+    UnsupportedOSs.push_back(Triple::Cygwin);
+    UnsupportedOSs.push_back(Triple::Darwin);
+
+    UnsupportedEnvironments.push_back(Triple::Cygnus);
+  }
+
   void createJIT(Module *M) {
 
     // Due to the EngineBuilder constructor, it is required to have a Module
@@ -165,11 +317,9 @@ protected:
     EngineBuilder EB(M);
     std::string Error;
     TheJIT.reset(EB.setEngineKind(EngineKind::JIT)
-                 .setUseMCJIT(true) /* can this be folded into the EngineKind enum? */
                  .setMCJITMemoryManager(MM)
                  .setErrorStr(&Error)
                  .setOptLevel(CodeGenOpt::None)
-                 .setAllocateGVsWithCode(false) /*does this do anything?*/
                  .setCodeModel(CodeModel::JITDefault)
                  .setRelocationModel(Reloc::Default)
                  .setMArch(MArch)
@@ -180,18 +330,15 @@ protected:
     assert(TheJIT.get() != NULL && "error creating MCJIT with EngineBuilder");
   }
 
-  LLVMContext Context;
   CodeGenOpt::Level OptLevel;
   Reloc::Model RelocModel;
   CodeModel::Model CodeModel;
   StringRef MArch;
   SmallVector<std::string, 1> MAttrs;
-  OwningPtr<TargetMachine> TM;
-  OwningPtr<ExecutionEngine> TheJIT;
-  IRBuilder<> Builder;
+  std::unique_ptr<ExecutionEngine> TheJIT;
   RTDyldMemoryManager *MM;
 
-  OwningPtr<Module> M;
+  std::unique_ptr<Module> M;
 };
 
 } // namespace llvm