[PM/AA] Rebuild LLVM's alias analysis infrastructure in a way compatible
[oota-llvm.git] / examples / Kaleidoscope / Chapter6 / toy.cpp
1 #include "llvm/ADT/STLExtras.h"
2 #include "llvm/Analysis/Passes.h"
3 #include "llvm/IR/IRBuilder.h"
4 #include "llvm/IR/LLVMContext.h"
5 #include "llvm/IR/LegacyPassManager.h"
6 #include "llvm/IR/Module.h"
7 #include "llvm/IR/Verifier.h"
8 #include "llvm/Support/TargetSelect.h"
9 #include "llvm/Transforms/Scalar.h"
10 #include <cctype>
11 #include <cstdio>
12 #include <map>
13 #include <string>
14 #include <vector>
15 #include "../include/KaleidoscopeJIT.h"
16
17 using namespace llvm;
18 using namespace llvm::orc;
19
20 //===----------------------------------------------------------------------===//
21 // Lexer
22 //===----------------------------------------------------------------------===//
23
24 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
25 // of these for known things.
26 enum Token {
27   tok_eof = -1,
28
29   // commands
30   tok_def = -2,
31   tok_extern = -3,
32
33   // primary
34   tok_identifier = -4,
35   tok_number = -5,
36
37   // control
38   tok_if = -6,
39   tok_then = -7,
40   tok_else = -8,
41   tok_for = -9,
42   tok_in = -10,
43
44   // operators
45   tok_binary = -11,
46   tok_unary = -12
47 };
48
49 static std::string IdentifierStr; // Filled in if tok_identifier
50 static double NumVal;             // Filled in if tok_number
51
52 /// gettok - Return the next token from standard input.
53 static int gettok() {
54   static int LastChar = ' ';
55
56   // Skip any whitespace.
57   while (isspace(LastChar))
58     LastChar = getchar();
59
60   if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
61     IdentifierStr = LastChar;
62     while (isalnum((LastChar = getchar())))
63       IdentifierStr += LastChar;
64
65     if (IdentifierStr == "def")
66       return tok_def;
67     if (IdentifierStr == "extern")
68       return tok_extern;
69     if (IdentifierStr == "if")
70       return tok_if;
71     if (IdentifierStr == "then")
72       return tok_then;
73     if (IdentifierStr == "else")
74       return tok_else;
75     if (IdentifierStr == "for")
76       return tok_for;
77     if (IdentifierStr == "in")
78       return tok_in;
79     if (IdentifierStr == "binary")
80       return tok_binary;
81     if (IdentifierStr == "unary")
82       return tok_unary;
83     return tok_identifier;
84   }
85
86   if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
87     std::string NumStr;
88     do {
89       NumStr += LastChar;
90       LastChar = getchar();
91     } while (isdigit(LastChar) || LastChar == '.');
92
93     NumVal = strtod(NumStr.c_str(), 0);
94     return tok_number;
95   }
96
97   if (LastChar == '#') {
98     // Comment until end of line.
99     do
100       LastChar = getchar();
101     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
102
103     if (LastChar != EOF)
104       return gettok();
105   }
106
107   // Check for end of file.  Don't eat the EOF.
108   if (LastChar == EOF)
109     return tok_eof;
110
111   // Otherwise, just return the character as its ascii value.
112   int ThisChar = LastChar;
113   LastChar = getchar();
114   return ThisChar;
115 }
116
117 //===----------------------------------------------------------------------===//
118 // Abstract Syntax Tree (aka Parse Tree)
119 //===----------------------------------------------------------------------===//
120 namespace {
121 /// ExprAST - Base class for all expression nodes.
122 class ExprAST {
123 public:
124   virtual ~ExprAST() {}
125   virtual Value *codegen() = 0;
126 };
127
128 /// NumberExprAST - Expression class for numeric literals like "1.0".
129 class NumberExprAST : public ExprAST {
130   double Val;
131
132 public:
133   NumberExprAST(double Val) : Val(Val) {}
134   Value *codegen() override;
135 };
136
137 /// VariableExprAST - Expression class for referencing a variable, like "a".
138 class VariableExprAST : public ExprAST {
139   std::string Name;
140
141 public:
142   VariableExprAST(const std::string &Name) : Name(Name) {}
143   Value *codegen() override;
144 };
145
146 /// UnaryExprAST - Expression class for a unary operator.
147 class UnaryExprAST : public ExprAST {
148   char Opcode;
149   std::unique_ptr<ExprAST> Operand;
150
151 public:
152   UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
153       : Opcode(Opcode), Operand(std::move(Operand)) {}
154   Value *codegen() override;
155 };
156
157 /// BinaryExprAST - Expression class for a binary operator.
158 class BinaryExprAST : public ExprAST {
159   char Op;
160   std::unique_ptr<ExprAST> LHS, RHS;
161
162 public:
163   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
164                 std::unique_ptr<ExprAST> RHS)
165       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
166   Value *codegen() override;
167 };
168
169 /// CallExprAST - Expression class for function calls.
170 class CallExprAST : public ExprAST {
171   std::string Callee;
172   std::vector<std::unique_ptr<ExprAST>> Args;
173
174 public:
175   CallExprAST(const std::string &Callee,
176               std::vector<std::unique_ptr<ExprAST>> Args)
177       : Callee(Callee), Args(std::move(Args)) {}
178   Value *codegen() override;
179 };
180
181 /// IfExprAST - Expression class for if/then/else.
182 class IfExprAST : public ExprAST {
183   std::unique_ptr<ExprAST> Cond, Then, Else;
184
185 public:
186   IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
187             std::unique_ptr<ExprAST> Else)
188       : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
189   Value *codegen() override;
190 };
191
192 /// ForExprAST - Expression class for for/in.
193 class ForExprAST : public ExprAST {
194   std::string VarName;
195   std::unique_ptr<ExprAST> Start, End, Step, Body;
196
197 public:
198   ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start,
199              std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step,
200              std::unique_ptr<ExprAST> Body)
201       : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
202         Step(std::move(Step)), Body(std::move(Body)) {}
203   Value *codegen() override;
204 };
205
206 /// PrototypeAST - This class represents the "prototype" for a function,
207 /// which captures its name, and its argument names (thus implicitly the number
208 /// of arguments the function takes), as well as if it is an operator.
209 class PrototypeAST {
210   std::string Name;
211   std::vector<std::string> Args;
212   bool IsOperator;
213   unsigned Precedence; // Precedence if a binary op.
214
215 public:
216   PrototypeAST(const std::string &Name, std::vector<std::string> Args,
217                bool IsOperator = false, unsigned Prec = 0)
218       : Name(Name), Args(std::move(Args)), IsOperator(IsOperator),
219         Precedence(Prec) {}
220   Function *codegen();
221   const std::string &getName() const { return Name; }
222
223   bool isUnaryOp() const { return IsOperator && Args.size() == 1; }
224   bool isBinaryOp() const { return IsOperator && Args.size() == 2; }
225
226   char getOperatorName() const {
227     assert(isUnaryOp() || isBinaryOp());
228     return Name[Name.size() - 1];
229   }
230
231   unsigned getBinaryPrecedence() const { return Precedence; }
232 };
233
234 /// FunctionAST - This class represents a function definition itself.
235 class FunctionAST {
236   std::unique_ptr<PrototypeAST> Proto;
237   std::unique_ptr<ExprAST> Body;
238
239 public:
240   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
241               std::unique_ptr<ExprAST> Body)
242       : Proto(std::move(Proto)), Body(std::move(Body)) {}
243   Function *codegen();
244 };
245 } // end anonymous namespace
246
247 //===----------------------------------------------------------------------===//
248 // Parser
249 //===----------------------------------------------------------------------===//
250
251 /// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
252 /// token the parser is looking at.  getNextToken reads another token from the
253 /// lexer and updates CurTok with its results.
254 static int CurTok;
255 static int getNextToken() { return CurTok = gettok(); }
256
257 /// BinopPrecedence - This holds the precedence for each binary operator that is
258 /// defined.
259 static std::map<char, int> BinopPrecedence;
260
261 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
262 static int GetTokPrecedence() {
263   if (!isascii(CurTok))
264     return -1;
265
266   // Make sure it's a declared binop.
267   int TokPrec = BinopPrecedence[CurTok];
268   if (TokPrec <= 0)
269     return -1;
270   return TokPrec;
271 }
272
273 /// Error* - These are little helper functions for error handling.
274 std::unique_ptr<ExprAST> Error(const char *Str) {
275   fprintf(stderr, "Error: %s\n", Str);
276   return nullptr;
277 }
278 std::unique_ptr<PrototypeAST> ErrorP(const char *Str) {
279   Error(Str);
280   return nullptr;
281 }
282
283 static std::unique_ptr<ExprAST> ParseExpression();
284
285 /// numberexpr ::= number
286 static std::unique_ptr<ExprAST> ParseNumberExpr() {
287   auto Result = llvm::make_unique<NumberExprAST>(NumVal);
288   getNextToken(); // consume the number
289   return std::move(Result);
290 }
291
292 /// parenexpr ::= '(' expression ')'
293 static std::unique_ptr<ExprAST> ParseParenExpr() {
294   getNextToken(); // eat (.
295   auto V = ParseExpression();
296   if (!V)
297     return nullptr;
298
299   if (CurTok != ')')
300     return Error("expected ')'");
301   getNextToken(); // eat ).
302   return V;
303 }
304
305 /// identifierexpr
306 ///   ::= identifier
307 ///   ::= identifier '(' expression* ')'
308 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
309   std::string IdName = IdentifierStr;
310
311   getNextToken(); // eat identifier.
312
313   if (CurTok != '(') // Simple variable ref.
314     return llvm::make_unique<VariableExprAST>(IdName);
315
316   // Call.
317   getNextToken(); // eat (
318   std::vector<std::unique_ptr<ExprAST>> Args;
319   if (CurTok != ')') {
320     while (1) {
321       if (auto Arg = ParseExpression())
322         Args.push_back(std::move(Arg));
323       else
324         return nullptr;
325
326       if (CurTok == ')')
327         break;
328
329       if (CurTok != ',')
330         return Error("Expected ')' or ',' in argument list");
331       getNextToken();
332     }
333   }
334
335   // Eat the ')'.
336   getNextToken();
337
338   return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
339 }
340
341 /// ifexpr ::= 'if' expression 'then' expression 'else' expression
342 static std::unique_ptr<ExprAST> ParseIfExpr() {
343   getNextToken(); // eat the if.
344
345   // condition.
346   auto Cond = ParseExpression();
347   if (!Cond)
348     return nullptr;
349
350   if (CurTok != tok_then)
351     return Error("expected then");
352   getNextToken(); // eat the then
353
354   auto Then = ParseExpression();
355   if (!Then)
356     return nullptr;
357
358   if (CurTok != tok_else)
359     return Error("expected else");
360
361   getNextToken();
362
363   auto Else = ParseExpression();
364   if (!Else)
365     return nullptr;
366
367   return llvm::make_unique<IfExprAST>(std::move(Cond), std::move(Then),
368                                       std::move(Else));
369 }
370
371 /// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
372 static std::unique_ptr<ExprAST> ParseForExpr() {
373   getNextToken(); // eat the for.
374
375   if (CurTok != tok_identifier)
376     return Error("expected identifier after for");
377
378   std::string IdName = IdentifierStr;
379   getNextToken(); // eat identifier.
380
381   if (CurTok != '=')
382     return Error("expected '=' after for");
383   getNextToken(); // eat '='.
384
385   auto Start = ParseExpression();
386   if (!Start)
387     return nullptr;
388   if (CurTok != ',')
389     return Error("expected ',' after for start value");
390   getNextToken();
391
392   auto End = ParseExpression();
393   if (!End)
394     return nullptr;
395
396   // The step value is optional.
397   std::unique_ptr<ExprAST> Step;
398   if (CurTok == ',') {
399     getNextToken();
400     Step = ParseExpression();
401     if (!Step)
402       return nullptr;
403   }
404
405   if (CurTok != tok_in)
406     return Error("expected 'in' after for");
407   getNextToken(); // eat 'in'.
408
409   auto Body = ParseExpression();
410   if (!Body)
411     return nullptr;
412
413   return llvm::make_unique<ForExprAST>(IdName, std::move(Start), std::move(End),
414                                        std::move(Step), std::move(Body));
415 }
416
417 /// primary
418 ///   ::= identifierexpr
419 ///   ::= numberexpr
420 ///   ::= parenexpr
421 ///   ::= ifexpr
422 ///   ::= forexpr
423 static std::unique_ptr<ExprAST> ParsePrimary() {
424   switch (CurTok) {
425   default:
426     return Error("unknown token when expecting an expression");
427   case tok_identifier:
428     return ParseIdentifierExpr();
429   case tok_number:
430     return ParseNumberExpr();
431   case '(':
432     return ParseParenExpr();
433   case tok_if:
434     return ParseIfExpr();
435   case tok_for:
436     return ParseForExpr();
437   }
438 }
439
440 /// unary
441 ///   ::= primary
442 ///   ::= '!' unary
443 static std::unique_ptr<ExprAST> ParseUnary() {
444   // If the current token is not an operator, it must be a primary expr.
445   if (!isascii(CurTok) || CurTok == '(' || CurTok == ',')
446     return ParsePrimary();
447
448   // If this is a unary operator, read it.
449   int Opc = CurTok;
450   getNextToken();
451   if (auto Operand = ParseUnary())
452     return llvm::make_unique<UnaryExprAST>(Opc, std::move(Operand));
453   return nullptr;
454 }
455
456 /// binoprhs
457 ///   ::= ('+' unary)*
458 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
459                                               std::unique_ptr<ExprAST> LHS) {
460   // If this is a binop, find its precedence.
461   while (1) {
462     int TokPrec = GetTokPrecedence();
463
464     // If this is a binop that binds at least as tightly as the current binop,
465     // consume it, otherwise we are done.
466     if (TokPrec < ExprPrec)
467       return LHS;
468
469     // Okay, we know this is a binop.
470     int BinOp = CurTok;
471     getNextToken(); // eat binop
472
473     // Parse the unary expression after the binary operator.
474     auto RHS = ParseUnary();
475     if (!RHS)
476       return nullptr;
477
478     // If BinOp binds less tightly with RHS than the operator after RHS, let
479     // the pending operator take RHS as its LHS.
480     int NextPrec = GetTokPrecedence();
481     if (TokPrec < NextPrec) {
482       RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
483       if (!RHS)
484         return nullptr;
485     }
486
487     // Merge LHS/RHS.
488     LHS =
489         llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
490   }
491 }
492
493 /// expression
494 ///   ::= unary binoprhs
495 ///
496 static std::unique_ptr<ExprAST> ParseExpression() {
497   auto LHS = ParseUnary();
498   if (!LHS)
499     return nullptr;
500
501   return ParseBinOpRHS(0, std::move(LHS));
502 }
503
504 /// prototype
505 ///   ::= id '(' id* ')'
506 ///   ::= binary LETTER number? (id, id)
507 ///   ::= unary LETTER (id)
508 static std::unique_ptr<PrototypeAST> ParsePrototype() {
509   std::string FnName;
510
511   unsigned Kind = 0; // 0 = identifier, 1 = unary, 2 = binary.
512   unsigned BinaryPrecedence = 30;
513
514   switch (CurTok) {
515   default:
516     return ErrorP("Expected function name in prototype");
517   case tok_identifier:
518     FnName = IdentifierStr;
519     Kind = 0;
520     getNextToken();
521     break;
522   case tok_unary:
523     getNextToken();
524     if (!isascii(CurTok))
525       return ErrorP("Expected unary operator");
526     FnName = "unary";
527     FnName += (char)CurTok;
528     Kind = 1;
529     getNextToken();
530     break;
531   case tok_binary:
532     getNextToken();
533     if (!isascii(CurTok))
534       return ErrorP("Expected binary operator");
535     FnName = "binary";
536     FnName += (char)CurTok;
537     Kind = 2;
538     getNextToken();
539
540     // Read the precedence if present.
541     if (CurTok == tok_number) {
542       if (NumVal < 1 || NumVal > 100)
543         return ErrorP("Invalid precedecnce: must be 1..100");
544       BinaryPrecedence = (unsigned)NumVal;
545       getNextToken();
546     }
547     break;
548   }
549
550   if (CurTok != '(')
551     return ErrorP("Expected '(' in prototype");
552
553   std::vector<std::string> ArgNames;
554   while (getNextToken() == tok_identifier)
555     ArgNames.push_back(IdentifierStr);
556   if (CurTok != ')')
557     return ErrorP("Expected ')' in prototype");
558
559   // success.
560   getNextToken(); // eat ')'.
561
562   // Verify right number of names for operator.
563   if (Kind && ArgNames.size() != Kind)
564     return ErrorP("Invalid number of operands for operator");
565
566   return llvm::make_unique<PrototypeAST>(FnName, ArgNames, Kind != 0,
567                                          BinaryPrecedence);
568 }
569
570 /// definition ::= 'def' prototype expression
571 static std::unique_ptr<FunctionAST> ParseDefinition() {
572   getNextToken(); // eat def.
573   auto Proto = ParsePrototype();
574   if (!Proto)
575     return nullptr;
576
577   if (auto E = ParseExpression())
578     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
579   return nullptr;
580 }
581
582 /// toplevelexpr ::= expression
583 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
584   if (auto E = ParseExpression()) {
585     // Make an anonymous proto.
586     auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
587                                                  std::vector<std::string>());
588     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
589   }
590   return nullptr;
591 }
592
593 /// external ::= 'extern' prototype
594 static std::unique_ptr<PrototypeAST> ParseExtern() {
595   getNextToken(); // eat extern.
596   return ParsePrototype();
597 }
598
599 //===----------------------------------------------------------------------===//
600 // Code Generation
601 //===----------------------------------------------------------------------===//
602
603 static std::unique_ptr<Module> TheModule;
604 static IRBuilder<> Builder(getGlobalContext());
605 static std::map<std::string, Value *> NamedValues;
606 static std::unique_ptr<legacy::FunctionPassManager> TheFPM;
607 static std::unique_ptr<KaleidoscopeJIT> TheJIT;
608 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
609
610 Value *ErrorV(const char *Str) {
611   Error(Str);
612   return nullptr;
613 }
614
615 Function *getFunction(std::string Name) {
616   // First, see if the function has already been added to the current module.
617   if (auto *F = TheModule->getFunction(Name))
618     return F;
619
620   // If not, check whether we can codegen the declaration from some existing
621   // prototype.
622   auto FI = FunctionProtos.find(Name);
623   if (FI != FunctionProtos.end())
624     return FI->second->codegen();
625
626   // If no existing prototype exists, return null.
627   return nullptr;
628 }
629
630 Value *NumberExprAST::codegen() {
631   return ConstantFP::get(getGlobalContext(), APFloat(Val));
632 }
633
634 Value *VariableExprAST::codegen() {
635   // Look this variable up in the function.
636   Value *V = NamedValues[Name];
637   if (!V)
638     return ErrorV("Unknown variable name");
639   return V;
640 }
641
642 Value *UnaryExprAST::codegen() {
643   Value *OperandV = Operand->codegen();
644   if (!OperandV)
645     return nullptr;
646
647   Function *F = getFunction(std::string("unary") + Opcode);
648   if (!F)
649     return ErrorV("Unknown unary operator");
650
651   return Builder.CreateCall(F, OperandV, "unop");
652 }
653
654 Value *BinaryExprAST::codegen() {
655   Value *L = LHS->codegen();
656   Value *R = RHS->codegen();
657   if (!L || !R)
658     return nullptr;
659
660   switch (Op) {
661   case '+':
662     return Builder.CreateFAdd(L, R, "addtmp");
663   case '-':
664     return Builder.CreateFSub(L, R, "subtmp");
665   case '*':
666     return Builder.CreateFMul(L, R, "multmp");
667   case '<':
668     L = Builder.CreateFCmpULT(L, R, "cmptmp");
669     // Convert bool 0/1 to double 0.0 or 1.0
670     return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()),
671                                 "booltmp");
672   default:
673     break;
674   }
675
676   // If it wasn't a builtin binary operator, it must be a user defined one. Emit
677   // a call to it.
678   Function *F = getFunction(std::string("binary") + Op);
679   assert(F && "binary operator not found!");
680
681   Value *Ops[] = {L, R};
682   return Builder.CreateCall(F, Ops, "binop");
683 }
684
685 Value *CallExprAST::codegen() {
686   // Look up the name in the global module table.
687   Function *CalleeF = getFunction(Callee);
688   if (!CalleeF)
689     return ErrorV("Unknown function referenced");
690
691   // If argument mismatch error.
692   if (CalleeF->arg_size() != Args.size())
693     return ErrorV("Incorrect # arguments passed");
694
695   std::vector<Value *> ArgsV;
696   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
697     ArgsV.push_back(Args[i]->codegen());
698     if (!ArgsV.back())
699       return nullptr;
700   }
701
702   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
703 }
704
705 Value *IfExprAST::codegen() {
706   Value *CondV = Cond->codegen();
707   if (!CondV)
708     return nullptr;
709
710   // Convert condition to a bool by comparing equal to 0.0.
711   CondV = Builder.CreateFCmpONE(
712       CondV, ConstantFP::get(getGlobalContext(), APFloat(0.0)), "ifcond");
713
714   Function *TheFunction = Builder.GetInsertBlock()->getParent();
715
716   // Create blocks for the then and else cases.  Insert the 'then' block at the
717   // end of the function.
718   BasicBlock *ThenBB =
719       BasicBlock::Create(getGlobalContext(), "then", TheFunction);
720   BasicBlock *ElseBB = BasicBlock::Create(getGlobalContext(), "else");
721   BasicBlock *MergeBB = BasicBlock::Create(getGlobalContext(), "ifcont");
722
723   Builder.CreateCondBr(CondV, ThenBB, ElseBB);
724
725   // Emit then value.
726   Builder.SetInsertPoint(ThenBB);
727
728   Value *ThenV = Then->codegen();
729   if (!ThenV)
730     return nullptr;
731
732   Builder.CreateBr(MergeBB);
733   // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
734   ThenBB = Builder.GetInsertBlock();
735
736   // Emit else block.
737   TheFunction->getBasicBlockList().push_back(ElseBB);
738   Builder.SetInsertPoint(ElseBB);
739
740   Value *ElseV = Else->codegen();
741   if (!ElseV)
742     return nullptr;
743
744   Builder.CreateBr(MergeBB);
745   // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
746   ElseBB = Builder.GetInsertBlock();
747
748   // Emit merge block.
749   TheFunction->getBasicBlockList().push_back(MergeBB);
750   Builder.SetInsertPoint(MergeBB);
751   PHINode *PN =
752       Builder.CreatePHI(Type::getDoubleTy(getGlobalContext()), 2, "iftmp");
753
754   PN->addIncoming(ThenV, ThenBB);
755   PN->addIncoming(ElseV, ElseBB);
756   return PN;
757 }
758
759 // Output for-loop as:
760 //   ...
761 //   start = startexpr
762 //   goto loop
763 // loop:
764 //   variable = phi [start, loopheader], [nextvariable, loopend]
765 //   ...
766 //   bodyexpr
767 //   ...
768 // loopend:
769 //   step = stepexpr
770 //   nextvariable = variable + step
771 //   endcond = endexpr
772 //   br endcond, loop, endloop
773 // outloop:
774 Value *ForExprAST::codegen() {
775   // Emit the start code first, without 'variable' in scope.
776   Value *StartVal = Start->codegen();
777   if (!StartVal)
778     return nullptr;
779
780   // Make the new basic block for the loop header, inserting after current
781   // block.
782   Function *TheFunction = Builder.GetInsertBlock()->getParent();
783   BasicBlock *PreheaderBB = Builder.GetInsertBlock();
784   BasicBlock *LoopBB =
785       BasicBlock::Create(getGlobalContext(), "loop", TheFunction);
786
787   // Insert an explicit fall through from the current block to the LoopBB.
788   Builder.CreateBr(LoopBB);
789
790   // Start insertion in LoopBB.
791   Builder.SetInsertPoint(LoopBB);
792
793   // Start the PHI node with an entry for Start.
794   PHINode *Variable = Builder.CreatePHI(Type::getDoubleTy(getGlobalContext()),
795                                         2, VarName.c_str());
796   Variable->addIncoming(StartVal, PreheaderBB);
797
798   // Within the loop, the variable is defined equal to the PHI node.  If it
799   // shadows an existing variable, we have to restore it, so save it now.
800   Value *OldVal = NamedValues[VarName];
801   NamedValues[VarName] = Variable;
802
803   // Emit the body of the loop.  This, like any other expr, can change the
804   // current BB.  Note that we ignore the value computed by the body, but don't
805   // allow an error.
806   if (!Body->codegen())
807     return nullptr;
808
809   // Emit the step value.
810   Value *StepVal = nullptr;
811   if (Step) {
812     StepVal = Step->codegen();
813     if (!StepVal)
814       return nullptr;
815   } else {
816     // If not specified, use 1.0.
817     StepVal = ConstantFP::get(getGlobalContext(), APFloat(1.0));
818   }
819
820   Value *NextVar = Builder.CreateFAdd(Variable, StepVal, "nextvar");
821
822   // Compute the end condition.
823   Value *EndCond = End->codegen();
824   if (!EndCond)
825     return nullptr;
826
827   // Convert condition to a bool by comparing equal to 0.0.
828   EndCond = Builder.CreateFCmpONE(
829       EndCond, ConstantFP::get(getGlobalContext(), APFloat(0.0)), "loopcond");
830
831   // Create the "after loop" block and insert it.
832   BasicBlock *LoopEndBB = Builder.GetInsertBlock();
833   BasicBlock *AfterBB =
834       BasicBlock::Create(getGlobalContext(), "afterloop", TheFunction);
835
836   // Insert the conditional branch into the end of LoopEndBB.
837   Builder.CreateCondBr(EndCond, LoopBB, AfterBB);
838
839   // Any new code will be inserted in AfterBB.
840   Builder.SetInsertPoint(AfterBB);
841
842   // Add a new entry to the PHI node for the backedge.
843   Variable->addIncoming(NextVar, LoopEndBB);
844
845   // Restore the unshadowed variable.
846   if (OldVal)
847     NamedValues[VarName] = OldVal;
848   else
849     NamedValues.erase(VarName);
850
851   // for expr always returns 0.0.
852   return Constant::getNullValue(Type::getDoubleTy(getGlobalContext()));
853 }
854
855 Function *PrototypeAST::codegen() {
856   // Make the function type:  double(double,double) etc.
857   std::vector<Type *> Doubles(Args.size(),
858                               Type::getDoubleTy(getGlobalContext()));
859   FunctionType *FT =
860       FunctionType::get(Type::getDoubleTy(getGlobalContext()), Doubles, false);
861
862   Function *F =
863       Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
864
865   // Set names for all arguments.
866   unsigned Idx = 0;
867   for (auto &Arg : F->args())
868     Arg.setName(Args[Idx++]);
869
870   return F;
871 }
872
873 Function *FunctionAST::codegen() {
874   // Transfer ownership of the prototype to the FunctionProtos map, but keep a
875   // reference to it for use below.
876   auto &P = *Proto;
877   FunctionProtos[Proto->getName()] = std::move(Proto);
878   Function *TheFunction = getFunction(P.getName());
879   if (!TheFunction)
880     return nullptr;
881
882   // If this is an operator, install it.
883   if (P.isBinaryOp())
884     BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
885
886   // Create a new basic block to start insertion into.
887   BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
888   Builder.SetInsertPoint(BB);
889
890   // Record the function arguments in the NamedValues map.
891   NamedValues.clear();
892   for (auto &Arg : TheFunction->args())
893     NamedValues[Arg.getName()] = &Arg;
894
895   if (Value *RetVal = Body->codegen()) {
896     // Finish off the function.
897     Builder.CreateRet(RetVal);
898
899     // Validate the generated code, checking for consistency.
900     verifyFunction(*TheFunction);
901
902     // Run the optimizer on the function.
903     TheFPM->run(*TheFunction);
904
905     return TheFunction;
906   }
907
908   // Error reading body, remove function.
909   TheFunction->eraseFromParent();
910
911   if (P.isBinaryOp())
912     BinopPrecedence.erase(Proto->getOperatorName());
913   return nullptr;
914 }
915
916 //===----------------------------------------------------------------------===//
917 // Top-Level parsing and JIT Driver
918 //===----------------------------------------------------------------------===//
919
920 static void InitializeModuleAndPassManager() {
921   // Open a new module.
922   TheModule = llvm::make_unique<Module>("my cool jit", getGlobalContext());
923   TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
924
925   // Create a new pass manager attached to it.
926   TheFPM = llvm::make_unique<legacy::FunctionPassManager>(TheModule.get());
927
928   // Do simple "peephole" optimizations and bit-twiddling optzns.
929   TheFPM->add(createInstructionCombiningPass());
930   // Reassociate expressions.
931   TheFPM->add(createReassociatePass());
932   // Eliminate Common SubExpressions.
933   TheFPM->add(createGVNPass());
934   // Simplify the control flow graph (deleting unreachable blocks, etc).
935   TheFPM->add(createCFGSimplificationPass());
936
937   TheFPM->doInitialization();
938 }
939
940 static void HandleDefinition() {
941   if (auto FnAST = ParseDefinition()) {
942     if (auto *FnIR = FnAST->codegen()) {
943       fprintf(stderr, "Read function definition:");
944       FnIR->dump();
945       TheJIT->addModule(std::move(TheModule));
946       InitializeModuleAndPassManager();
947     }
948   } else {
949     // Skip token for error recovery.
950     getNextToken();
951   }
952 }
953
954 static void HandleExtern() {
955   if (auto ProtoAST = ParseExtern()) {
956     if (auto *FnIR = ProtoAST->codegen()) {
957       fprintf(stderr, "Read extern: ");
958       FnIR->dump();
959       FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
960     }
961   } else {
962     // Skip token for error recovery.
963     getNextToken();
964   }
965 }
966
967 static void HandleTopLevelExpression() {
968   // Evaluate a top-level expression into an anonymous function.
969   if (auto FnAST = ParseTopLevelExpr()) {
970     if (FnAST->codegen()) {
971
972       // JIT the module containing the anonymous expression, keeping a handle so
973       // we can free it later.
974       auto H = TheJIT->addModule(std::move(TheModule));
975       InitializeModuleAndPassManager();
976
977       // Search the JIT for the __anon_expr symbol.
978       auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
979       assert(ExprSymbol && "Function not found");
980
981       // Get the symbol's address and cast it to the right type (takes no
982       // arguments, returns a double) so we can call it as a native function.
983       double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
984       fprintf(stderr, "Evaluated to %f\n", FP());
985
986       // Delete the anonymous expression module from the JIT.
987       TheJIT->removeModule(H);
988     }
989   } else {
990     // Skip token for error recovery.
991     getNextToken();
992   }
993 }
994
995 /// top ::= definition | external | expression | ';'
996 static void MainLoop() {
997   while (1) {
998     fprintf(stderr, "ready> ");
999     switch (CurTok) {
1000     case tok_eof:
1001       return;
1002     case ';': // ignore top-level semicolons.
1003       getNextToken();
1004       break;
1005     case tok_def:
1006       HandleDefinition();
1007       break;
1008     case tok_extern:
1009       HandleExtern();
1010       break;
1011     default:
1012       HandleTopLevelExpression();
1013       break;
1014     }
1015   }
1016 }
1017
1018 //===----------------------------------------------------------------------===//
1019 // "Library" functions that can be "extern'd" from user code.
1020 //===----------------------------------------------------------------------===//
1021
1022 /// putchard - putchar that takes a double and returns 0.
1023 extern "C" double putchard(double X) {
1024   fputc((char)X, stderr);
1025   return 0;
1026 }
1027
1028 /// printd - printf that takes a double prints it as "%f\n", returning 0.
1029 extern "C" double printd(double X) {
1030   fprintf(stderr, "%f\n", X);
1031   return 0;
1032 }
1033
1034 //===----------------------------------------------------------------------===//
1035 // Main driver code.
1036 //===----------------------------------------------------------------------===//
1037
1038 int main() {
1039   InitializeNativeTarget();
1040   InitializeNativeTargetAsmPrinter();
1041   InitializeNativeTargetAsmParser();
1042
1043   // Install standard binary operators.
1044   // 1 is lowest precedence.
1045   BinopPrecedence['<'] = 10;
1046   BinopPrecedence['+'] = 20;
1047   BinopPrecedence['-'] = 20;
1048   BinopPrecedence['*'] = 40; // highest.
1049
1050   // Prime the first token.
1051   fprintf(stderr, "ready> ");
1052   getNextToken();
1053
1054   TheJIT = llvm::make_unique<KaleidoscopeJIT>();
1055
1056   InitializeModuleAndPassManager();
1057
1058   // Run the main "interpreter loop" now.
1059   MainLoop();
1060
1061   return 0;
1062 }