32a14c4d5328bd4260009aa349fc11e75edc32ae
[oota-llvm.git] / examples / BrainF / BrainF.cpp
1 //===-- BrainF.cpp - BrainF compiler example ----------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===--------------------------------------------------------------------===//
9 //
10 // This class compiles the BrainF language into LLVM assembly.
11 //
12 // The BrainF language has 8 commands:
13 // Command   Equivalent C    Action
14 // -------   ------------    ------
15 // ,         *h=getchar();   Read a character from stdin, 255 on EOF
16 // .         putchar(*h);    Write a character to stdout
17 // -         --*h;           Decrement tape
18 // +         ++*h;           Increment tape
19 // <         --h;            Move head left
20 // >         ++h;            Move head right
21 // [         while(*h) {     Start loop
22 // ]         }               End loop
23 //
24 //===--------------------------------------------------------------------===//
25
26 #include "BrainF.h"
27 #include "llvm/Constants.h"
28 #include "llvm/Intrinsics.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include <iostream>
31 using namespace llvm;
32
33 //Set the constants for naming
34 const char *BrainF::tapereg = "tape";
35 const char *BrainF::headreg = "head";
36 const char *BrainF::label   = "brainf";
37 const char *BrainF::testreg = "test";
38
39 Module *BrainF::parse(std::istream *in1, int mem, CompileFlags cf) {
40   in       = in1;
41   memtotal = mem;
42   comflag  = cf;
43
44   header();
45   readloop(0, 0, 0);
46   delete builder;
47   return module;
48 }
49
50 void BrainF::header() {
51   module = new Module("BrainF");
52
53   //Function prototypes
54
55   //declare void @llvm.memset.i32(i8 *, i8, i32, i32)
56   const Type *Tys[] = { Type::Int32Ty };
57   Function *memset_func = Intrinsic::getDeclaration(module, Intrinsic::memset,
58                                                     Tys, 1);
59
60   //declare i32 @getchar()
61   getchar_func = cast<Function>(module->
62     getOrInsertFunction("getchar", IntegerType::Int32Ty, NULL));
63
64   //declare i32 @putchar(i32)
65   putchar_func = cast<Function>(module->
66     getOrInsertFunction("putchar", IntegerType::Int32Ty,
67                         IntegerType::Int32Ty, NULL));
68
69
70   //Function header
71
72   //define void @brainf()
73   brainf_func = cast<Function>(module->
74     getOrInsertFunction("brainf", Type::VoidTy, NULL));
75
76   builder = new IRBuilder<>(BasicBlock::Create(label, brainf_func));
77
78   //%arr = malloc i8, i32 %d
79   ConstantInt *val_mem = ConstantInt::get(APInt(32, memtotal));
80   ptr_arr = builder->CreateMalloc(IntegerType::Int8Ty, val_mem, "arr");
81
82   //call void @llvm.memset.i32(i8 *%arr, i8 0, i32 %d, i32 1)
83   {
84     Value *memset_params[] = {
85       ptr_arr,
86       ConstantInt::get(APInt(8, 0)),
87       val_mem,
88       ConstantInt::get(APInt(32, 1))
89     };
90
91     CallInst *memset_call = builder->
92       CreateCall(memset_func, memset_params, array_endof(memset_params));
93     memset_call->setTailCall(false);
94   }
95
96   //%arrmax = getelementptr i8 *%arr, i32 %d
97   if (comflag & flag_arraybounds) {
98     ptr_arrmax = builder->
99       CreateGEP(ptr_arr, ConstantInt::get(APInt(32, memtotal)), "arrmax");
100   }
101
102   //%head.%d = getelementptr i8 *%arr, i32 %d
103   curhead = builder->CreateGEP(ptr_arr,
104                                ConstantInt::get(APInt(32, memtotal/2)),
105                                headreg);
106
107
108
109   //Function footer
110
111   //brainf.end:
112   endbb = BasicBlock::Create(label, brainf_func);
113
114   //free i8 *%arr
115   new FreeInst(ptr_arr, endbb);
116
117   //ret void
118   ReturnInst::Create(endbb);
119
120
121
122   //Error block for array out of bounds
123   if (comflag & flag_arraybounds)
124   {
125     //@aberrormsg = internal constant [%d x i8] c"\00"
126     Constant *msg_0 = ConstantArray::
127       get("Error: The head has left the tape.", true);
128
129     GlobalVariable *aberrormsg = new GlobalVariable(
130       msg_0->getType(),
131       true,
132       GlobalValue::InternalLinkage,
133       msg_0,
134       "aberrormsg",
135       module);
136
137     //declare i32 @puts(i8 *)
138     Function *puts_func = cast<Function>(module->
139       getOrInsertFunction("puts", IntegerType::Int32Ty,
140                           PointerType::getUnqual(IntegerType::Int8Ty), NULL));
141
142     //brainf.aberror:
143     aberrorbb = BasicBlock::Create(label, brainf_func);
144
145     //call i32 @puts(i8 *getelementptr([%d x i8] *@aberrormsg, i32 0, i32 0))
146     {
147       Constant *zero_32 = Constant::getNullValue(IntegerType::Int32Ty);
148
149       Constant *gep_params[] = {
150         zero_32,
151         zero_32
152       };
153
154       Constant *msgptr = ConstantExpr::
155         getGetElementPtr(aberrormsg, gep_params,
156                          array_lengthof(gep_params));
157
158       Value *puts_params[] = {
159         msgptr
160       };
161
162       CallInst *puts_call =
163         CallInst::Create(puts_func,
164                          puts_params, array_endof(puts_params),
165                          "", aberrorbb);
166       puts_call->setTailCall(false);
167     }
168
169     //br label %brainf.end
170     BranchInst::Create(endbb, aberrorbb);
171   }
172 }
173
174 void BrainF::readloop(PHINode *phi, BasicBlock *oldbb, BasicBlock *testbb) {
175   Symbol cursym = SYM_NONE;
176   int curvalue = 0;
177   Symbol nextsym = SYM_NONE;
178   int nextvalue = 0;
179   char c;
180   int loop;
181   int direction;
182
183   while(cursym != SYM_EOF && cursym != SYM_ENDLOOP) {
184     // Write out commands
185     switch(cursym) {
186       case SYM_NONE:
187         // Do nothing
188         break;
189
190       case SYM_READ:
191         {
192           //%tape.%d = call i32 @getchar()
193           CallInst *getchar_call = builder->CreateCall(getchar_func, tapereg);
194           getchar_call->setTailCall(false);
195           Value *tape_0 = getchar_call;
196
197           //%tape.%d = trunc i32 %tape.%d to i8
198           Value *tape_1 = builder->
199             CreateTrunc(tape_0, IntegerType::Int8Ty, tapereg);
200
201           //store i8 %tape.%d, i8 *%head.%d
202           builder->CreateStore(tape_1, curhead);
203         }
204         break;
205
206       case SYM_WRITE:
207         {
208           //%tape.%d = load i8 *%head.%d
209           LoadInst *tape_0 = builder->CreateLoad(curhead, tapereg);
210
211           //%tape.%d = sext i8 %tape.%d to i32
212           Value *tape_1 = builder->
213             CreateSExt(tape_0, IntegerType::Int32Ty, tapereg);
214
215           //call i32 @putchar(i32 %tape.%d)
216           Value *putchar_params[] = {
217             tape_1
218           };
219           CallInst *putchar_call = builder->
220             CreateCall(putchar_func,
221                        putchar_params, array_endof(putchar_params));
222           putchar_call->setTailCall(false);
223         }
224         break;
225
226       case SYM_MOVE:
227         {
228           //%head.%d = getelementptr i8 *%head.%d, i32 %d
229           curhead = builder->
230             CreateGEP(curhead, ConstantInt::get(APInt(32, curvalue)),
231                       headreg);
232
233           //Error block for array out of bounds
234           if (comflag & flag_arraybounds)
235           {
236             //%test.%d = icmp uge i8 *%head.%d, %arrmax
237             Value *test_0 = builder->
238               CreateICmpUGE(curhead, ptr_arrmax, testreg);
239
240             //%test.%d = icmp ult i8 *%head.%d, %arr
241             Value *test_1 = builder->
242               CreateICmpULT(curhead, ptr_arr, testreg);
243
244             //%test.%d = or i1 %test.%d, %test.%d
245             Value *test_2 = builder->
246               CreateOr(test_0, test_1, testreg);
247
248             //br i1 %test.%d, label %main.%d, label %main.%d
249             BasicBlock *nextbb = BasicBlock::Create(label, brainf_func);
250             builder->CreateCondBr(test_2, aberrorbb, nextbb);
251
252             //main.%d:
253             builder->SetInsertPoint(nextbb);
254           }
255         }
256         break;
257
258       case SYM_CHANGE:
259         {
260           //%tape.%d = load i8 *%head.%d
261           LoadInst *tape_0 = builder->CreateLoad(curhead, tapereg);
262
263           //%tape.%d = add i8 %tape.%d, %d
264           Value *tape_1 = builder->
265             CreateAdd(tape_0, ConstantInt::get(APInt(8, curvalue)), tapereg);
266
267           //store i8 %tape.%d, i8 *%head.%d\n"
268           builder->CreateStore(tape_1, curhead);
269         }
270         break;
271
272       case SYM_LOOP:
273         {
274           //br label %main.%d
275           BasicBlock *testbb = BasicBlock::Create(label, brainf_func);
276           builder->CreateBr(testbb);
277
278           //main.%d:
279           BasicBlock *bb_0 = builder->GetInsertBlock();
280           BasicBlock *bb_1 = BasicBlock::Create(label, brainf_func);
281           builder->SetInsertPoint(bb_1);
282
283           // Make part of PHI instruction now, wait until end of loop to finish
284           PHINode *phi_0 =
285             PHINode::Create(PointerType::getUnqual(IntegerType::Int8Ty),
286                             headreg, testbb);
287           phi_0->reserveOperandSpace(2);
288           phi_0->addIncoming(curhead, bb_0);
289           curhead = phi_0;
290
291           readloop(phi_0, bb_1, testbb);
292         }
293         break;
294
295       default:
296         std::cerr << "Error: Unknown symbol.\n";
297         abort();
298         break;
299     }
300
301     cursym = nextsym;
302     curvalue = nextvalue;
303     nextsym = SYM_NONE;
304
305     // Reading stdin loop
306     loop = (cursym == SYM_NONE)
307         || (cursym == SYM_MOVE)
308         || (cursym == SYM_CHANGE);
309     while(loop) {
310       *in>>c;
311       if (in->eof()) {
312         if (cursym == SYM_NONE) {
313           cursym = SYM_EOF;
314         } else {
315           nextsym = SYM_EOF;
316         }
317         loop = 0;
318       } else {
319         direction = 1;
320         switch(c) {
321           case '-':
322             direction = -1;
323             // Fall through
324
325           case '+':
326             if (cursym == SYM_CHANGE) {
327               curvalue += direction;
328               // loop = 1
329             } else {
330               if (cursym == SYM_NONE) {
331                 cursym = SYM_CHANGE;
332                 curvalue = direction;
333                 // loop = 1
334               } else {
335                 nextsym = SYM_CHANGE;
336                 nextvalue = direction;
337                 loop = 0;
338               }
339             }
340             break;
341
342           case '<':
343             direction = -1;
344             // Fall through
345
346           case '>':
347             if (cursym == SYM_MOVE) {
348               curvalue += direction;
349               // loop = 1
350             } else {
351               if (cursym == SYM_NONE) {
352                 cursym = SYM_MOVE;
353                 curvalue = direction;
354                 // loop = 1
355               } else {
356                 nextsym = SYM_MOVE;
357                 nextvalue = direction;
358                 loop = 0;
359               }
360             }
361             break;
362
363           case ',':
364             if (cursym == SYM_NONE) {
365               cursym = SYM_READ;
366             } else {
367               nextsym = SYM_READ;
368             }
369             loop = 0;
370             break;
371
372           case '.':
373             if (cursym == SYM_NONE) {
374               cursym = SYM_WRITE;
375             } else {
376               nextsym = SYM_WRITE;
377             }
378             loop = 0;
379             break;
380
381           case '[':
382             if (cursym == SYM_NONE) {
383               cursym = SYM_LOOP;
384             } else {
385               nextsym = SYM_LOOP;
386             }
387             loop = 0;
388             break;
389
390           case ']':
391             if (cursym == SYM_NONE) {
392               cursym = SYM_ENDLOOP;
393             } else {
394               nextsym = SYM_ENDLOOP;
395             }
396             loop = 0;
397             break;
398
399           // Ignore other characters
400           default:
401             break;
402         }
403       }
404     }
405   }
406
407   if (cursym == SYM_ENDLOOP) {
408     if (!phi) {
409       std::cerr << "Error: Extra ']'\n";
410       abort();
411     }
412
413     // Write loop test
414     {
415       //br label %main.%d
416       builder->CreateBr(testbb);
417
418       //main.%d:
419
420       //%head.%d = phi i8 *[%head.%d, %main.%d], [%head.%d, %main.%d]
421       //Finish phi made at beginning of loop
422       phi->addIncoming(curhead, builder->GetInsertBlock());
423       Value *head_0 = phi;
424
425       //%tape.%d = load i8 *%head.%d
426       LoadInst *tape_0 = new LoadInst(head_0, tapereg, testbb);
427
428       //%test.%d = icmp eq i8 %tape.%d, 0
429       ICmpInst *test_0 = new ICmpInst(ICmpInst::ICMP_EQ, tape_0,
430                                       ConstantInt::get(APInt(8, 0)), testreg,
431                                       testbb);
432
433       //br i1 %test.%d, label %main.%d, label %main.%d
434       BasicBlock *bb_0 = BasicBlock::Create(label, brainf_func);
435       BranchInst::Create(bb_0, oldbb, test_0, testbb);
436
437       //main.%d:
438       builder->SetInsertPoint(bb_0);
439
440       //%head.%d = phi i8 *[%head.%d, %main.%d]
441       PHINode *phi_1 = builder->
442         CreatePHI(PointerType::getUnqual(IntegerType::Int8Ty), headreg);
443       phi_1->reserveOperandSpace(1);
444       phi_1->addIncoming(head_0, testbb);
445       curhead = phi_1;
446     }
447
448     return;
449   }
450
451   //End of the program, so go to return block
452   builder->CreateBr(endbb);
453
454   if (phi) {
455     std::cerr << "Error: Missing ']'\n";
456     abort();
457   }
458 }