Switch GlobalVariable ctors to a sane API, where *either* a context or a module is...
[oota-llvm.git] / examples / BrainF / BrainF.cpp
1 //===-- BrainF.cpp - BrainF compiler example ----------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===--------------------------------------------------------------------===//
9 //
10 // This class compiles the BrainF language into LLVM assembly.
11 //
12 // The BrainF language has 8 commands:
13 // Command   Equivalent C    Action
14 // -------   ------------    ------
15 // ,         *h=getchar();   Read a character from stdin, 255 on EOF
16 // .         putchar(*h);    Write a character to stdout
17 // -         --*h;           Decrement tape
18 // +         ++*h;           Increment tape
19 // <         --h;            Move head left
20 // >         ++h;            Move head right
21 // [         while(*h) {     Start loop
22 // ]         }               End loop
23 //
24 //===--------------------------------------------------------------------===//
25
26 #include "BrainF.h"
27 #include "llvm/Constants.h"
28 #include "llvm/Intrinsics.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include <iostream>
31 using namespace llvm;
32
33 //Set the constants for naming
34 const char *BrainF::tapereg = "tape";
35 const char *BrainF::headreg = "head";
36 const char *BrainF::label   = "brainf";
37 const char *BrainF::testreg = "test";
38
39 Module *BrainF::parse(std::istream *in1, int mem, CompileFlags cf,
40                       LLVMContext& Context) {
41   in       = in1;
42   memtotal = mem;
43   comflag  = cf;
44
45   header(Context);
46   readloop(0, 0, 0);
47   delete builder;
48   return module;
49 }
50
51 void BrainF::header(LLVMContext& C) {
52   module = new Module("BrainF", C);
53
54   //Function prototypes
55
56   //declare void @llvm.memset.i32(i8 *, i8, i32, i32)
57   const Type *Tys[] = { Type::Int32Ty };
58   Function *memset_func = Intrinsic::getDeclaration(module, Intrinsic::memset,
59                                                     Tys, 1);
60
61   //declare i32 @getchar()
62   getchar_func = cast<Function>(module->
63     getOrInsertFunction("getchar", IntegerType::Int32Ty, NULL));
64
65   //declare i32 @putchar(i32)
66   putchar_func = cast<Function>(module->
67     getOrInsertFunction("putchar", IntegerType::Int32Ty,
68                         IntegerType::Int32Ty, NULL));
69
70
71   //Function header
72
73   //define void @brainf()
74   brainf_func = cast<Function>(module->
75     getOrInsertFunction("brainf", Type::VoidTy, NULL));
76
77   builder = new IRBuilder<>(BasicBlock::Create(label, brainf_func));
78
79   //%arr = malloc i8, i32 %d
80   ConstantInt *val_mem = ConstantInt::get(APInt(32, memtotal));
81   ptr_arr = builder->CreateMalloc(IntegerType::Int8Ty, val_mem, "arr");
82
83   //call void @llvm.memset.i32(i8 *%arr, i8 0, i32 %d, i32 1)
84   {
85     Value *memset_params[] = {
86       ptr_arr,
87       ConstantInt::get(APInt(8, 0)),
88       val_mem,
89       ConstantInt::get(APInt(32, 1))
90     };
91
92     CallInst *memset_call = builder->
93       CreateCall(memset_func, memset_params, array_endof(memset_params));
94     memset_call->setTailCall(false);
95   }
96
97   //%arrmax = getelementptr i8 *%arr, i32 %d
98   if (comflag & flag_arraybounds) {
99     ptr_arrmax = builder->
100       CreateGEP(ptr_arr, ConstantInt::get(APInt(32, memtotal)), "arrmax");
101   }
102
103   //%head.%d = getelementptr i8 *%arr, i32 %d
104   curhead = builder->CreateGEP(ptr_arr,
105                                ConstantInt::get(APInt(32, memtotal/2)),
106                                headreg);
107
108
109
110   //Function footer
111
112   //brainf.end:
113   endbb = BasicBlock::Create(label, brainf_func);
114
115   //free i8 *%arr
116   new FreeInst(ptr_arr, endbb);
117
118   //ret void
119   ReturnInst::Create(endbb);
120
121
122
123   //Error block for array out of bounds
124   if (comflag & flag_arraybounds)
125   {
126     //@aberrormsg = internal constant [%d x i8] c"\00"
127     Constant *msg_0 = ConstantArray::
128       get("Error: The head has left the tape.", true);
129
130     GlobalVariable *aberrormsg = new GlobalVariable(
131       *module,
132       msg_0->getType(),
133       true,
134       GlobalValue::InternalLinkage,
135       msg_0,
136       "aberrormsg");
137
138     //declare i32 @puts(i8 *)
139     Function *puts_func = cast<Function>(module->
140       getOrInsertFunction("puts", IntegerType::Int32Ty,
141                           PointerType::getUnqual(IntegerType::Int8Ty), NULL));
142
143     //brainf.aberror:
144     aberrorbb = BasicBlock::Create(label, brainf_func);
145
146     //call i32 @puts(i8 *getelementptr([%d x i8] *@aberrormsg, i32 0, i32 0))
147     {
148       Constant *zero_32 = Constant::getNullValue(IntegerType::Int32Ty);
149
150       Constant *gep_params[] = {
151         zero_32,
152         zero_32
153       };
154
155       Constant *msgptr = ConstantExpr::
156         getGetElementPtr(aberrormsg, gep_params,
157                          array_lengthof(gep_params));
158
159       Value *puts_params[] = {
160         msgptr
161       };
162
163       CallInst *puts_call =
164         CallInst::Create(puts_func,
165                          puts_params, array_endof(puts_params),
166                          "", aberrorbb);
167       puts_call->setTailCall(false);
168     }
169
170     //br label %brainf.end
171     BranchInst::Create(endbb, aberrorbb);
172   }
173 }
174
175 void BrainF::readloop(PHINode *phi, BasicBlock *oldbb, BasicBlock *testbb) {
176   Symbol cursym = SYM_NONE;
177   int curvalue = 0;
178   Symbol nextsym = SYM_NONE;
179   int nextvalue = 0;
180   char c;
181   int loop;
182   int direction;
183
184   while(cursym != SYM_EOF && cursym != SYM_ENDLOOP) {
185     // Write out commands
186     switch(cursym) {
187       case SYM_NONE:
188         // Do nothing
189         break;
190
191       case SYM_READ:
192         {
193           //%tape.%d = call i32 @getchar()
194           CallInst *getchar_call = builder->CreateCall(getchar_func, tapereg);
195           getchar_call->setTailCall(false);
196           Value *tape_0 = getchar_call;
197
198           //%tape.%d = trunc i32 %tape.%d to i8
199           Value *tape_1 = builder->
200             CreateTrunc(tape_0, IntegerType::Int8Ty, tapereg);
201
202           //store i8 %tape.%d, i8 *%head.%d
203           builder->CreateStore(tape_1, curhead);
204         }
205         break;
206
207       case SYM_WRITE:
208         {
209           //%tape.%d = load i8 *%head.%d
210           LoadInst *tape_0 = builder->CreateLoad(curhead, tapereg);
211
212           //%tape.%d = sext i8 %tape.%d to i32
213           Value *tape_1 = builder->
214             CreateSExt(tape_0, IntegerType::Int32Ty, tapereg);
215
216           //call i32 @putchar(i32 %tape.%d)
217           Value *putchar_params[] = {
218             tape_1
219           };
220           CallInst *putchar_call = builder->
221             CreateCall(putchar_func,
222                        putchar_params, array_endof(putchar_params));
223           putchar_call->setTailCall(false);
224         }
225         break;
226
227       case SYM_MOVE:
228         {
229           //%head.%d = getelementptr i8 *%head.%d, i32 %d
230           curhead = builder->
231             CreateGEP(curhead, ConstantInt::get(APInt(32, curvalue)),
232                       headreg);
233
234           //Error block for array out of bounds
235           if (comflag & flag_arraybounds)
236           {
237             //%test.%d = icmp uge i8 *%head.%d, %arrmax
238             Value *test_0 = builder->
239               CreateICmpUGE(curhead, ptr_arrmax, testreg);
240
241             //%test.%d = icmp ult i8 *%head.%d, %arr
242             Value *test_1 = builder->
243               CreateICmpULT(curhead, ptr_arr, testreg);
244
245             //%test.%d = or i1 %test.%d, %test.%d
246             Value *test_2 = builder->
247               CreateOr(test_0, test_1, testreg);
248
249             //br i1 %test.%d, label %main.%d, label %main.%d
250             BasicBlock *nextbb = BasicBlock::Create(label, brainf_func);
251             builder->CreateCondBr(test_2, aberrorbb, nextbb);
252
253             //main.%d:
254             builder->SetInsertPoint(nextbb);
255           }
256         }
257         break;
258
259       case SYM_CHANGE:
260         {
261           //%tape.%d = load i8 *%head.%d
262           LoadInst *tape_0 = builder->CreateLoad(curhead, tapereg);
263
264           //%tape.%d = add i8 %tape.%d, %d
265           Value *tape_1 = builder->
266             CreateAdd(tape_0, ConstantInt::get(APInt(8, curvalue)), tapereg);
267
268           //store i8 %tape.%d, i8 *%head.%d\n"
269           builder->CreateStore(tape_1, curhead);
270         }
271         break;
272
273       case SYM_LOOP:
274         {
275           //br label %main.%d
276           BasicBlock *testbb = BasicBlock::Create(label, brainf_func);
277           builder->CreateBr(testbb);
278
279           //main.%d:
280           BasicBlock *bb_0 = builder->GetInsertBlock();
281           BasicBlock *bb_1 = BasicBlock::Create(label, brainf_func);
282           builder->SetInsertPoint(bb_1);
283
284           // Make part of PHI instruction now, wait until end of loop to finish
285           PHINode *phi_0 =
286             PHINode::Create(PointerType::getUnqual(IntegerType::Int8Ty),
287                             headreg, testbb);
288           phi_0->reserveOperandSpace(2);
289           phi_0->addIncoming(curhead, bb_0);
290           curhead = phi_0;
291
292           readloop(phi_0, bb_1, testbb);
293         }
294         break;
295
296       default:
297         std::cerr << "Error: Unknown symbol.\n";
298         abort();
299         break;
300     }
301
302     cursym = nextsym;
303     curvalue = nextvalue;
304     nextsym = SYM_NONE;
305
306     // Reading stdin loop
307     loop = (cursym == SYM_NONE)
308         || (cursym == SYM_MOVE)
309         || (cursym == SYM_CHANGE);
310     while(loop) {
311       *in>>c;
312       if (in->eof()) {
313         if (cursym == SYM_NONE) {
314           cursym = SYM_EOF;
315         } else {
316           nextsym = SYM_EOF;
317         }
318         loop = 0;
319       } else {
320         direction = 1;
321         switch(c) {
322           case '-':
323             direction = -1;
324             // Fall through
325
326           case '+':
327             if (cursym == SYM_CHANGE) {
328               curvalue += direction;
329               // loop = 1
330             } else {
331               if (cursym == SYM_NONE) {
332                 cursym = SYM_CHANGE;
333                 curvalue = direction;
334                 // loop = 1
335               } else {
336                 nextsym = SYM_CHANGE;
337                 nextvalue = direction;
338                 loop = 0;
339               }
340             }
341             break;
342
343           case '<':
344             direction = -1;
345             // Fall through
346
347           case '>':
348             if (cursym == SYM_MOVE) {
349               curvalue += direction;
350               // loop = 1
351             } else {
352               if (cursym == SYM_NONE) {
353                 cursym = SYM_MOVE;
354                 curvalue = direction;
355                 // loop = 1
356               } else {
357                 nextsym = SYM_MOVE;
358                 nextvalue = direction;
359                 loop = 0;
360               }
361             }
362             break;
363
364           case ',':
365             if (cursym == SYM_NONE) {
366               cursym = SYM_READ;
367             } else {
368               nextsym = SYM_READ;
369             }
370             loop = 0;
371             break;
372
373           case '.':
374             if (cursym == SYM_NONE) {
375               cursym = SYM_WRITE;
376             } else {
377               nextsym = SYM_WRITE;
378             }
379             loop = 0;
380             break;
381
382           case '[':
383             if (cursym == SYM_NONE) {
384               cursym = SYM_LOOP;
385             } else {
386               nextsym = SYM_LOOP;
387             }
388             loop = 0;
389             break;
390
391           case ']':
392             if (cursym == SYM_NONE) {
393               cursym = SYM_ENDLOOP;
394             } else {
395               nextsym = SYM_ENDLOOP;
396             }
397             loop = 0;
398             break;
399
400           // Ignore other characters
401           default:
402             break;
403         }
404       }
405     }
406   }
407
408   if (cursym == SYM_ENDLOOP) {
409     if (!phi) {
410       std::cerr << "Error: Extra ']'\n";
411       abort();
412     }
413
414     // Write loop test
415     {
416       //br label %main.%d
417       builder->CreateBr(testbb);
418
419       //main.%d:
420
421       //%head.%d = phi i8 *[%head.%d, %main.%d], [%head.%d, %main.%d]
422       //Finish phi made at beginning of loop
423       phi->addIncoming(curhead, builder->GetInsertBlock());
424       Value *head_0 = phi;
425
426       //%tape.%d = load i8 *%head.%d
427       LoadInst *tape_0 = new LoadInst(head_0, tapereg, testbb);
428
429       //%test.%d = icmp eq i8 %tape.%d, 0
430       ICmpInst *test_0 = new ICmpInst(ICmpInst::ICMP_EQ, tape_0,
431                                       ConstantInt::get(APInt(8, 0)), testreg,
432                                       testbb);
433
434       //br i1 %test.%d, label %main.%d, label %main.%d
435       BasicBlock *bb_0 = BasicBlock::Create(label, brainf_func);
436       BranchInst::Create(bb_0, oldbb, test_0, testbb);
437
438       //main.%d:
439       builder->SetInsertPoint(bb_0);
440
441       //%head.%d = phi i8 *[%head.%d, %main.%d]
442       PHINode *phi_1 = builder->
443         CreatePHI(PointerType::getUnqual(IntegerType::Int8Ty), headreg);
444       phi_1->reserveOperandSpace(1);
445       phi_1->addIncoming(head_0, testbb);
446       curhead = phi_1;
447     }
448
449     return;
450   }
451
452   //End of the program, so go to return block
453   builder->CreateBr(endbb);
454
455   if (phi) {
456     std::cerr << "Error: Missing ']'\n";
457     abort();
458   }
459 }