Taints upcoming store and adds bogus conditional branches else where. Now as a separa...
[oota-llvm.git] / examples / BrainF / BrainF.cpp
1 //===-- BrainF.cpp - BrainF compiler example ----------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===--------------------------------------------------------------------===//
9 //
10 // This class compiles the BrainF language into LLVM assembly.
11 //
12 // The BrainF language has 8 commands:
13 // Command   Equivalent C    Action
14 // -------   ------------    ------
15 // ,         *h=getchar();   Read a character from stdin, 255 on EOF
16 // .         putchar(*h);    Write a character to stdout
17 // -         --*h;           Decrement tape
18 // +         ++*h;           Increment tape
19 // <         --h;            Move head left
20 // >         ++h;            Move head right
21 // [         while(*h) {     Start loop
22 // ]         }               End loop
23 //
24 //===--------------------------------------------------------------------===//
25
26 #include "BrainF.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/IR/Constants.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/Intrinsics.h"
31 #include <iostream>
32
33 using namespace llvm;
34
35 //Set the constants for naming
36 const char *BrainF::tapereg = "tape";
37 const char *BrainF::headreg = "head";
38 const char *BrainF::label   = "brainf";
39 const char *BrainF::testreg = "test";
40
41 Module *BrainF::parse(std::istream *in1, int mem, CompileFlags cf,
42                       LLVMContext& Context) {
43   in       = in1;
44   memtotal = mem;
45   comflag  = cf;
46
47   header(Context);
48   readloop(nullptr, nullptr, nullptr, Context);
49   delete builder;
50   return module;
51 }
52
53 void BrainF::header(LLVMContext& C) {
54   module = new Module("BrainF", C);
55
56   //Function prototypes
57
58   //declare void @llvm.memset.p0i8.i32(i8 *, i8, i32, i32, i1)
59   Type *Tys[] = { Type::getInt8PtrTy(C), Type::getInt32Ty(C) };
60   Function *memset_func = Intrinsic::getDeclaration(module, Intrinsic::memset,
61                                                     Tys);
62
63   //declare i32 @getchar()
64   getchar_func = cast<Function>(module->
65     getOrInsertFunction("getchar", IntegerType::getInt32Ty(C), NULL));
66
67   //declare i32 @putchar(i32)
68   putchar_func = cast<Function>(module->
69     getOrInsertFunction("putchar", IntegerType::getInt32Ty(C),
70                         IntegerType::getInt32Ty(C), NULL));
71
72   //Function header
73
74   //define void @brainf()
75   brainf_func = cast<Function>(module->
76     getOrInsertFunction("brainf", Type::getVoidTy(C), NULL));
77
78   builder = new IRBuilder<>(BasicBlock::Create(C, label, brainf_func));
79
80   //%arr = malloc i8, i32 %d
81   ConstantInt *val_mem = ConstantInt::get(C, APInt(32, memtotal));
82   BasicBlock* BB = builder->GetInsertBlock();
83   Type* IntPtrTy = IntegerType::getInt32Ty(C);
84   Type* Int8Ty = IntegerType::getInt8Ty(C);
85   Constant* allocsize = ConstantExpr::getSizeOf(Int8Ty);
86   allocsize = ConstantExpr::getTruncOrBitCast(allocsize, IntPtrTy);
87   ptr_arr = CallInst::CreateMalloc(BB, IntPtrTy, Int8Ty, allocsize, val_mem, 
88                                    nullptr, "arr");
89   BB->getInstList().push_back(cast<Instruction>(ptr_arr));
90
91   //call void @llvm.memset.p0i8.i32(i8 *%arr, i8 0, i32 %d, i32 1, i1 0)
92   {
93     Value *memset_params[] = {
94       ptr_arr,
95       ConstantInt::get(C, APInt(8, 0)),
96       val_mem,
97       ConstantInt::get(C, APInt(32, 1)),
98       ConstantInt::get(C, APInt(1, 0))
99     };
100
101     CallInst *memset_call = builder->
102       CreateCall(memset_func, memset_params);
103     memset_call->setTailCall(false);
104   }
105
106   //%arrmax = getelementptr i8 *%arr, i32 %d
107   if (comflag & flag_arraybounds) {
108     ptr_arrmax = builder->
109       CreateGEP(ptr_arr, ConstantInt::get(C, APInt(32, memtotal)), "arrmax");
110   }
111
112   //%head.%d = getelementptr i8 *%arr, i32 %d
113   curhead = builder->CreateGEP(ptr_arr,
114                                ConstantInt::get(C, APInt(32, memtotal/2)),
115                                headreg);
116
117   //Function footer
118
119   //brainf.end:
120   endbb = BasicBlock::Create(C, label, brainf_func);
121
122   //call free(i8 *%arr)
123   endbb->getInstList().push_back(CallInst::CreateFree(ptr_arr, endbb));
124
125   //ret void
126   ReturnInst::Create(C, endbb);
127
128   //Error block for array out of bounds
129   if (comflag & flag_arraybounds)
130   {
131     //@aberrormsg = internal constant [%d x i8] c"\00"
132     Constant *msg_0 =
133       ConstantDataArray::getString(C, "Error: The head has left the tape.",
134                                    true);
135
136     GlobalVariable *aberrormsg = new GlobalVariable(
137       *module,
138       msg_0->getType(),
139       true,
140       GlobalValue::InternalLinkage,
141       msg_0,
142       "aberrormsg");
143
144     //declare i32 @puts(i8 *)
145     Function *puts_func = cast<Function>(module->
146       getOrInsertFunction("puts", IntegerType::getInt32Ty(C),
147                       PointerType::getUnqual(IntegerType::getInt8Ty(C)), NULL));
148
149     //brainf.aberror:
150     aberrorbb = BasicBlock::Create(C, label, brainf_func);
151
152     //call i32 @puts(i8 *getelementptr([%d x i8] *@aberrormsg, i32 0, i32 0))
153     {
154       Constant *zero_32 = Constant::getNullValue(IntegerType::getInt32Ty(C));
155
156       Constant *gep_params[] = {
157         zero_32,
158         zero_32
159       };
160
161       Constant *msgptr = ConstantExpr::
162         getGetElementPtr(aberrormsg->getValueType(), aberrormsg, gep_params);
163
164       Value *puts_params[] = {
165         msgptr
166       };
167
168       CallInst *puts_call =
169         CallInst::Create(puts_func,
170                          puts_params,
171                          "", aberrorbb);
172       puts_call->setTailCall(false);
173     }
174
175     //br label %brainf.end
176     BranchInst::Create(endbb, aberrorbb);
177   }
178 }
179
180 void BrainF::readloop(PHINode *phi, BasicBlock *oldbb, BasicBlock *testbb,
181                       LLVMContext &C) {
182   Symbol cursym = SYM_NONE;
183   int curvalue = 0;
184   Symbol nextsym = SYM_NONE;
185   int nextvalue = 0;
186   char c;
187   int loop;
188   int direction;
189
190   while(cursym != SYM_EOF && cursym != SYM_ENDLOOP) {
191     // Write out commands
192     switch(cursym) {
193       case SYM_NONE:
194         // Do nothing
195         break;
196
197       case SYM_READ:
198         {
199           //%tape.%d = call i32 @getchar()
200           CallInst *getchar_call =
201               builder->CreateCall(getchar_func, {}, tapereg);
202           getchar_call->setTailCall(false);
203           Value *tape_0 = getchar_call;
204
205           //%tape.%d = trunc i32 %tape.%d to i8
206           Value *tape_1 = builder->
207             CreateTrunc(tape_0, IntegerType::getInt8Ty(C), tapereg);
208
209           //store i8 %tape.%d, i8 *%head.%d
210           builder->CreateStore(tape_1, curhead);
211         }
212         break;
213
214       case SYM_WRITE:
215         {
216           //%tape.%d = load i8 *%head.%d
217           LoadInst *tape_0 = builder->CreateLoad(curhead, tapereg);
218
219           //%tape.%d = sext i8 %tape.%d to i32
220           Value *tape_1 = builder->
221             CreateSExt(tape_0, IntegerType::getInt32Ty(C), tapereg);
222
223           //call i32 @putchar(i32 %tape.%d)
224           Value *putchar_params[] = {
225             tape_1
226           };
227           CallInst *putchar_call = builder->
228             CreateCall(putchar_func,
229                        putchar_params);
230           putchar_call->setTailCall(false);
231         }
232         break;
233
234       case SYM_MOVE:
235         {
236           //%head.%d = getelementptr i8 *%head.%d, i32 %d
237           curhead = builder->
238             CreateGEP(curhead, ConstantInt::get(C, APInt(32, curvalue)),
239                       headreg);
240
241           //Error block for array out of bounds
242           if (comflag & flag_arraybounds)
243           {
244             //%test.%d = icmp uge i8 *%head.%d, %arrmax
245             Value *test_0 = builder->
246               CreateICmpUGE(curhead, ptr_arrmax, testreg);
247
248             //%test.%d = icmp ult i8 *%head.%d, %arr
249             Value *test_1 = builder->
250               CreateICmpULT(curhead, ptr_arr, testreg);
251
252             //%test.%d = or i1 %test.%d, %test.%d
253             Value *test_2 = builder->
254               CreateOr(test_0, test_1, testreg);
255
256             //br i1 %test.%d, label %main.%d, label %main.%d
257             BasicBlock *nextbb = BasicBlock::Create(C, label, brainf_func);
258             builder->CreateCondBr(test_2, aberrorbb, nextbb);
259
260             //main.%d:
261             builder->SetInsertPoint(nextbb);
262           }
263         }
264         break;
265
266       case SYM_CHANGE:
267         {
268           //%tape.%d = load i8 *%head.%d
269           LoadInst *tape_0 = builder->CreateLoad(curhead, tapereg);
270
271           //%tape.%d = add i8 %tape.%d, %d
272           Value *tape_1 = builder->
273             CreateAdd(tape_0, ConstantInt::get(C, APInt(8, curvalue)), tapereg);
274
275           //store i8 %tape.%d, i8 *%head.%d\n"
276           builder->CreateStore(tape_1, curhead);
277         }
278         break;
279
280       case SYM_LOOP:
281         {
282           //br label %main.%d
283           BasicBlock *testbb = BasicBlock::Create(C, label, brainf_func);
284           builder->CreateBr(testbb);
285
286           //main.%d:
287           BasicBlock *bb_0 = builder->GetInsertBlock();
288           BasicBlock *bb_1 = BasicBlock::Create(C, label, brainf_func);
289           builder->SetInsertPoint(bb_1);
290
291           // Make part of PHI instruction now, wait until end of loop to finish
292           PHINode *phi_0 =
293             PHINode::Create(PointerType::getUnqual(IntegerType::getInt8Ty(C)),
294                             2, headreg, testbb);
295           phi_0->addIncoming(curhead, bb_0);
296           curhead = phi_0;
297
298           readloop(phi_0, bb_1, testbb, C);
299         }
300         break;
301
302       default:
303         std::cerr << "Error: Unknown symbol.\n";
304         abort();
305         break;
306     }
307
308     cursym = nextsym;
309     curvalue = nextvalue;
310     nextsym = SYM_NONE;
311
312     // Reading stdin loop
313     loop = (cursym == SYM_NONE)
314         || (cursym == SYM_MOVE)
315         || (cursym == SYM_CHANGE);
316     while(loop) {
317       *in>>c;
318       if (in->eof()) {
319         if (cursym == SYM_NONE) {
320           cursym = SYM_EOF;
321         } else {
322           nextsym = SYM_EOF;
323         }
324         loop = 0;
325       } else {
326         direction = 1;
327         switch(c) {
328           case '-':
329             direction = -1;
330             // Fall through
331
332           case '+':
333             if (cursym == SYM_CHANGE) {
334               curvalue += direction;
335               // loop = 1
336             } else {
337               if (cursym == SYM_NONE) {
338                 cursym = SYM_CHANGE;
339                 curvalue = direction;
340                 // loop = 1
341               } else {
342                 nextsym = SYM_CHANGE;
343                 nextvalue = direction;
344                 loop = 0;
345               }
346             }
347             break;
348
349           case '<':
350             direction = -1;
351             // Fall through
352
353           case '>':
354             if (cursym == SYM_MOVE) {
355               curvalue += direction;
356               // loop = 1
357             } else {
358               if (cursym == SYM_NONE) {
359                 cursym = SYM_MOVE;
360                 curvalue = direction;
361                 // loop = 1
362               } else {
363                 nextsym = SYM_MOVE;
364                 nextvalue = direction;
365                 loop = 0;
366               }
367             }
368             break;
369
370           case ',':
371             if (cursym == SYM_NONE) {
372               cursym = SYM_READ;
373             } else {
374               nextsym = SYM_READ;
375             }
376             loop = 0;
377             break;
378
379           case '.':
380             if (cursym == SYM_NONE) {
381               cursym = SYM_WRITE;
382             } else {
383               nextsym = SYM_WRITE;
384             }
385             loop = 0;
386             break;
387
388           case '[':
389             if (cursym == SYM_NONE) {
390               cursym = SYM_LOOP;
391             } else {
392               nextsym = SYM_LOOP;
393             }
394             loop = 0;
395             break;
396
397           case ']':
398             if (cursym == SYM_NONE) {
399               cursym = SYM_ENDLOOP;
400             } else {
401               nextsym = SYM_ENDLOOP;
402             }
403             loop = 0;
404             break;
405
406           // Ignore other characters
407           default:
408             break;
409         }
410       }
411     }
412   }
413
414   if (cursym == SYM_ENDLOOP) {
415     if (!phi) {
416       std::cerr << "Error: Extra ']'\n";
417       abort();
418     }
419
420     // Write loop test
421     {
422       //br label %main.%d
423       builder->CreateBr(testbb);
424
425       //main.%d:
426
427       //%head.%d = phi i8 *[%head.%d, %main.%d], [%head.%d, %main.%d]
428       //Finish phi made at beginning of loop
429       phi->addIncoming(curhead, builder->GetInsertBlock());
430       Value *head_0 = phi;
431
432       //%tape.%d = load i8 *%head.%d
433       LoadInst *tape_0 = new LoadInst(head_0, tapereg, testbb);
434
435       //%test.%d = icmp eq i8 %tape.%d, 0
436       ICmpInst *test_0 = new ICmpInst(*testbb, ICmpInst::ICMP_EQ, tape_0,
437                                     ConstantInt::get(C, APInt(8, 0)), testreg);
438
439       //br i1 %test.%d, label %main.%d, label %main.%d
440       BasicBlock *bb_0 = BasicBlock::Create(C, label, brainf_func);
441       BranchInst::Create(bb_0, oldbb, test_0, testbb);
442
443       //main.%d:
444       builder->SetInsertPoint(bb_0);
445
446       //%head.%d = phi i8 *[%head.%d, %main.%d]
447       PHINode *phi_1 = builder->
448         CreatePHI(PointerType::getUnqual(IntegerType::getInt8Ty(C)), 1,
449                   headreg);
450       phi_1->addIncoming(head_0, testbb);
451       curhead = phi_1;
452     }
453
454     return;
455   }
456
457   //End of the program, so go to return block
458   builder->CreateBr(endbb);
459
460   if (phi) {
461     std::cerr << "Error: Missing ']'\n";
462     abort();
463   }
464 }