Update CreateMalloc so that its callers specify the size to allocate:
[oota-llvm.git] / examples / BrainF / BrainF.cpp
1 //===-- BrainF.cpp - BrainF compiler example ----------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===--------------------------------------------------------------------===//
9 //
10 // This class compiles the BrainF language into LLVM assembly.
11 //
12 // The BrainF language has 8 commands:
13 // Command   Equivalent C    Action
14 // -------   ------------    ------
15 // ,         *h=getchar();   Read a character from stdin, 255 on EOF
16 // .         putchar(*h);    Write a character to stdout
17 // -         --*h;           Decrement tape
18 // +         ++*h;           Increment tape
19 // <         --h;            Move head left
20 // >         ++h;            Move head right
21 // [         while(*h) {     Start loop
22 // ]         }               End loop
23 //
24 //===--------------------------------------------------------------------===//
25
26 #include "BrainF.h"
27 #include "llvm/Constants.h"
28 #include "llvm/Instructions.h"
29 #include "llvm/Intrinsics.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include <iostream>
32 using namespace llvm;
33
34 //Set the constants for naming
35 const char *BrainF::tapereg = "tape";
36 const char *BrainF::headreg = "head";
37 const char *BrainF::label   = "brainf";
38 const char *BrainF::testreg = "test";
39
40 Module *BrainF::parse(std::istream *in1, int mem, CompileFlags cf,
41                       LLVMContext& Context) {
42   in       = in1;
43   memtotal = mem;
44   comflag  = cf;
45
46   header(Context);
47   readloop(0, 0, 0, Context);
48   delete builder;
49   return module;
50 }
51
52 void BrainF::header(LLVMContext& C) {
53   module = new Module("BrainF", C);
54
55   //Function prototypes
56
57   //declare void @llvm.memset.i32(i8 *, i8, i32, i32)
58   const Type *Tys[] = { Type::getInt32Ty(C) };
59   Function *memset_func = Intrinsic::getDeclaration(module, Intrinsic::memset,
60                                                     Tys, 1);
61
62   //declare i32 @getchar()
63   getchar_func = cast<Function>(module->
64     getOrInsertFunction("getchar", IntegerType::getInt32Ty(C), NULL));
65
66   //declare i32 @putchar(i32)
67   putchar_func = cast<Function>(module->
68     getOrInsertFunction("putchar", IntegerType::getInt32Ty(C),
69                         IntegerType::getInt32Ty(C), NULL));
70
71
72   //Function header
73
74   //define void @brainf()
75   brainf_func = cast<Function>(module->
76     getOrInsertFunction("brainf", Type::getVoidTy(C), NULL));
77
78   builder = new IRBuilder<>(BasicBlock::Create(C, label, brainf_func));
79
80   //%arr = malloc i8, i32 %d
81   ConstantInt *val_mem = ConstantInt::get(C, APInt(32, memtotal));
82   BasicBlock* BB = builder->GetInsertBlock();
83   const Type* IntPtrTy = IntegerType::getInt32Ty(C);
84   const Type* Int8Ty = IntegerType::getInt8Ty(C);
85   Constant* allocsize = ConstantExpr::getSizeOf(Int8Ty);
86   allocsize = ConstantExpr::getTruncOrBitCast(allocsize, IntPtrTy);
87   ptr_arr = CallInst::CreateMalloc(BB, IntPtrTy, Int8Ty, allocsize, val_mem, 
88                                    NULL, "arr");
89   BB->getInstList().push_back(cast<Instruction>(ptr_arr));
90
91   //call void @llvm.memset.i32(i8 *%arr, i8 0, i32 %d, i32 1)
92   {
93     Value *memset_params[] = {
94       ptr_arr,
95       ConstantInt::get(C, APInt(8, 0)),
96       val_mem,
97       ConstantInt::get(C, APInt(32, 1))
98     };
99
100     CallInst *memset_call = builder->
101       CreateCall(memset_func, memset_params, array_endof(memset_params));
102     memset_call->setTailCall(false);
103   }
104
105   //%arrmax = getelementptr i8 *%arr, i32 %d
106   if (comflag & flag_arraybounds) {
107     ptr_arrmax = builder->
108       CreateGEP(ptr_arr, ConstantInt::get(C, APInt(32, memtotal)), "arrmax");
109   }
110
111   //%head.%d = getelementptr i8 *%arr, i32 %d
112   curhead = builder->CreateGEP(ptr_arr,
113                                ConstantInt::get(C, APInt(32, memtotal/2)),
114                                headreg);
115
116
117
118   //Function footer
119
120   //brainf.end:
121   endbb = BasicBlock::Create(C, label, brainf_func);
122
123   //call free(i8 *%arr)
124   endbb->getInstList().push_back(CallInst::CreateFree(ptr_arr, endbb));
125
126   //ret void
127   ReturnInst::Create(C, endbb);
128
129
130
131   //Error block for array out of bounds
132   if (comflag & flag_arraybounds)
133   {
134     //@aberrormsg = internal constant [%d x i8] c"\00"
135     Constant *msg_0 =
136       ConstantArray::get(C, "Error: The head has left the tape.", true);
137
138     GlobalVariable *aberrormsg = new GlobalVariable(
139       *module,
140       msg_0->getType(),
141       true,
142       GlobalValue::InternalLinkage,
143       msg_0,
144       "aberrormsg");
145
146     //declare i32 @puts(i8 *)
147     Function *puts_func = cast<Function>(module->
148       getOrInsertFunction("puts", IntegerType::getInt32Ty(C),
149                       PointerType::getUnqual(IntegerType::getInt8Ty(C)), NULL));
150
151     //brainf.aberror:
152     aberrorbb = BasicBlock::Create(C, label, brainf_func);
153
154     //call i32 @puts(i8 *getelementptr([%d x i8] *@aberrormsg, i32 0, i32 0))
155     {
156       Constant *zero_32 = Constant::getNullValue(IntegerType::getInt32Ty(C));
157
158       Constant *gep_params[] = {
159         zero_32,
160         zero_32
161       };
162
163       Constant *msgptr = ConstantExpr::
164         getGetElementPtr(aberrormsg, gep_params,
165                          array_lengthof(gep_params));
166
167       Value *puts_params[] = {
168         msgptr
169       };
170
171       CallInst *puts_call =
172         CallInst::Create(puts_func,
173                          puts_params, array_endof(puts_params),
174                          "", aberrorbb);
175       puts_call->setTailCall(false);
176     }
177
178     //br label %brainf.end
179     BranchInst::Create(endbb, aberrorbb);
180   }
181 }
182
183 void BrainF::readloop(PHINode *phi, BasicBlock *oldbb, BasicBlock *testbb,
184                       LLVMContext &C) {
185   Symbol cursym = SYM_NONE;
186   int curvalue = 0;
187   Symbol nextsym = SYM_NONE;
188   int nextvalue = 0;
189   char c;
190   int loop;
191   int direction;
192
193   while(cursym != SYM_EOF && cursym != SYM_ENDLOOP) {
194     // Write out commands
195     switch(cursym) {
196       case SYM_NONE:
197         // Do nothing
198         break;
199
200       case SYM_READ:
201         {
202           //%tape.%d = call i32 @getchar()
203           CallInst *getchar_call = builder->CreateCall(getchar_func, tapereg);
204           getchar_call->setTailCall(false);
205           Value *tape_0 = getchar_call;
206
207           //%tape.%d = trunc i32 %tape.%d to i8
208           Value *tape_1 = builder->
209             CreateTrunc(tape_0, IntegerType::getInt8Ty(C), tapereg);
210
211           //store i8 %tape.%d, i8 *%head.%d
212           builder->CreateStore(tape_1, curhead);
213         }
214         break;
215
216       case SYM_WRITE:
217         {
218           //%tape.%d = load i8 *%head.%d
219           LoadInst *tape_0 = builder->CreateLoad(curhead, tapereg);
220
221           //%tape.%d = sext i8 %tape.%d to i32
222           Value *tape_1 = builder->
223             CreateSExt(tape_0, IntegerType::getInt32Ty(C), tapereg);
224
225           //call i32 @putchar(i32 %tape.%d)
226           Value *putchar_params[] = {
227             tape_1
228           };
229           CallInst *putchar_call = builder->
230             CreateCall(putchar_func,
231                        putchar_params, array_endof(putchar_params));
232           putchar_call->setTailCall(false);
233         }
234         break;
235
236       case SYM_MOVE:
237         {
238           //%head.%d = getelementptr i8 *%head.%d, i32 %d
239           curhead = builder->
240             CreateGEP(curhead, ConstantInt::get(C, APInt(32, curvalue)),
241                       headreg);
242
243           //Error block for array out of bounds
244           if (comflag & flag_arraybounds)
245           {
246             //%test.%d = icmp uge i8 *%head.%d, %arrmax
247             Value *test_0 = builder->
248               CreateICmpUGE(curhead, ptr_arrmax, testreg);
249
250             //%test.%d = icmp ult i8 *%head.%d, %arr
251             Value *test_1 = builder->
252               CreateICmpULT(curhead, ptr_arr, testreg);
253
254             //%test.%d = or i1 %test.%d, %test.%d
255             Value *test_2 = builder->
256               CreateOr(test_0, test_1, testreg);
257
258             //br i1 %test.%d, label %main.%d, label %main.%d
259             BasicBlock *nextbb = BasicBlock::Create(C, label, brainf_func);
260             builder->CreateCondBr(test_2, aberrorbb, nextbb);
261
262             //main.%d:
263             builder->SetInsertPoint(nextbb);
264           }
265         }
266         break;
267
268       case SYM_CHANGE:
269         {
270           //%tape.%d = load i8 *%head.%d
271           LoadInst *tape_0 = builder->CreateLoad(curhead, tapereg);
272
273           //%tape.%d = add i8 %tape.%d, %d
274           Value *tape_1 = builder->
275             CreateAdd(tape_0, ConstantInt::get(C, APInt(8, curvalue)), tapereg);
276
277           //store i8 %tape.%d, i8 *%head.%d\n"
278           builder->CreateStore(tape_1, curhead);
279         }
280         break;
281
282       case SYM_LOOP:
283         {
284           //br label %main.%d
285           BasicBlock *testbb = BasicBlock::Create(C, label, brainf_func);
286           builder->CreateBr(testbb);
287
288           //main.%d:
289           BasicBlock *bb_0 = builder->GetInsertBlock();
290           BasicBlock *bb_1 = BasicBlock::Create(C, label, brainf_func);
291           builder->SetInsertPoint(bb_1);
292
293           // Make part of PHI instruction now, wait until end of loop to finish
294           PHINode *phi_0 =
295             PHINode::Create(PointerType::getUnqual(IntegerType::getInt8Ty(C)),
296                             headreg, testbb);
297           phi_0->reserveOperandSpace(2);
298           phi_0->addIncoming(curhead, bb_0);
299           curhead = phi_0;
300
301           readloop(phi_0, bb_1, testbb, C);
302         }
303         break;
304
305       default:
306         std::cerr << "Error: Unknown symbol.\n";
307         abort();
308         break;
309     }
310
311     cursym = nextsym;
312     curvalue = nextvalue;
313     nextsym = SYM_NONE;
314
315     // Reading stdin loop
316     loop = (cursym == SYM_NONE)
317         || (cursym == SYM_MOVE)
318         || (cursym == SYM_CHANGE);
319     while(loop) {
320       *in>>c;
321       if (in->eof()) {
322         if (cursym == SYM_NONE) {
323           cursym = SYM_EOF;
324         } else {
325           nextsym = SYM_EOF;
326         }
327         loop = 0;
328       } else {
329         direction = 1;
330         switch(c) {
331           case '-':
332             direction = -1;
333             // Fall through
334
335           case '+':
336             if (cursym == SYM_CHANGE) {
337               curvalue += direction;
338               // loop = 1
339             } else {
340               if (cursym == SYM_NONE) {
341                 cursym = SYM_CHANGE;
342                 curvalue = direction;
343                 // loop = 1
344               } else {
345                 nextsym = SYM_CHANGE;
346                 nextvalue = direction;
347                 loop = 0;
348               }
349             }
350             break;
351
352           case '<':
353             direction = -1;
354             // Fall through
355
356           case '>':
357             if (cursym == SYM_MOVE) {
358               curvalue += direction;
359               // loop = 1
360             } else {
361               if (cursym == SYM_NONE) {
362                 cursym = SYM_MOVE;
363                 curvalue = direction;
364                 // loop = 1
365               } else {
366                 nextsym = SYM_MOVE;
367                 nextvalue = direction;
368                 loop = 0;
369               }
370             }
371             break;
372
373           case ',':
374             if (cursym == SYM_NONE) {
375               cursym = SYM_READ;
376             } else {
377               nextsym = SYM_READ;
378             }
379             loop = 0;
380             break;
381
382           case '.':
383             if (cursym == SYM_NONE) {
384               cursym = SYM_WRITE;
385             } else {
386               nextsym = SYM_WRITE;
387             }
388             loop = 0;
389             break;
390
391           case '[':
392             if (cursym == SYM_NONE) {
393               cursym = SYM_LOOP;
394             } else {
395               nextsym = SYM_LOOP;
396             }
397             loop = 0;
398             break;
399
400           case ']':
401             if (cursym == SYM_NONE) {
402               cursym = SYM_ENDLOOP;
403             } else {
404               nextsym = SYM_ENDLOOP;
405             }
406             loop = 0;
407             break;
408
409           // Ignore other characters
410           default:
411             break;
412         }
413       }
414     }
415   }
416
417   if (cursym == SYM_ENDLOOP) {
418     if (!phi) {
419       std::cerr << "Error: Extra ']'\n";
420       abort();
421     }
422
423     // Write loop test
424     {
425       //br label %main.%d
426       builder->CreateBr(testbb);
427
428       //main.%d:
429
430       //%head.%d = phi i8 *[%head.%d, %main.%d], [%head.%d, %main.%d]
431       //Finish phi made at beginning of loop
432       phi->addIncoming(curhead, builder->GetInsertBlock());
433       Value *head_0 = phi;
434
435       //%tape.%d = load i8 *%head.%d
436       LoadInst *tape_0 = new LoadInst(head_0, tapereg, testbb);
437
438       //%test.%d = icmp eq i8 %tape.%d, 0
439       ICmpInst *test_0 = new ICmpInst(*testbb, ICmpInst::ICMP_EQ, tape_0,
440                                     ConstantInt::get(C, APInt(8, 0)), testreg);
441
442       //br i1 %test.%d, label %main.%d, label %main.%d
443       BasicBlock *bb_0 = BasicBlock::Create(C, label, brainf_func);
444       BranchInst::Create(bb_0, oldbb, test_0, testbb);
445
446       //main.%d:
447       builder->SetInsertPoint(bb_0);
448
449       //%head.%d = phi i8 *[%head.%d, %main.%d]
450       PHINode *phi_1 = builder->
451         CreatePHI(PointerType::getUnqual(IntegerType::getInt8Ty(C)), headreg);
452       phi_1->reserveOperandSpace(1);
453       phi_1->addIncoming(head_0, testbb);
454       curhead = phi_1;
455     }
456
457     return;
458   }
459
460   //End of the program, so go to return block
461   builder->CreateBr(endbb);
462
463   if (phi) {
464     std::cerr << "Error: Missing ']'\n";
465     abort();
466   }
467 }