b75b51943c24525ef1db1d40d3ec9394a66bea8e
[oota-llvm.git] / examples / BrainF / BrainF.cpp
1 //===-- BrainF.cpp - BrainF compiler example ----------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Sterling Stein and is distributed under the
6 // University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===--------------------------------------------------------------------===//
9 //
10 // This class compiles the BrainF language into LLVM assembly.
11 //
12 // The BrainF language has 8 commands:
13 // Command   Equivalent C    Action
14 // -------   ------------    ------
15 // ,         *h=getchar();   Read a character from stdin, 255 on EOF
16 // .         putchar(*h);    Write a character to stdout
17 // -         --*h;           Decrement tape
18 // +         ++*h;           Increment tape
19 // <         --h;            Move head left
20 // >         ++h;            Move head right
21 // [         while(*h) {     Start loop
22 // ]         }               End loop
23 //
24 //===--------------------------------------------------------------------===//
25
26 #include "BrainF.h"
27 #include "llvm/Constants.h"
28 #include "llvm/ADT/STLExtras.h"
29
30 using namespace llvm;
31
32 //Set the constants for naming
33 const char *BrainF::tapereg = "tape";
34 const char *BrainF::headreg = "head";
35 const char *BrainF::label   = "brainf";
36 const char *BrainF::testreg = "test";
37
38 Module *BrainF::parse(std::istream *in1, int mem, CompileFlags cf) {
39   in       = in1;
40   memtotal = mem;
41   comflag  = cf;
42
43   header();
44   readloop(0, 0, 0);
45   delete builder;
46   return module;
47 }
48
49 void BrainF::header() {
50   module = new Module("BrainF");
51
52   //Function prototypes
53
54   //declare void @llvm.memset.i32(i8 *, i8, i32, i32)
55   Function *memset_func = cast<Function>(module->
56     getOrInsertFunction("llvm.memset.i32", Type::VoidTy,
57                         PointerType::getUnqual(IntegerType::Int8Ty),
58                         IntegerType::Int8Ty, IntegerType::Int32Ty,
59                         IntegerType::Int32Ty, NULL));
60
61   //declare i32 @getchar()
62   getchar_func = cast<Function>(module->
63     getOrInsertFunction("getchar", IntegerType::Int32Ty, NULL));
64
65   //declare i32 @putchar(i32)
66   putchar_func = cast<Function>(module->
67     getOrInsertFunction("putchar", IntegerType::Int32Ty,
68                         IntegerType::Int32Ty, NULL));
69
70
71   //Function header
72
73   //define void @brainf()
74   brainf_func = cast<Function>(module->
75     getOrInsertFunction("brainf", Type::VoidTy, NULL));
76
77   builder = new LLVMBuilder(new BasicBlock(label, brainf_func));
78
79   //%arr = malloc i8, i32 %d
80   ConstantInt *val_mem = ConstantInt::get(APInt(32, memtotal));
81   ptr_arr = builder->CreateMalloc(IntegerType::Int8Ty, val_mem, "arr");
82
83   //call void @llvm.memset.i32(i8 *%arr, i8 0, i32 %d, i32 1)
84   {
85     Value *memset_params[] = {
86       ptr_arr,
87       ConstantInt::get(APInt(8, 0)),
88       val_mem,
89       ConstantInt::get(APInt(32, 1))
90     };
91
92     CallInst *memset_call = builder->
93       CreateCall(memset_func, memset_params, array_endof(memset_params));
94     memset_call->setTailCall(false);
95   }
96
97   //%arrmax = getelementptr i8 *%arr, i32 %d
98   if (comflag & flag_arraybounds) {
99     ptr_arrmax = builder->
100       CreateGEP(ptr_arr, ConstantInt::get(APInt(32, memtotal)), "arrmax");
101   }
102
103   //%head.%d = getelementptr i8 *%arr, i32 %d
104   curhead = builder->CreateGEP(ptr_arr,
105                                ConstantInt::get(APInt(32, memtotal/2)),
106                                headreg);
107
108
109
110   //Function footer
111
112   //brainf.end:
113   endbb = new BasicBlock(label, brainf_func);
114
115   //free i8 *%arr
116   new FreeInst(ptr_arr, endbb);
117
118   //ret void
119   new ReturnInst(endbb);
120
121
122
123   //Error block for array out of bounds
124   if (comflag & flag_arraybounds)
125   {
126     //@aberrormsg = internal constant [%d x i8] c"\00"
127     Constant *msg_0 = ConstantArray::
128       get("Error: The head has left the tape.", true);
129
130     GlobalVariable *aberrormsg = new GlobalVariable(
131       msg_0->getType(),
132       true,
133       GlobalValue::InternalLinkage,
134       msg_0,
135       "aberrormsg",
136       module);
137
138     //declare i32 @puts(i8 *)
139     Function *puts_func = cast<Function>(module->
140       getOrInsertFunction("puts", IntegerType::Int32Ty,
141                           PointerType::getUnqual(IntegerType::Int8Ty), NULL));
142
143     //brainf.aberror:
144     aberrorbb = new BasicBlock(label, brainf_func);
145
146     //call i32 @puts(i8 *getelementptr([%d x i8] *@aberrormsg, i32 0, i32 0))
147     {
148       Constant *zero_32 = Constant::getNullValue(IntegerType::Int32Ty);
149
150       Constant *gep_params[] = {
151         zero_32,
152         zero_32
153       };
154
155       Constant *msgptr = ConstantExpr::
156         getGetElementPtr(aberrormsg, gep_params,
157                          array_lengthof(gep_params));
158
159       Value *puts_params[] = {
160         msgptr
161       };
162
163       CallInst *puts_call =
164         new CallInst(puts_func,
165                      puts_params, array_endof(puts_params),
166                      "", aberrorbb);
167       puts_call->setTailCall(false);
168     }
169
170     //br label %brainf.end
171     new BranchInst(endbb, aberrorbb);
172   }
173 }
174
175 void BrainF::readloop(PHINode *phi, BasicBlock *oldbb, BasicBlock *testbb) {
176   Symbol cursym = SYM_NONE;
177   int curvalue = 0;
178   Symbol nextsym = SYM_NONE;
179   int nextvalue = 0;
180   char c;
181   int loop;
182   int direction;
183
184   while(cursym != SYM_EOF && cursym != SYM_ENDLOOP) {
185     // Write out commands
186     switch(cursym) {
187       case SYM_NONE:
188         // Do nothing
189         break;
190
191       case SYM_READ:
192         {
193           //%tape.%d = call i32 @getchar()
194           CallInst *getchar_call = builder->CreateCall(getchar_func, tapereg);
195           getchar_call->setTailCall(false);
196           Value *tape_0 = getchar_call;
197
198           //%tape.%d = trunc i32 %tape.%d to i8
199           TruncInst *tape_1 = builder->
200             CreateTrunc(tape_0, IntegerType::Int8Ty, tapereg);
201
202           //store i8 %tape.%d, i8 *%head.%d
203           builder->CreateStore(tape_1, curhead);
204         }
205         break;
206
207       case SYM_WRITE:
208         {
209           //%tape.%d = load i8 *%head.%d
210           LoadInst *tape_0 = builder->CreateLoad(curhead, tapereg);
211
212           //%tape.%d = sext i8 %tape.%d to i32
213           SExtInst *tape_1 = builder->
214             CreateSExt(tape_0, IntegerType::Int32Ty, tapereg);
215
216           //call i32 @putchar(i32 %tape.%d)
217           Value *putchar_params[] = {
218             tape_1
219           };
220           CallInst *putchar_call = builder->
221             CreateCall(putchar_func,
222                        putchar_params, array_endof(putchar_params));
223           putchar_call->setTailCall(false);
224         }
225         break;
226
227       case SYM_MOVE:
228         {
229           //%head.%d = getelementptr i8 *%head.%d, i32 %d
230           curhead = builder->
231             CreateGEP(curhead, ConstantInt::get(APInt(32, curvalue)),
232                       headreg);
233
234           //Error block for array out of bounds
235           if (comflag & flag_arraybounds)
236           {
237             //%test.%d = icmp uge i8 *%head.%d, %arrmax
238             ICmpInst *test_0 = builder->
239               CreateICmpUGE(curhead, ptr_arrmax, testreg);
240
241             //%test.%d = icmp ult i8 *%head.%d, %arr
242             ICmpInst *test_1 = builder->
243               CreateICmpULT(curhead, ptr_arr, testreg);
244
245             //%test.%d = or i1 %test.%d, %test.%d
246             BinaryOperator *test_2 = builder->
247               CreateOr(test_0, test_1, testreg);
248
249             //br i1 %test.%d, label %main.%d, label %main.%d
250             BasicBlock *nextbb = new BasicBlock(label, brainf_func);
251             builder->CreateCondBr(test_2, aberrorbb, nextbb);
252
253             //main.%d:
254             builder->SetInsertPoint(nextbb);
255           }
256         }
257         break;
258
259       case SYM_CHANGE:
260         {
261           //%tape.%d = load i8 *%head.%d
262           LoadInst *tape_0 = builder->CreateLoad(curhead, tapereg);
263
264           //%tape.%d = add i8 %tape.%d, %d
265           BinaryOperator *tape_1 = builder->
266             CreateAdd(tape_0, ConstantInt::get(APInt(8, curvalue)), tapereg);
267
268           //store i8 %tape.%d, i8 *%head.%d\n"
269           builder->CreateStore(tape_1, curhead);
270         }
271         break;
272
273       case SYM_LOOP:
274         {
275           //br label %main.%d
276           BasicBlock *testbb = new BasicBlock(label, brainf_func);
277           builder->CreateBr(testbb);
278
279           //main.%d:
280           BasicBlock *bb_0 = builder->GetInsertBlock();
281           BasicBlock *bb_1 = new BasicBlock(label, brainf_func);
282           builder->SetInsertPoint(bb_1);
283
284           //Make part of PHI instruction now, wait until end of loop to finish
285           PHINode *phi_0 = new PHINode(PointerType::getUnqual(IntegerType::Int8Ty),
286                                        headreg, testbb);
287           phi_0->reserveOperandSpace(2);
288           phi_0->addIncoming(curhead, bb_0);
289           curhead = phi_0;
290
291           readloop(phi_0, bb_1, testbb);
292         }
293         break;
294
295       default:
296         cerr<<"Error: Unknown symbol.\n";
297         abort();
298         break;
299     }
300
301     cursym = nextsym;
302     curvalue = nextvalue;
303     nextsym = SYM_NONE;
304
305     // Reading stdin loop
306     loop = (cursym == SYM_NONE)
307         || (cursym == SYM_MOVE)
308         || (cursym == SYM_CHANGE);
309     while(loop) {
310       *in>>c;
311       if (in->eof()) {
312         if (cursym == SYM_NONE) {
313           cursym = SYM_EOF;
314         } else {
315           nextsym = SYM_EOF;
316         }
317         loop = 0;
318       } else {
319         direction = 1;
320         switch(c) {
321           case '-':
322             direction = -1;
323             // Fall through
324
325           case '+':
326             if (cursym == SYM_CHANGE) {
327               curvalue += direction;
328               // loop = 1
329             } else {
330               if (cursym == SYM_NONE) {
331                 cursym = SYM_CHANGE;
332                 curvalue = direction;
333                 // loop = 1
334               } else {
335                 nextsym = SYM_CHANGE;
336                 nextvalue = direction;
337                 loop = 0;
338               }
339             }
340             break;
341
342           case '<':
343             direction = -1;
344             // Fall through
345
346           case '>':
347             if (cursym == SYM_MOVE) {
348               curvalue += direction;
349               // loop = 1
350             } else {
351               if (cursym == SYM_NONE) {
352                 cursym = SYM_MOVE;
353                 curvalue = direction;
354                 // loop = 1
355               } else {
356                 nextsym = SYM_MOVE;
357                 nextvalue = direction;
358                 loop = 0;
359               }
360             }
361             break;
362
363           case ',':
364             if (cursym == SYM_NONE) {
365               cursym = SYM_READ;
366             } else {
367               nextsym = SYM_READ;
368             }
369             loop = 0;
370             break;
371
372           case '.':
373             if (cursym == SYM_NONE) {
374               cursym = SYM_WRITE;
375             } else {
376               nextsym = SYM_WRITE;
377             }
378             loop = 0;
379             break;
380
381           case '[':
382             if (cursym == SYM_NONE) {
383               cursym = SYM_LOOP;
384             } else {
385               nextsym = SYM_LOOP;
386             }
387             loop = 0;
388             break;
389
390           case ']':
391             if (cursym == SYM_NONE) {
392               cursym = SYM_ENDLOOP;
393             } else {
394               nextsym = SYM_ENDLOOP;
395             }
396             loop = 0;
397             break;
398
399           // Ignore other characters
400           default:
401             break;
402         }
403       }
404     }
405   }
406
407   if (cursym == SYM_ENDLOOP) {
408     if (!phi) {
409       cerr<<"Error: Extra ']'\n";
410       abort();
411     }
412
413     // Write loop test
414     {
415       //br label %main.%d
416       builder->CreateBr(testbb);
417
418       //main.%d:
419
420       //%head.%d = phi i8 *[%head.%d, %main.%d], [%head.%d, %main.%d]
421       //Finish phi made at beginning of loop
422       phi->addIncoming(curhead, builder->GetInsertBlock());
423       Value *head_0 = phi;
424
425       //%tape.%d = load i8 *%head.%d
426       LoadInst *tape_0 = new LoadInst(head_0, tapereg, testbb);
427
428       //%test.%d = icmp eq i8 %tape.%d, 0
429       ICmpInst *test_0 = new ICmpInst(ICmpInst::ICMP_EQ, tape_0,
430                                       ConstantInt::get(APInt(8, 0)), testreg,
431                                       testbb);
432
433       //br i1 %test.%d, label %main.%d, label %main.%d
434       BasicBlock *bb_0 = new BasicBlock(label, brainf_func);
435       new BranchInst(bb_0, oldbb, test_0, testbb);
436
437       //main.%d:
438       builder->SetInsertPoint(bb_0);
439
440       //%head.%d = phi i8 *[%head.%d, %main.%d]
441       PHINode *phi_1 = builder->
442         CreatePHI(PointerType::getUnqual(IntegerType::Int8Ty), headreg);
443       phi_1->reserveOperandSpace(1);
444       phi_1->addIncoming(head_0, testbb);
445       curhead = phi_1;
446     }
447
448     return;
449   }
450
451   //End of the program, so go to return block
452   builder->CreateBr(endbb);
453
454   if (phi) {
455     cerr<<"Error: Missing ']'\n";
456     abort();
457   }
458 }