07707c91cd99eec7f13b67b0ed68f84963f84349
[oota-llvm.git] / unittests / ExecutionEngine / Orc / OrcCAPITest.cpp
1 //===--------------- OrcCAPITest.cpp - Unit tests Orc C API ---------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9
10 #include "OrcTestCommon.h"
11 #include "gtest/gtest.h"
12 #include "llvm-c/OrcBindings.h"
13 #include "llvm-c/Target.h"
14 #include "llvm-c/TargetMachine.h"
15
16 #include <stdio.h>
17 #include <stdlib.h>
18 #include <string.h>
19
20 namespace llvm {
21
22 DEFINE_SIMPLE_CONVERSION_FUNCTIONS(TargetMachine, LLVMTargetMachineRef)
23
24 class OrcCAPIExecutionTest : public testing::Test, public OrcExecutionTest {
25 protected:
26   std::unique_ptr<Module> createTestModule(const Triple &TT) {
27     ModuleBuilder MB(getGlobalContext(), TT.str(), "");
28     Function *TestFunc = MB.createFunctionDecl<int()>("testFunc");
29     Function *Main = MB.createFunctionDecl<int(int, char*[])>("main");
30
31     Main->getBasicBlockList().push_back(BasicBlock::Create(getGlobalContext()));
32     IRBuilder<> B(&Main->back());
33     Value* Result = B.CreateCall(TestFunc);
34     B.CreateRet(Result);
35
36     return MB.takeModule();
37   }
38
39   typedef int (*MainFnTy)();
40
41   static int myTestFuncImpl() {
42     return 42;
43   }
44
45   static char *testFuncName;
46
47   static uint64_t myResolver(const char *Name, void *Ctx) {
48     if (!strncmp(Name, testFuncName, 8))
49       return (uint64_t)&myTestFuncImpl;
50     return 0;
51   }
52
53   struct CompileContext {
54     CompileContext() : Compiled(false) { }
55
56     OrcCAPIExecutionTest* APIExecTest;
57     std::unique_ptr<Module> M;
58     LLVMOrcModuleHandle H;
59     bool Compiled;
60   };
61
62   static LLVMOrcTargetAddress myCompileCallback(LLVMOrcJITStackRef JITStack,
63                                                 void *Ctx) {
64     CompileContext *CCtx = static_cast<CompileContext*>(Ctx);
65     auto *ET = CCtx->APIExecTest;
66     CCtx->M = ET->createTestModule(ET->TM->getTargetTriple());
67     CCtx->H = LLVMOrcAddEagerlyCompiledIR(JITStack, wrap(CCtx->M.get()),
68                                           myResolver, nullptr);
69     CCtx->Compiled = true;
70     LLVMOrcTargetAddress MainAddr = LLVMOrcGetSymbolAddress(JITStack, "main");
71     LLVMOrcSetIndirectStubPointer(JITStack, "foo", MainAddr);
72     return MainAddr;
73   }
74 };
75
76 char *OrcCAPIExecutionTest::testFuncName = nullptr;
77
78 TEST_F(OrcCAPIExecutionTest, TestEagerIRCompilation) {
79   if (!TM)
80     return;
81
82   LLVMOrcJITStackRef JIT =
83     LLVMOrcCreateInstance(wrap(TM.get()));
84
85   std::unique_ptr<Module> M = createTestModule(TM->getTargetTriple());
86
87   LLVMOrcGetMangledSymbol(JIT, &testFuncName, "testFunc");
88
89   LLVMOrcModuleHandle H =
90     LLVMOrcAddEagerlyCompiledIR(JIT, wrap(M.get()), myResolver, nullptr);
91   MainFnTy MainFn = (MainFnTy)LLVMOrcGetSymbolAddress(JIT, "main");
92   int Result = MainFn();
93   EXPECT_EQ(Result, 42)
94     << "Eagerly JIT'd code did not return expected result";
95
96   LLVMOrcRemoveModule(JIT, H);
97
98   LLVMOrcDisposeMangledSymbol(testFuncName);
99   LLVMOrcDisposeInstance(JIT);
100 }
101
102 TEST_F(OrcCAPIExecutionTest, TestLazyIRCompilation) {
103   if (!TM)
104     return;
105
106   LLVMOrcJITStackRef JIT =
107     LLVMOrcCreateInstance(wrap(TM.get()));
108
109   std::unique_ptr<Module> M = createTestModule(TM->getTargetTriple());
110
111   LLVMOrcGetMangledSymbol(JIT, &testFuncName, "testFunc");
112
113   LLVMOrcModuleHandle H =
114     LLVMOrcAddLazilyCompiledIR(JIT, wrap(M.get()), myResolver, nullptr);
115   MainFnTy MainFn = (MainFnTy)LLVMOrcGetSymbolAddress(JIT, "main");
116   int Result = MainFn();
117   EXPECT_EQ(Result, 42)
118     << "Lazily JIT'd code did not return expected result";
119
120   LLVMOrcRemoveModule(JIT, H);
121
122   LLVMOrcDisposeMangledSymbol(testFuncName);
123   LLVMOrcDisposeInstance(JIT);
124 }
125
126 TEST_F(OrcCAPIExecutionTest, TestDirectCallbacksAPI) {
127   if (!TM)
128     return;
129
130   LLVMOrcJITStackRef JIT =
131     LLVMOrcCreateInstance(wrap(TM.get()));
132
133   LLVMOrcGetMangledSymbol(JIT, &testFuncName, "testFunc");
134
135   CompileContext C;
136   C.APIExecTest = this;
137   LLVMOrcCreateIndirectStub(JIT, "foo",
138                             LLVMOrcCreateLazyCompileCallback(JIT,
139                                                              myCompileCallback,
140                                                              &C));
141   MainFnTy FooFn = (MainFnTy)LLVMOrcGetSymbolAddress(JIT, "foo");
142   int Result = FooFn();
143   EXPECT_TRUE(C.Compiled)
144     << "Function wasn't lazily compiled";
145   EXPECT_EQ(Result, 42)
146     << "Direct-callback JIT'd code did not return expected result";
147
148   C.Compiled = false;
149   FooFn();
150   EXPECT_FALSE(C.Compiled)
151     << "Direct-callback JIT'd code was JIT'd twice";
152
153   LLVMOrcRemoveModule(JIT, C.H);
154
155   LLVMOrcDisposeMangledSymbol(testFuncName);
156   LLVMOrcDisposeInstance(JIT);
157 }
158
159 } // namespace llvm