use proper namespace qualifications
[oota-llvm.git] / examples / BrainF / BrainF.cpp
1 //===-- BrainF.cpp - BrainF compiler example ----------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===--------------------------------------------------------------------===//
9 //
10 // This class compiles the BrainF language into LLVM assembly.
11 //
12 // The BrainF language has 8 commands:
13 // Command   Equivalent C    Action
14 // -------   ------------    ------
15 // ,         *h=getchar();   Read a character from stdin, 255 on EOF
16 // .         putchar(*h);    Write a character to stdout
17 // -         --*h;           Decrement tape
18 // +         ++*h;           Increment tape
19 // <         --h;            Move head left
20 // >         ++h;            Move head right
21 // [         while(*h) {     Start loop
22 // ]         }               End loop
23 //
24 //===--------------------------------------------------------------------===//
25
26 #include "BrainF.h"
27 #include "llvm/Constants.h"
28 #include "llvm/Intrinsics.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include <iostream>
31 using namespace llvm;
32
33 //Set the constants for naming
34 const char *BrainF::tapereg = "tape";
35 const char *BrainF::headreg = "head";
36 const char *BrainF::label   = "brainf";
37 const char *BrainF::testreg = "test";
38
39 Module *BrainF::parse(std::istream *in1, int mem, CompileFlags cf) {
40   in       = in1;
41   memtotal = mem;
42   comflag  = cf;
43
44   header();
45   readloop(0, 0, 0);
46   delete builder;
47   return module;
48 }
49
50 void BrainF::header() {
51   module = new Module("BrainF");
52
53   //Function prototypes
54
55   //declare void @llvm.memset.i32(i8 *, i8, i32, i32)
56   Function *memset_func = Intrinsic::getDeclaration(module, Intrinsic::memset_i32);
57
58   //declare i32 @getchar()
59   getchar_func = cast<Function>(module->
60     getOrInsertFunction("getchar", IntegerType::Int32Ty, NULL));
61
62   //declare i32 @putchar(i32)
63   putchar_func = cast<Function>(module->
64     getOrInsertFunction("putchar", IntegerType::Int32Ty,
65                         IntegerType::Int32Ty, NULL));
66
67
68   //Function header
69
70   //define void @brainf()
71   brainf_func = cast<Function>(module->
72     getOrInsertFunction("brainf", Type::VoidTy, NULL));
73
74   builder = new IRBuilder<>(BasicBlock::Create(label, brainf_func));
75
76   //%arr = malloc i8, i32 %d
77   ConstantInt *val_mem = ConstantInt::get(APInt(32, memtotal));
78   ptr_arr = builder->CreateMalloc(IntegerType::Int8Ty, val_mem, "arr");
79
80   //call void @llvm.memset.i32(i8 *%arr, i8 0, i32 %d, i32 1)
81   {
82     Value *memset_params[] = {
83       ptr_arr,
84       ConstantInt::get(APInt(8, 0)),
85       val_mem,
86       ConstantInt::get(APInt(32, 1))
87     };
88
89     CallInst *memset_call = builder->
90       CreateCall(memset_func, memset_params, array_endof(memset_params));
91     memset_call->setTailCall(false);
92   }
93
94   //%arrmax = getelementptr i8 *%arr, i32 %d
95   if (comflag & flag_arraybounds) {
96     ptr_arrmax = builder->
97       CreateGEP(ptr_arr, ConstantInt::get(APInt(32, memtotal)), "arrmax");
98   }
99
100   //%head.%d = getelementptr i8 *%arr, i32 %d
101   curhead = builder->CreateGEP(ptr_arr,
102                                ConstantInt::get(APInt(32, memtotal/2)),
103                                headreg);
104
105
106
107   //Function footer
108
109   //brainf.end:
110   endbb = BasicBlock::Create(label, brainf_func);
111
112   //free i8 *%arr
113   new FreeInst(ptr_arr, endbb);
114
115   //ret void
116   ReturnInst::Create(endbb);
117
118
119
120   //Error block for array out of bounds
121   if (comflag & flag_arraybounds)
122   {
123     //@aberrormsg = internal constant [%d x i8] c"\00"
124     Constant *msg_0 = ConstantArray::
125       get("Error: The head has left the tape.", true);
126
127     GlobalVariable *aberrormsg = new GlobalVariable(
128       msg_0->getType(),
129       true,
130       GlobalValue::InternalLinkage,
131       msg_0,
132       "aberrormsg",
133       module);
134
135     //declare i32 @puts(i8 *)
136     Function *puts_func = cast<Function>(module->
137       getOrInsertFunction("puts", IntegerType::Int32Ty,
138                           PointerType::getUnqual(IntegerType::Int8Ty), NULL));
139
140     //brainf.aberror:
141     aberrorbb = BasicBlock::Create(label, brainf_func);
142
143     //call i32 @puts(i8 *getelementptr([%d x i8] *@aberrormsg, i32 0, i32 0))
144     {
145       Constant *zero_32 = Constant::getNullValue(IntegerType::Int32Ty);
146
147       Constant *gep_params[] = {
148         zero_32,
149         zero_32
150       };
151
152       Constant *msgptr = ConstantExpr::
153         getGetElementPtr(aberrormsg, gep_params,
154                          array_lengthof(gep_params));
155
156       Value *puts_params[] = {
157         msgptr
158       };
159
160       CallInst *puts_call =
161         CallInst::Create(puts_func,
162                          puts_params, array_endof(puts_params),
163                          "", aberrorbb);
164       puts_call->setTailCall(false);
165     }
166
167     //br label %brainf.end
168     BranchInst::Create(endbb, aberrorbb);
169   }
170 }
171
172 void BrainF::readloop(PHINode *phi, BasicBlock *oldbb, BasicBlock *testbb) {
173   Symbol cursym = SYM_NONE;
174   int curvalue = 0;
175   Symbol nextsym = SYM_NONE;
176   int nextvalue = 0;
177   char c;
178   int loop;
179   int direction;
180
181   while(cursym != SYM_EOF && cursym != SYM_ENDLOOP) {
182     // Write out commands
183     switch(cursym) {
184       case SYM_NONE:
185         // Do nothing
186         break;
187
188       case SYM_READ:
189         {
190           //%tape.%d = call i32 @getchar()
191           CallInst *getchar_call = builder->CreateCall(getchar_func, tapereg);
192           getchar_call->setTailCall(false);
193           Value *tape_0 = getchar_call;
194
195           //%tape.%d = trunc i32 %tape.%d to i8
196           Value *tape_1 = builder->
197             CreateTrunc(tape_0, IntegerType::Int8Ty, tapereg);
198
199           //store i8 %tape.%d, i8 *%head.%d
200           builder->CreateStore(tape_1, curhead);
201         }
202         break;
203
204       case SYM_WRITE:
205         {
206           //%tape.%d = load i8 *%head.%d
207           LoadInst *tape_0 = builder->CreateLoad(curhead, tapereg);
208
209           //%tape.%d = sext i8 %tape.%d to i32
210           Value *tape_1 = builder->
211             CreateSExt(tape_0, IntegerType::Int32Ty, tapereg);
212
213           //call i32 @putchar(i32 %tape.%d)
214           Value *putchar_params[] = {
215             tape_1
216           };
217           CallInst *putchar_call = builder->
218             CreateCall(putchar_func,
219                        putchar_params, array_endof(putchar_params));
220           putchar_call->setTailCall(false);
221         }
222         break;
223
224       case SYM_MOVE:
225         {
226           //%head.%d = getelementptr i8 *%head.%d, i32 %d
227           curhead = builder->
228             CreateGEP(curhead, ConstantInt::get(APInt(32, curvalue)),
229                       headreg);
230
231           //Error block for array out of bounds
232           if (comflag & flag_arraybounds)
233           {
234             //%test.%d = icmp uge i8 *%head.%d, %arrmax
235             Value *test_0 = builder->
236               CreateICmpUGE(curhead, ptr_arrmax, testreg);
237
238             //%test.%d = icmp ult i8 *%head.%d, %arr
239             Value *test_1 = builder->
240               CreateICmpULT(curhead, ptr_arr, testreg);
241
242             //%test.%d = or i1 %test.%d, %test.%d
243             Value *test_2 = builder->
244               CreateOr(test_0, test_1, testreg);
245
246             //br i1 %test.%d, label %main.%d, label %main.%d
247             BasicBlock *nextbb = BasicBlock::Create(label, brainf_func);
248             builder->CreateCondBr(test_2, aberrorbb, nextbb);
249
250             //main.%d:
251             builder->SetInsertPoint(nextbb);
252           }
253         }
254         break;
255
256       case SYM_CHANGE:
257         {
258           //%tape.%d = load i8 *%head.%d
259           LoadInst *tape_0 = builder->CreateLoad(curhead, tapereg);
260
261           //%tape.%d = add i8 %tape.%d, %d
262           Value *tape_1 = builder->
263             CreateAdd(tape_0, ConstantInt::get(APInt(8, curvalue)), tapereg);
264
265           //store i8 %tape.%d, i8 *%head.%d\n"
266           builder->CreateStore(tape_1, curhead);
267         }
268         break;
269
270       case SYM_LOOP:
271         {
272           //br label %main.%d
273           BasicBlock *testbb = BasicBlock::Create(label, brainf_func);
274           builder->CreateBr(testbb);
275
276           //main.%d:
277           BasicBlock *bb_0 = builder->GetInsertBlock();
278           BasicBlock *bb_1 = BasicBlock::Create(label, brainf_func);
279           builder->SetInsertPoint(bb_1);
280
281           // Make part of PHI instruction now, wait until end of loop to finish
282           PHINode *phi_0 =
283             PHINode::Create(PointerType::getUnqual(IntegerType::Int8Ty),
284                             headreg, testbb);
285           phi_0->reserveOperandSpace(2);
286           phi_0->addIncoming(curhead, bb_0);
287           curhead = phi_0;
288
289           readloop(phi_0, bb_1, testbb);
290         }
291         break;
292
293       default:
294         std::cerr << "Error: Unknown symbol.\n";
295         abort();
296         break;
297     }
298
299     cursym = nextsym;
300     curvalue = nextvalue;
301     nextsym = SYM_NONE;
302
303     // Reading stdin loop
304     loop = (cursym == SYM_NONE)
305         || (cursym == SYM_MOVE)
306         || (cursym == SYM_CHANGE);
307     while(loop) {
308       *in>>c;
309       if (in->eof()) {
310         if (cursym == SYM_NONE) {
311           cursym = SYM_EOF;
312         } else {
313           nextsym = SYM_EOF;
314         }
315         loop = 0;
316       } else {
317         direction = 1;
318         switch(c) {
319           case '-':
320             direction = -1;
321             // Fall through
322
323           case '+':
324             if (cursym == SYM_CHANGE) {
325               curvalue += direction;
326               // loop = 1
327             } else {
328               if (cursym == SYM_NONE) {
329                 cursym = SYM_CHANGE;
330                 curvalue = direction;
331                 // loop = 1
332               } else {
333                 nextsym = SYM_CHANGE;
334                 nextvalue = direction;
335                 loop = 0;
336               }
337             }
338             break;
339
340           case '<':
341             direction = -1;
342             // Fall through
343
344           case '>':
345             if (cursym == SYM_MOVE) {
346               curvalue += direction;
347               // loop = 1
348             } else {
349               if (cursym == SYM_NONE) {
350                 cursym = SYM_MOVE;
351                 curvalue = direction;
352                 // loop = 1
353               } else {
354                 nextsym = SYM_MOVE;
355                 nextvalue = direction;
356                 loop = 0;
357               }
358             }
359             break;
360
361           case ',':
362             if (cursym == SYM_NONE) {
363               cursym = SYM_READ;
364             } else {
365               nextsym = SYM_READ;
366             }
367             loop = 0;
368             break;
369
370           case '.':
371             if (cursym == SYM_NONE) {
372               cursym = SYM_WRITE;
373             } else {
374               nextsym = SYM_WRITE;
375             }
376             loop = 0;
377             break;
378
379           case '[':
380             if (cursym == SYM_NONE) {
381               cursym = SYM_LOOP;
382             } else {
383               nextsym = SYM_LOOP;
384             }
385             loop = 0;
386             break;
387
388           case ']':
389             if (cursym == SYM_NONE) {
390               cursym = SYM_ENDLOOP;
391             } else {
392               nextsym = SYM_ENDLOOP;
393             }
394             loop = 0;
395             break;
396
397           // Ignore other characters
398           default:
399             break;
400         }
401       }
402     }
403   }
404
405   if (cursym == SYM_ENDLOOP) {
406     if (!phi) {
407       std::cerr << "Error: Extra ']'\n";
408       abort();
409     }
410
411     // Write loop test
412     {
413       //br label %main.%d
414       builder->CreateBr(testbb);
415
416       //main.%d:
417
418       //%head.%d = phi i8 *[%head.%d, %main.%d], [%head.%d, %main.%d]
419       //Finish phi made at beginning of loop
420       phi->addIncoming(curhead, builder->GetInsertBlock());
421       Value *head_0 = phi;
422
423       //%tape.%d = load i8 *%head.%d
424       LoadInst *tape_0 = new LoadInst(head_0, tapereg, testbb);
425
426       //%test.%d = icmp eq i8 %tape.%d, 0
427       ICmpInst *test_0 = new ICmpInst(ICmpInst::ICMP_EQ, tape_0,
428                                       ConstantInt::get(APInt(8, 0)), testreg,
429                                       testbb);
430
431       //br i1 %test.%d, label %main.%d, label %main.%d
432       BasicBlock *bb_0 = BasicBlock::Create(label, brainf_func);
433       BranchInst::Create(bb_0, oldbb, test_0, testbb);
434
435       //main.%d:
436       builder->SetInsertPoint(bb_0);
437
438       //%head.%d = phi i8 *[%head.%d, %main.%d]
439       PHINode *phi_1 = builder->
440         CreatePHI(PointerType::getUnqual(IntegerType::Int8Ty), headreg);
441       phi_1->reserveOperandSpace(1);
442       phi_1->addIncoming(head_0, testbb);
443       curhead = phi_1;
444     }
445
446     return;
447   }
448
449   //End of the program, so go to return block
450   builder->CreateBr(endbb);
451
452   if (phi) {
453     std::cerr << "Error: Missing ']'\n";
454     abort();
455   }
456 }