a6d7dcf7b556c0d8fde244f03e0c1e3aec857c4e
[oota-llvm.git] / examples / ParallelJIT / ParallelJIT.cpp
1 //===-- examples/ParallelJIT/ParallelJIT.cpp - Exercise threaded-safe JIT -===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Parallel JIT
11 //
12 // This test program creates two LLVM functions then calls them from three
13 // separate threads.  It requires the pthreads library.
14 // The three threads are created and then block waiting on a condition variable.
15 // Once all threads are blocked on the conditional variable, the main thread
16 // wakes them up. This complicated work is performed so that all three threads
17 // call into the JIT at the same time (or the best possible approximation of the
18 // same time). This test had assertion errors until I got the locking right.
19
20 #include <pthread.h>
21 #include "llvm/Module.h"
22 #include "llvm/Constants.h"
23 #include "llvm/DerivedTypes.h"
24 #include "llvm/Instructions.h"
25 #include "llvm/ModuleProvider.h"
26 #include "llvm/ExecutionEngine/JIT.h"
27 #include "llvm/ExecutionEngine/Interpreter.h"
28 #include "llvm/ExecutionEngine/GenericValue.h"
29 #include "llvm/Target/TargetSelect.h"
30 #include <iostream>
31 using namespace llvm;
32
33 static Function* createAdd1(Module *M) {
34   // Create the add1 function entry and insert this entry into module M.  The
35   // function will have a return type of "int" and take an argument of "int".
36   // The '0' terminates the list of argument types.
37   Function *Add1F =
38     cast<Function>(M->getOrInsertFunction("add1", Type::Int32Ty, Type::Int32Ty,
39                                           (Type *)0));
40
41   // Add a basic block to the function. As before, it automatically inserts
42   // because of the last argument.
43   BasicBlock *BB = BasicBlock::Create("EntryBlock", Add1F);
44
45   // Get pointers to the constant `1'.
46   Value *One = ConstantInt::get(Type::Int32Ty, 1);
47
48   // Get pointers to the integer argument of the add1 function...
49   assert(Add1F->arg_begin() != Add1F->arg_end()); // Make sure there's an arg
50   Argument *ArgX = Add1F->arg_begin();  // Get the arg
51   ArgX->setName("AnArg");            // Give it a nice symbolic name for fun.
52
53   // Create the add instruction, inserting it into the end of BB.
54   Instruction *Add = BinaryOperator::CreateAdd(One, ArgX, "addresult", BB);
55
56   // Create the return instruction and add it to the basic block
57   ReturnInst::Create(Add, BB);
58
59   // Now, function add1 is ready.
60   return Add1F;
61 }
62
63 static Function *CreateFibFunction(Module *M) {
64   // Create the fib function and insert it into module M.  This function is said
65   // to return an int and take an int parameter.
66   Function *FibF = 
67     cast<Function>(M->getOrInsertFunction("fib", Type::Int32Ty, Type::Int32Ty,
68                                           (Type *)0));
69
70   // Add a basic block to the function.
71   BasicBlock *BB = BasicBlock::Create("EntryBlock", FibF);
72
73   // Get pointers to the constants.
74   Value *One = ConstantInt::get(Type::Int32Ty, 1);
75   Value *Two = ConstantInt::get(Type::Int32Ty, 2);
76
77   // Get pointer to the integer argument of the add1 function...
78   Argument *ArgX = FibF->arg_begin();   // Get the arg.
79   ArgX->setName("AnArg");            // Give it a nice symbolic name for fun.
80
81   // Create the true_block.
82   BasicBlock *RetBB = BasicBlock::Create("return", FibF);
83   // Create an exit block.
84   BasicBlock* RecurseBB = BasicBlock::Create("recurse", FibF);
85
86   // Create the "if (arg < 2) goto exitbb"
87   Value *CondInst = new ICmpInst(ICmpInst::ICMP_SLE, ArgX, Two, "cond", BB);
88   BranchInst::Create(RetBB, RecurseBB, CondInst, BB);
89
90   // Create: ret int 1
91   ReturnInst::Create(One, RetBB);
92
93   // create fib(x-1)
94   Value *Sub = BinaryOperator::CreateSub(ArgX, One, "arg", RecurseBB);
95   Value *CallFibX1 = CallInst::Create(FibF, Sub, "fibx1", RecurseBB);
96
97   // create fib(x-2)
98   Sub = BinaryOperator::CreateSub(ArgX, Two, "arg", RecurseBB);
99   Value *CallFibX2 = CallInst::Create(FibF, Sub, "fibx2", RecurseBB);
100
101   // fib(x-1)+fib(x-2)
102   Value *Sum =
103     BinaryOperator::CreateAdd(CallFibX1, CallFibX2, "addresult", RecurseBB);
104
105   // Create the return instruction and add it to the basic block
106   ReturnInst::Create(Sum, RecurseBB);
107
108   return FibF;
109 }
110
111 struct threadParams {
112   ExecutionEngine* EE;
113   Function* F;
114   int value;
115 };
116
117 // We block the subthreads just before they begin to execute:
118 // we want all of them to call into the JIT at the same time,
119 // to verify that the locking is working correctly.
120 class WaitForThreads
121 {
122 public:
123   WaitForThreads()
124   {
125     n = 0;
126     waitFor = 0;
127
128     int result = pthread_cond_init( &condition, NULL );
129     assert( result == 0 );
130
131     result = pthread_mutex_init( &mutex, NULL );
132     assert( result == 0 );
133   }
134
135   ~WaitForThreads()
136   {
137     int result = pthread_cond_destroy( &condition );
138     assert( result == 0 );
139
140     result = pthread_mutex_destroy( &mutex );
141     assert( result == 0 );
142   }
143
144   // All threads will stop here until another thread calls releaseThreads
145   void block()
146   {
147     int result = pthread_mutex_lock( &mutex );
148     assert( result == 0 );
149     n ++;
150     //~ std::cout << "block() n " << n << " waitFor " << waitFor << std::endl;
151
152     assert( waitFor == 0 || n <= waitFor );
153     if ( waitFor > 0 && n == waitFor )
154     {
155       // There are enough threads blocked that we can release all of them
156       std::cout << "Unblocking threads from block()" << std::endl;
157       unblockThreads();
158     }
159     else
160     {
161       // We just need to wait until someone unblocks us
162       result = pthread_cond_wait( &condition, &mutex );
163       assert( result == 0 );
164     }
165
166     // unlock the mutex before returning
167     result = pthread_mutex_unlock( &mutex );
168     assert( result == 0 );
169   }
170
171   // If there are num or more threads blocked, it will signal them all
172   // Otherwise, this thread blocks until there are enough OTHER threads
173   // blocked
174   void releaseThreads( size_t num )
175   {
176     int result = pthread_mutex_lock( &mutex );
177     assert( result == 0 );
178
179     if ( n >= num ) {
180       std::cout << "Unblocking threads from releaseThreads()" << std::endl;
181       unblockThreads();
182     }
183     else
184     {
185       waitFor = num;
186       pthread_cond_wait( &condition, &mutex );
187     }
188
189     // unlock the mutex before returning
190     result = pthread_mutex_unlock( &mutex );
191     assert( result == 0 );
192   }
193
194 private:
195   void unblockThreads()
196   {
197     // Reset the counters to zero: this way, if any new threads
198     // enter while threads are exiting, they will block instead
199     // of triggering a new release of threads
200     n = 0;
201
202     // Reset waitFor to zero: this way, if waitFor threads enter
203     // while threads are exiting, they will block instead of
204     // triggering a new release of threads
205     waitFor = 0;
206
207     int result = pthread_cond_broadcast( &condition );
208     assert(result == 0); result=result;
209   }
210
211   size_t n;
212   size_t waitFor;
213   pthread_cond_t condition;
214   pthread_mutex_t mutex;
215 };
216
217 static WaitForThreads synchronize;
218
219 void* callFunc( void* param )
220 {
221   struct threadParams* p = (struct threadParams*) param;
222
223   // Call the `foo' function with no arguments:
224   std::vector<GenericValue> Args(1);
225   Args[0].IntVal = APInt(32, p->value);
226
227   synchronize.block(); // wait until other threads are at this point
228   GenericValue gv = p->EE->runFunction(p->F, Args);
229
230   return (void*)(intptr_t)gv.IntVal.getZExtValue();
231 }
232
233 int main() {
234   InitializeNativeTarget();
235
236   // Create some module to put our function into it.
237   Module *M = new Module("test");
238
239   Function* add1F = createAdd1( M );
240   Function* fibF = CreateFibFunction( M );
241
242   // Now we create the JIT.
243   ExistingModuleProvider* MP = new ExistingModuleProvider(M);
244   ExecutionEngine* EE = ExecutionEngine::create(MP, false);
245
246   //~ std::cout << "We just constructed this LLVM module:\n\n" << *M;
247   //~ std::cout << "\n\nRunning foo: " << std::flush;
248
249   // Create one thread for add1 and two threads for fib
250   struct threadParams add1 = { EE, add1F, 1000 };
251   struct threadParams fib1 = { EE, fibF, 39 };
252   struct threadParams fib2 = { EE, fibF, 42 };
253
254   pthread_t add1Thread;
255   int result = pthread_create( &add1Thread, NULL, callFunc, &add1 );
256   if ( result != 0 ) {
257           std::cerr << "Could not create thread" << std::endl;
258           return 1;
259   }
260
261   pthread_t fibThread1;
262   result = pthread_create( &fibThread1, NULL, callFunc, &fib1 );
263   if ( result != 0 ) {
264           std::cerr << "Could not create thread" << std::endl;
265           return 1;
266   }
267
268   pthread_t fibThread2;
269   result = pthread_create( &fibThread2, NULL, callFunc, &fib2 );
270   if ( result != 0 ) {
271           std::cerr << "Could not create thread" << std::endl;
272           return 1;
273   }
274
275   synchronize.releaseThreads(3); // wait until other threads are at this point
276
277   void* returnValue;
278   result = pthread_join( add1Thread, &returnValue );
279   if ( result != 0 ) {
280           std::cerr << "Could not join thread" << std::endl;
281           return 1;
282   }
283   std::cout << "Add1 returned " << intptr_t(returnValue) << std::endl;
284
285   result = pthread_join( fibThread1, &returnValue );
286   if ( result != 0 ) {
287           std::cerr << "Could not join thread" << std::endl;
288           return 1;
289   }
290   std::cout << "Fib1 returned " << intptr_t(returnValue) << std::endl;
291
292   result = pthread_join( fibThread2, &returnValue );
293   if ( result != 0 ) {
294           std::cerr << "Could not join thread" << std::endl;
295           return 1;
296   }
297   std::cout << "Fib2 returned " << intptr_t(returnValue) << std::endl;
298
299   return 0;
300 }