Tests for MCJIT multiple module support
[oota-llvm.git] / unittests / ExecutionEngine / MCJIT / MCJITTestBase.h
index 5debb8b57858aea93abfdf0d57a2bb1adb89cb39..b42a9c0980db1ea26ab0dfaf14e81b7dd6ba3c8c 100644 (file)
@@ -119,12 +119,13 @@ protected:
   // Inserts an declaration to a function defined elsewhere
   Function *insertExternalReferenceToFunction(Module *M, Function *Func) {
     Function *Result = Function::Create(Func->getFunctionType(),
-                                        GlobalValue::AvailableExternallyLinkage,
+                                        GlobalValue::ExternalLinkage,
                                         Func->getName(), M);
     return Result;
   }
 
   // Inserts a global variable of type int32
+  // FIXME: make this a template function to support any type
   GlobalVariable *insertGlobalInt32(Module *M,
                                     StringRef name,
                                     int32_t InitialValue) {
@@ -138,11 +139,148 @@ protected:
                                                 name);
     return Global;
   }
+
+  // Inserts a function
+  //   int32_t recursive_add(int32_t num) {
+  //     if (num == 0) {
+  //       return num;
+  //     } else {
+  //       int32_t recursive_param = num - 1;
+  //       return num + Helper(recursive_param);
+  //     }
+  //   }
+  // NOTE: if Helper is left as the default parameter, Helper == recursive_add.
+  Function *insertAccumulateFunction(Module *M,
+                                              Function *Helper = 0,
+                                              StringRef Name = "accumulate") {
+    Function *Result = startFunction<int32_t(int32_t)>(M, Name);
+    if (Helper == 0)
+      Helper = Result;
+
+    BasicBlock *BaseCase = BasicBlock::Create(Context, "", Result);
+    BasicBlock *RecursiveCase = BasicBlock::Create(Context, "", Result);
+
+    // if (num == 0)
+    Value *Param = Result->arg_begin();
+    Value *Zero = ConstantInt::get(Context, APInt(32, 0));
+    Builder.CreateCondBr(Builder.CreateICmpEQ(Param, Zero),
+                         BaseCase, RecursiveCase);
+
+    //   return num;
+    Builder.SetInsertPoint(BaseCase);
+    Builder.CreateRet(Param);
+
+    //   int32_t recursive_param = num - 1;
+    //   return Helper(recursive_param);
+    Builder.SetInsertPoint(RecursiveCase);
+    Value *One = ConstantInt::get(Context, APInt(32, 1));
+    Value *RecursiveParam = Builder.CreateSub(Param, One);
+    Value *RecursiveReturn = Builder.CreateCall(Helper, RecursiveParam);
+    Value *Accumulator = Builder.CreateAdd(Param, RecursiveReturn);
+    Builder.CreateRet(Accumulator);
+
+    return Result;
+  }
+
+  // Populates Modules A and B:
+  // Module A { Extern FB1, Function FA which calls FB1 },
+  // Module B { Extern FA, Function FB1, Function FB2 which calls FA },
+  void createCrossModuleRecursiveCase(OwningPtr<Module> &A,
+                                      Function *&FA,
+                                      OwningPtr<Module> &B,
+                                      Function *&FB1,
+                                      Function *&FB2) {
+    // Define FB1 in B.
+    B.reset(createEmptyModule("B"));
+    FB1 = insertAccumulateFunction(B.get(), 0, "FB1");
+
+    // Declare FB1 in A (as an external).
+    A.reset(createEmptyModule("A"));
+    Function *FB1Extern = insertExternalReferenceToFunction(A.get(), FB1);
+
+    // Define FA in A (with a call to FB1).
+    FA = insertAccumulateFunction(A.get(), FB1Extern, "FA");
+
+    // Declare FA in B (as an external)
+    Function *FAExtern = insertExternalReferenceToFunction(B.get(), FA);
+
+    // Define FB2 in B (with a call to FA)
+    FB2 = insertAccumulateFunction(B.get(), FAExtern, "FB2");
+  }
+
+  // Module A { Function FA },
+  // Module B { Extern FA, Function FB which calls FA },
+  // Module C { Extern FB, Function FC which calls FB },
+  void createThreeModuleChainedCallsCase(OwningPtr<Module> &A,
+                             Function *&FA,
+                             OwningPtr<Module> &B,
+                             Function *&FB,
+                             OwningPtr<Module> &C,
+                             Function *&FC) {
+    A.reset(createEmptyModule("A"));
+    FA = insertAddFunction(A.get());
+
+    B.reset(createEmptyModule("B"));
+    Function *FAExtern_in_B = insertExternalReferenceToFunction(B.get(), FA);
+    FB = insertSimpleCallFunction<int32_t(int32_t, int32_t)>(B.get(), FAExtern_in_B);
+
+    C.reset(createEmptyModule("C"));
+    Function *FBExtern_in_C = insertExternalReferenceToFunction(C.get(), FB);
+    FC = insertSimpleCallFunction<int32_t(int32_t, int32_t)>(C.get(), FBExtern_in_C);
+  }
+
+
+  // Module A { Function FA },
+  // Populates Modules A and B:
+  // Module B { Function FB }
+  void createTwoModuleCase(OwningPtr<Module> &A, Function *&FA,
+                           OwningPtr<Module> &B, Function *&FB) {
+    A.reset(createEmptyModule("A"));
+    FA = insertAddFunction(A.get());
+
+    B.reset(createEmptyModule("B"));
+    FB = insertAddFunction(B.get());
+  }
+
+  // Module A { Function FA },
+  // Module B { Extern FA, Function FB which calls FA }
+  void createTwoModuleExternCase(OwningPtr<Module> &A, Function *&FA,
+                                 OwningPtr<Module> &B, Function *&FB) {
+    A.reset(createEmptyModule("A"));
+    FA = insertAddFunction(A.get());
+
+    B.reset(createEmptyModule("B"));
+    Function *FAExtern_in_B = insertExternalReferenceToFunction(B.get(), FA);
+    FB = insertSimpleCallFunction<int32_t(int32_t, int32_t)>(B.get(),
+                                                             FAExtern_in_B);
+  }
+
+  // Module A { Function FA },
+  // Module B { Extern FA, Function FB which calls FA },
+  // Module C { Extern FB, Function FC which calls FA },
+  void createThreeModuleCase(OwningPtr<Module> &A,
+                             Function *&FA,
+                             OwningPtr<Module> &B,
+                             Function *&FB,
+                             OwningPtr<Module> &C,
+                             Function *&FC) {
+    A.reset(createEmptyModule("A"));
+    FA = insertAddFunction(A.get());
+
+    B.reset(createEmptyModule("B"));
+    Function *FAExtern_in_B = insertExternalReferenceToFunction(B.get(), FA);
+    FB = insertSimpleCallFunction<int32_t(int32_t, int32_t)>(B.get(), FAExtern_in_B);
+
+    C.reset(createEmptyModule("C"));
+    Function *FAExtern_in_C = insertExternalReferenceToFunction(C.get(), FA);
+    FC = insertSimpleCallFunction<int32_t(int32_t, int32_t)>(C.get(), FAExtern_in_C);
+  }
 };
 
+
 class MCJITTestBase : public MCJITTestAPICommon, public TrivialModuleBuilder {
 protected:
-  
+
   MCJITTestBase()
     : TrivialModuleBuilder(HostTriple)
     , OptLevel(CodeGenOpt::None)