Convert ConstantExpr::getGetElementPtr and
[oota-llvm.git] / examples / BrainF / BrainF.cpp
1 //===-- BrainF.cpp - BrainF compiler example ----------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===--------------------------------------------------------------------===//
9 //
10 // This class compiles the BrainF language into LLVM assembly.
11 //
12 // The BrainF language has 8 commands:
13 // Command   Equivalent C    Action
14 // -------   ------------    ------
15 // ,         *h=getchar();   Read a character from stdin, 255 on EOF
16 // .         putchar(*h);    Write a character to stdout
17 // -         --*h;           Decrement tape
18 // +         ++*h;           Increment tape
19 // <         --h;            Move head left
20 // >         ++h;            Move head right
21 // [         while(*h) {     Start loop
22 // ]         }               End loop
23 //
24 //===--------------------------------------------------------------------===//
25
26 #include "BrainF.h"
27 #include "llvm/Constants.h"
28 #include "llvm/Instructions.h"
29 #include "llvm/Intrinsics.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include <iostream>
32 using namespace llvm;
33
34 //Set the constants for naming
35 const char *BrainF::tapereg = "tape";
36 const char *BrainF::headreg = "head";
37 const char *BrainF::label   = "brainf";
38 const char *BrainF::testreg = "test";
39
40 Module *BrainF::parse(std::istream *in1, int mem, CompileFlags cf,
41                       LLVMContext& Context) {
42   in       = in1;
43   memtotal = mem;
44   comflag  = cf;
45
46   header(Context);
47   readloop(0, 0, 0, Context);
48   delete builder;
49   return module;
50 }
51
52 void BrainF::header(LLVMContext& C) {
53   module = new Module("BrainF", C);
54
55   //Function prototypes
56
57   //declare void @llvm.memset.p0i8.i32(i8 *, i8, i32, i32, i1)
58   Type *Tys[] = { Type::getInt8PtrTy(C), Type::getInt32Ty(C) };
59   Function *memset_func = Intrinsic::getDeclaration(module, Intrinsic::memset,
60                                                     Tys);
61
62   //declare i32 @getchar()
63   getchar_func = cast<Function>(module->
64     getOrInsertFunction("getchar", IntegerType::getInt32Ty(C), NULL));
65
66   //declare i32 @putchar(i32)
67   putchar_func = cast<Function>(module->
68     getOrInsertFunction("putchar", IntegerType::getInt32Ty(C),
69                         IntegerType::getInt32Ty(C), NULL));
70
71
72   //Function header
73
74   //define void @brainf()
75   brainf_func = cast<Function>(module->
76     getOrInsertFunction("brainf", Type::getVoidTy(C), NULL));
77
78   builder = new IRBuilder<>(BasicBlock::Create(C, label, brainf_func));
79
80   //%arr = malloc i8, i32 %d
81   ConstantInt *val_mem = ConstantInt::get(C, APInt(32, memtotal));
82   BasicBlock* BB = builder->GetInsertBlock();
83   Type* IntPtrTy = IntegerType::getInt32Ty(C);
84   Type* Int8Ty = IntegerType::getInt8Ty(C);
85   Constant* allocsize = ConstantExpr::getSizeOf(Int8Ty);
86   allocsize = ConstantExpr::getTruncOrBitCast(allocsize, IntPtrTy);
87   ptr_arr = CallInst::CreateMalloc(BB, IntPtrTy, Int8Ty, allocsize, val_mem, 
88                                    NULL, "arr");
89   BB->getInstList().push_back(cast<Instruction>(ptr_arr));
90
91   //call void @llvm.memset.p0i8.i32(i8 *%arr, i8 0, i32 %d, i32 1, i1 0)
92   {
93     Value *memset_params[] = {
94       ptr_arr,
95       ConstantInt::get(C, APInt(8, 0)),
96       val_mem,
97       ConstantInt::get(C, APInt(32, 1)),
98       ConstantInt::get(C, APInt(1, 0))
99     };
100
101     CallInst *memset_call = builder->
102       CreateCall(memset_func, memset_params);
103     memset_call->setTailCall(false);
104   }
105
106   //%arrmax = getelementptr i8 *%arr, i32 %d
107   if (comflag & flag_arraybounds) {
108     ptr_arrmax = builder->
109       CreateGEP(ptr_arr, ConstantInt::get(C, APInt(32, memtotal)), "arrmax");
110   }
111
112   //%head.%d = getelementptr i8 *%arr, i32 %d
113   curhead = builder->CreateGEP(ptr_arr,
114                                ConstantInt::get(C, APInt(32, memtotal/2)),
115                                headreg);
116
117
118
119   //Function footer
120
121   //brainf.end:
122   endbb = BasicBlock::Create(C, label, brainf_func);
123
124   //call free(i8 *%arr)
125   endbb->getInstList().push_back(CallInst::CreateFree(ptr_arr, endbb));
126
127   //ret void
128   ReturnInst::Create(C, endbb);
129
130
131
132   //Error block for array out of bounds
133   if (comflag & flag_arraybounds)
134   {
135     //@aberrormsg = internal constant [%d x i8] c"\00"
136     Constant *msg_0 =
137       ConstantArray::get(C, "Error: The head has left the tape.", true);
138
139     GlobalVariable *aberrormsg = new GlobalVariable(
140       *module,
141       msg_0->getType(),
142       true,
143       GlobalValue::InternalLinkage,
144       msg_0,
145       "aberrormsg");
146
147     //declare i32 @puts(i8 *)
148     Function *puts_func = cast<Function>(module->
149       getOrInsertFunction("puts", IntegerType::getInt32Ty(C),
150                       PointerType::getUnqual(IntegerType::getInt8Ty(C)), NULL));
151
152     //brainf.aberror:
153     aberrorbb = BasicBlock::Create(C, label, brainf_func);
154
155     //call i32 @puts(i8 *getelementptr([%d x i8] *@aberrormsg, i32 0, i32 0))
156     {
157       Constant *zero_32 = Constant::getNullValue(IntegerType::getInt32Ty(C));
158
159       Constant *gep_params[] = {
160         zero_32,
161         zero_32
162       };
163
164       Constant *msgptr = ConstantExpr::
165         getGetElementPtr(aberrormsg, gep_params);
166
167       Value *puts_params[] = {
168         msgptr
169       };
170
171       CallInst *puts_call =
172         CallInst::Create(puts_func,
173                          puts_params,
174                          "", aberrorbb);
175       puts_call->setTailCall(false);
176     }
177
178     //br label %brainf.end
179     BranchInst::Create(endbb, aberrorbb);
180   }
181 }
182
183 void BrainF::readloop(PHINode *phi, BasicBlock *oldbb, BasicBlock *testbb,
184                       LLVMContext &C) {
185   Symbol cursym = SYM_NONE;
186   int curvalue = 0;
187   Symbol nextsym = SYM_NONE;
188   int nextvalue = 0;
189   char c;
190   int loop;
191   int direction;
192
193   while(cursym != SYM_EOF && cursym != SYM_ENDLOOP) {
194     // Write out commands
195     switch(cursym) {
196       case SYM_NONE:
197         // Do nothing
198         break;
199
200       case SYM_READ:
201         {
202           //%tape.%d = call i32 @getchar()
203           CallInst *getchar_call = builder->CreateCall(getchar_func, tapereg);
204           getchar_call->setTailCall(false);
205           Value *tape_0 = getchar_call;
206
207           //%tape.%d = trunc i32 %tape.%d to i8
208           Value *tape_1 = builder->
209             CreateTrunc(tape_0, IntegerType::getInt8Ty(C), tapereg);
210
211           //store i8 %tape.%d, i8 *%head.%d
212           builder->CreateStore(tape_1, curhead);
213         }
214         break;
215
216       case SYM_WRITE:
217         {
218           //%tape.%d = load i8 *%head.%d
219           LoadInst *tape_0 = builder->CreateLoad(curhead, tapereg);
220
221           //%tape.%d = sext i8 %tape.%d to i32
222           Value *tape_1 = builder->
223             CreateSExt(tape_0, IntegerType::getInt32Ty(C), tapereg);
224
225           //call i32 @putchar(i32 %tape.%d)
226           Value *putchar_params[] = {
227             tape_1
228           };
229           CallInst *putchar_call = builder->
230             CreateCall(putchar_func,
231                        putchar_params);
232           putchar_call->setTailCall(false);
233         }
234         break;
235
236       case SYM_MOVE:
237         {
238           //%head.%d = getelementptr i8 *%head.%d, i32 %d
239           curhead = builder->
240             CreateGEP(curhead, ConstantInt::get(C, APInt(32, curvalue)),
241                       headreg);
242
243           //Error block for array out of bounds
244           if (comflag & flag_arraybounds)
245           {
246             //%test.%d = icmp uge i8 *%head.%d, %arrmax
247             Value *test_0 = builder->
248               CreateICmpUGE(curhead, ptr_arrmax, testreg);
249
250             //%test.%d = icmp ult i8 *%head.%d, %arr
251             Value *test_1 = builder->
252               CreateICmpULT(curhead, ptr_arr, testreg);
253
254             //%test.%d = or i1 %test.%d, %test.%d
255             Value *test_2 = builder->
256               CreateOr(test_0, test_1, testreg);
257
258             //br i1 %test.%d, label %main.%d, label %main.%d
259             BasicBlock *nextbb = BasicBlock::Create(C, label, brainf_func);
260             builder->CreateCondBr(test_2, aberrorbb, nextbb);
261
262             //main.%d:
263             builder->SetInsertPoint(nextbb);
264           }
265         }
266         break;
267
268       case SYM_CHANGE:
269         {
270           //%tape.%d = load i8 *%head.%d
271           LoadInst *tape_0 = builder->CreateLoad(curhead, tapereg);
272
273           //%tape.%d = add i8 %tape.%d, %d
274           Value *tape_1 = builder->
275             CreateAdd(tape_0, ConstantInt::get(C, APInt(8, curvalue)), tapereg);
276
277           //store i8 %tape.%d, i8 *%head.%d\n"
278           builder->CreateStore(tape_1, curhead);
279         }
280         break;
281
282       case SYM_LOOP:
283         {
284           //br label %main.%d
285           BasicBlock *testbb = BasicBlock::Create(C, label, brainf_func);
286           builder->CreateBr(testbb);
287
288           //main.%d:
289           BasicBlock *bb_0 = builder->GetInsertBlock();
290           BasicBlock *bb_1 = BasicBlock::Create(C, label, brainf_func);
291           builder->SetInsertPoint(bb_1);
292
293           // Make part of PHI instruction now, wait until end of loop to finish
294           PHINode *phi_0 =
295             PHINode::Create(PointerType::getUnqual(IntegerType::getInt8Ty(C)),
296                             2, headreg, testbb);
297           phi_0->addIncoming(curhead, bb_0);
298           curhead = phi_0;
299
300           readloop(phi_0, bb_1, testbb, C);
301         }
302         break;
303
304       default:
305         std::cerr << "Error: Unknown symbol.\n";
306         abort();
307         break;
308     }
309
310     cursym = nextsym;
311     curvalue = nextvalue;
312     nextsym = SYM_NONE;
313
314     // Reading stdin loop
315     loop = (cursym == SYM_NONE)
316         || (cursym == SYM_MOVE)
317         || (cursym == SYM_CHANGE);
318     while(loop) {
319       *in>>c;
320       if (in->eof()) {
321         if (cursym == SYM_NONE) {
322           cursym = SYM_EOF;
323         } else {
324           nextsym = SYM_EOF;
325         }
326         loop = 0;
327       } else {
328         direction = 1;
329         switch(c) {
330           case '-':
331             direction = -1;
332             // Fall through
333
334           case '+':
335             if (cursym == SYM_CHANGE) {
336               curvalue += direction;
337               // loop = 1
338             } else {
339               if (cursym == SYM_NONE) {
340                 cursym = SYM_CHANGE;
341                 curvalue = direction;
342                 // loop = 1
343               } else {
344                 nextsym = SYM_CHANGE;
345                 nextvalue = direction;
346                 loop = 0;
347               }
348             }
349             break;
350
351           case '<':
352             direction = -1;
353             // Fall through
354
355           case '>':
356             if (cursym == SYM_MOVE) {
357               curvalue += direction;
358               // loop = 1
359             } else {
360               if (cursym == SYM_NONE) {
361                 cursym = SYM_MOVE;
362                 curvalue = direction;
363                 // loop = 1
364               } else {
365                 nextsym = SYM_MOVE;
366                 nextvalue = direction;
367                 loop = 0;
368               }
369             }
370             break;
371
372           case ',':
373             if (cursym == SYM_NONE) {
374               cursym = SYM_READ;
375             } else {
376               nextsym = SYM_READ;
377             }
378             loop = 0;
379             break;
380
381           case '.':
382             if (cursym == SYM_NONE) {
383               cursym = SYM_WRITE;
384             } else {
385               nextsym = SYM_WRITE;
386             }
387             loop = 0;
388             break;
389
390           case '[':
391             if (cursym == SYM_NONE) {
392               cursym = SYM_LOOP;
393             } else {
394               nextsym = SYM_LOOP;
395             }
396             loop = 0;
397             break;
398
399           case ']':
400             if (cursym == SYM_NONE) {
401               cursym = SYM_ENDLOOP;
402             } else {
403               nextsym = SYM_ENDLOOP;
404             }
405             loop = 0;
406             break;
407
408           // Ignore other characters
409           default:
410             break;
411         }
412       }
413     }
414   }
415
416   if (cursym == SYM_ENDLOOP) {
417     if (!phi) {
418       std::cerr << "Error: Extra ']'\n";
419       abort();
420     }
421
422     // Write loop test
423     {
424       //br label %main.%d
425       builder->CreateBr(testbb);
426
427       //main.%d:
428
429       //%head.%d = phi i8 *[%head.%d, %main.%d], [%head.%d, %main.%d]
430       //Finish phi made at beginning of loop
431       phi->addIncoming(curhead, builder->GetInsertBlock());
432       Value *head_0 = phi;
433
434       //%tape.%d = load i8 *%head.%d
435       LoadInst *tape_0 = new LoadInst(head_0, tapereg, testbb);
436
437       //%test.%d = icmp eq i8 %tape.%d, 0
438       ICmpInst *test_0 = new ICmpInst(*testbb, ICmpInst::ICMP_EQ, tape_0,
439                                     ConstantInt::get(C, APInt(8, 0)), testreg);
440
441       //br i1 %test.%d, label %main.%d, label %main.%d
442       BasicBlock *bb_0 = BasicBlock::Create(C, label, brainf_func);
443       BranchInst::Create(bb_0, oldbb, test_0, testbb);
444
445       //main.%d:
446       builder->SetInsertPoint(bb_0);
447
448       //%head.%d = phi i8 *[%head.%d, %main.%d]
449       PHINode *phi_1 = builder->
450         CreatePHI(PointerType::getUnqual(IntegerType::getInt8Ty(C)), 1,
451                   headreg);
452       phi_1->addIncoming(head_0, testbb);
453       curhead = phi_1;
454     }
455
456     return;
457   }
458
459   //End of the program, so go to return block
460   builder->CreateBr(endbb);
461
462   if (phi) {
463     std::cerr << "Error: Missing ']'\n";
464     abort();
465   }
466 }