Recommit r246175 - Add Kaleidoscope regression tests, with a fix to make sure
[oota-llvm.git] / examples / Kaleidoscope / Chapter6 / toy.cpp
1 #include "llvm/ADT/STLExtras.h"
2 #include "llvm/Analysis/BasicAliasAnalysis.h"
3 #include "llvm/Analysis/Passes.h"
4 #include "llvm/IR/IRBuilder.h"
5 #include "llvm/IR/LLVMContext.h"
6 #include "llvm/IR/LegacyPassManager.h"
7 #include "llvm/IR/Module.h"
8 #include "llvm/IR/Verifier.h"
9 #include "llvm/Support/TargetSelect.h"
10 #include "llvm/Transforms/Scalar.h"
11 #include <cctype>
12 #include <cstdio>
13 #include <map>
14 #include <string>
15 #include <vector>
16 #include "../include/KaleidoscopeJIT.h"
17
18 using namespace llvm;
19 using namespace llvm::orc;
20
21 //===----------------------------------------------------------------------===//
22 // Lexer
23 //===----------------------------------------------------------------------===//
24
25 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
26 // of these for known things.
27 enum Token {
28   tok_eof = -1,
29
30   // commands
31   tok_def = -2,
32   tok_extern = -3,
33
34   // primary
35   tok_identifier = -4,
36   tok_number = -5,
37
38   // control
39   tok_if = -6,
40   tok_then = -7,
41   tok_else = -8,
42   tok_for = -9,
43   tok_in = -10,
44
45   // operators
46   tok_binary = -11,
47   tok_unary = -12
48 };
49
50 static std::string IdentifierStr; // Filled in if tok_identifier
51 static double NumVal;             // Filled in if tok_number
52
53 /// gettok - Return the next token from standard input.
54 static int gettok() {
55   static int LastChar = ' ';
56
57   // Skip any whitespace.
58   while (isspace(LastChar))
59     LastChar = getchar();
60
61   if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
62     IdentifierStr = LastChar;
63     while (isalnum((LastChar = getchar())))
64       IdentifierStr += LastChar;
65
66     if (IdentifierStr == "def")
67       return tok_def;
68     if (IdentifierStr == "extern")
69       return tok_extern;
70     if (IdentifierStr == "if")
71       return tok_if;
72     if (IdentifierStr == "then")
73       return tok_then;
74     if (IdentifierStr == "else")
75       return tok_else;
76     if (IdentifierStr == "for")
77       return tok_for;
78     if (IdentifierStr == "in")
79       return tok_in;
80     if (IdentifierStr == "binary")
81       return tok_binary;
82     if (IdentifierStr == "unary")
83       return tok_unary;
84     return tok_identifier;
85   }
86
87   if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
88     std::string NumStr;
89     do {
90       NumStr += LastChar;
91       LastChar = getchar();
92     } while (isdigit(LastChar) || LastChar == '.');
93
94     NumVal = strtod(NumStr.c_str(), 0);
95     return tok_number;
96   }
97
98   if (LastChar == '#') {
99     // Comment until end of line.
100     do
101       LastChar = getchar();
102     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
103
104     if (LastChar != EOF)
105       return gettok();
106   }
107
108   // Check for end of file.  Don't eat the EOF.
109   if (LastChar == EOF)
110     return tok_eof;
111
112   // Otherwise, just return the character as its ascii value.
113   int ThisChar = LastChar;
114   LastChar = getchar();
115   return ThisChar;
116 }
117
118 //===----------------------------------------------------------------------===//
119 // Abstract Syntax Tree (aka Parse Tree)
120 //===----------------------------------------------------------------------===//
121 namespace {
122 /// ExprAST - Base class for all expression nodes.
123 class ExprAST {
124 public:
125   virtual ~ExprAST() {}
126   virtual Value *codegen() = 0;
127 };
128
129 /// NumberExprAST - Expression class for numeric literals like "1.0".
130 class NumberExprAST : public ExprAST {
131   double Val;
132
133 public:
134   NumberExprAST(double Val) : Val(Val) {}
135   Value *codegen() override;
136 };
137
138 /// VariableExprAST - Expression class for referencing a variable, like "a".
139 class VariableExprAST : public ExprAST {
140   std::string Name;
141
142 public:
143   VariableExprAST(const std::string &Name) : Name(Name) {}
144   Value *codegen() override;
145 };
146
147 /// UnaryExprAST - Expression class for a unary operator.
148 class UnaryExprAST : public ExprAST {
149   char Opcode;
150   std::unique_ptr<ExprAST> Operand;
151
152 public:
153   UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
154       : Opcode(Opcode), Operand(std::move(Operand)) {}
155   Value *codegen() override;
156 };
157
158 /// BinaryExprAST - Expression class for a binary operator.
159 class BinaryExprAST : public ExprAST {
160   char Op;
161   std::unique_ptr<ExprAST> LHS, RHS;
162
163 public:
164   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
165                 std::unique_ptr<ExprAST> RHS)
166       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
167   Value *codegen() override;
168 };
169
170 /// CallExprAST - Expression class for function calls.
171 class CallExprAST : public ExprAST {
172   std::string Callee;
173   std::vector<std::unique_ptr<ExprAST>> Args;
174
175 public:
176   CallExprAST(const std::string &Callee,
177               std::vector<std::unique_ptr<ExprAST>> Args)
178       : Callee(Callee), Args(std::move(Args)) {}
179   Value *codegen() override;
180 };
181
182 /// IfExprAST - Expression class for if/then/else.
183 class IfExprAST : public ExprAST {
184   std::unique_ptr<ExprAST> Cond, Then, Else;
185
186 public:
187   IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
188             std::unique_ptr<ExprAST> Else)
189       : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
190   Value *codegen() override;
191 };
192
193 /// ForExprAST - Expression class for for/in.
194 class ForExprAST : public ExprAST {
195   std::string VarName;
196   std::unique_ptr<ExprAST> Start, End, Step, Body;
197
198 public:
199   ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start,
200              std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step,
201              std::unique_ptr<ExprAST> Body)
202       : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
203         Step(std::move(Step)), Body(std::move(Body)) {}
204   Value *codegen() override;
205 };
206
207 /// PrototypeAST - This class represents the "prototype" for a function,
208 /// which captures its name, and its argument names (thus implicitly the number
209 /// of arguments the function takes), as well as if it is an operator.
210 class PrototypeAST {
211   std::string Name;
212   std::vector<std::string> Args;
213   bool IsOperator;
214   unsigned Precedence; // Precedence if a binary op.
215
216 public:
217   PrototypeAST(const std::string &Name, std::vector<std::string> Args,
218                bool IsOperator = false, unsigned Prec = 0)
219       : Name(Name), Args(std::move(Args)), IsOperator(IsOperator),
220         Precedence(Prec) {}
221   Function *codegen();
222   const std::string &getName() const { return Name; }
223
224   bool isUnaryOp() const { return IsOperator && Args.size() == 1; }
225   bool isBinaryOp() const { return IsOperator && Args.size() == 2; }
226
227   char getOperatorName() const {
228     assert(isUnaryOp() || isBinaryOp());
229     return Name[Name.size() - 1];
230   }
231
232   unsigned getBinaryPrecedence() const { return Precedence; }
233 };
234
235 /// FunctionAST - This class represents a function definition itself.
236 class FunctionAST {
237   std::unique_ptr<PrototypeAST> Proto;
238   std::unique_ptr<ExprAST> Body;
239
240 public:
241   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
242               std::unique_ptr<ExprAST> Body)
243       : Proto(std::move(Proto)), Body(std::move(Body)) {}
244   Function *codegen();
245 };
246 } // end anonymous namespace
247
248 //===----------------------------------------------------------------------===//
249 // Parser
250 //===----------------------------------------------------------------------===//
251
252 /// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
253 /// token the parser is looking at.  getNextToken reads another token from the
254 /// lexer and updates CurTok with its results.
255 static int CurTok;
256 static int getNextToken() { return CurTok = gettok(); }
257
258 /// BinopPrecedence - This holds the precedence for each binary operator that is
259 /// defined.
260 static std::map<char, int> BinopPrecedence;
261
262 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
263 static int GetTokPrecedence() {
264   if (!isascii(CurTok))
265     return -1;
266
267   // Make sure it's a declared binop.
268   int TokPrec = BinopPrecedence[CurTok];
269   if (TokPrec <= 0)
270     return -1;
271   return TokPrec;
272 }
273
274 /// Error* - These are little helper functions for error handling.
275 std::unique_ptr<ExprAST> Error(const char *Str) {
276   fprintf(stderr, "Error: %s\n", Str);
277   return nullptr;
278 }
279 std::unique_ptr<PrototypeAST> ErrorP(const char *Str) {
280   Error(Str);
281   return nullptr;
282 }
283
284 static std::unique_ptr<ExprAST> ParseExpression();
285
286 /// numberexpr ::= number
287 static std::unique_ptr<ExprAST> ParseNumberExpr() {
288   auto Result = llvm::make_unique<NumberExprAST>(NumVal);
289   getNextToken(); // consume the number
290   return std::move(Result);
291 }
292
293 /// parenexpr ::= '(' expression ')'
294 static std::unique_ptr<ExprAST> ParseParenExpr() {
295   getNextToken(); // eat (.
296   auto V = ParseExpression();
297   if (!V)
298     return nullptr;
299
300   if (CurTok != ')')
301     return Error("expected ')'");
302   getNextToken(); // eat ).
303   return V;
304 }
305
306 /// identifierexpr
307 ///   ::= identifier
308 ///   ::= identifier '(' expression* ')'
309 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
310   std::string IdName = IdentifierStr;
311
312   getNextToken(); // eat identifier.
313
314   if (CurTok != '(') // Simple variable ref.
315     return llvm::make_unique<VariableExprAST>(IdName);
316
317   // Call.
318   getNextToken(); // eat (
319   std::vector<std::unique_ptr<ExprAST>> Args;
320   if (CurTok != ')') {
321     while (1) {
322       if (auto Arg = ParseExpression())
323         Args.push_back(std::move(Arg));
324       else
325         return nullptr;
326
327       if (CurTok == ')')
328         break;
329
330       if (CurTok != ',')
331         return Error("Expected ')' or ',' in argument list");
332       getNextToken();
333     }
334   }
335
336   // Eat the ')'.
337   getNextToken();
338
339   return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
340 }
341
342 /// ifexpr ::= 'if' expression 'then' expression 'else' expression
343 static std::unique_ptr<ExprAST> ParseIfExpr() {
344   getNextToken(); // eat the if.
345
346   // condition.
347   auto Cond = ParseExpression();
348   if (!Cond)
349     return nullptr;
350
351   if (CurTok != tok_then)
352     return Error("expected then");
353   getNextToken(); // eat the then
354
355   auto Then = ParseExpression();
356   if (!Then)
357     return nullptr;
358
359   if (CurTok != tok_else)
360     return Error("expected else");
361
362   getNextToken();
363
364   auto Else = ParseExpression();
365   if (!Else)
366     return nullptr;
367
368   return llvm::make_unique<IfExprAST>(std::move(Cond), std::move(Then),
369                                       std::move(Else));
370 }
371
372 /// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
373 static std::unique_ptr<ExprAST> ParseForExpr() {
374   getNextToken(); // eat the for.
375
376   if (CurTok != tok_identifier)
377     return Error("expected identifier after for");
378
379   std::string IdName = IdentifierStr;
380   getNextToken(); // eat identifier.
381
382   if (CurTok != '=')
383     return Error("expected '=' after for");
384   getNextToken(); // eat '='.
385
386   auto Start = ParseExpression();
387   if (!Start)
388     return nullptr;
389   if (CurTok != ',')
390     return Error("expected ',' after for start value");
391   getNextToken();
392
393   auto End = ParseExpression();
394   if (!End)
395     return nullptr;
396
397   // The step value is optional.
398   std::unique_ptr<ExprAST> Step;
399   if (CurTok == ',') {
400     getNextToken();
401     Step = ParseExpression();
402     if (!Step)
403       return nullptr;
404   }
405
406   if (CurTok != tok_in)
407     return Error("expected 'in' after for");
408   getNextToken(); // eat 'in'.
409
410   auto Body = ParseExpression();
411   if (!Body)
412     return nullptr;
413
414   return llvm::make_unique<ForExprAST>(IdName, std::move(Start), std::move(End),
415                                        std::move(Step), std::move(Body));
416 }
417
418 /// primary
419 ///   ::= identifierexpr
420 ///   ::= numberexpr
421 ///   ::= parenexpr
422 ///   ::= ifexpr
423 ///   ::= forexpr
424 static std::unique_ptr<ExprAST> ParsePrimary() {
425   switch (CurTok) {
426   default:
427     return Error("unknown token when expecting an expression");
428   case tok_identifier:
429     return ParseIdentifierExpr();
430   case tok_number:
431     return ParseNumberExpr();
432   case '(':
433     return ParseParenExpr();
434   case tok_if:
435     return ParseIfExpr();
436   case tok_for:
437     return ParseForExpr();
438   }
439 }
440
441 /// unary
442 ///   ::= primary
443 ///   ::= '!' unary
444 static std::unique_ptr<ExprAST> ParseUnary() {
445   // If the current token is not an operator, it must be a primary expr.
446   if (!isascii(CurTok) || CurTok == '(' || CurTok == ',')
447     return ParsePrimary();
448
449   // If this is a unary operator, read it.
450   int Opc = CurTok;
451   getNextToken();
452   if (auto Operand = ParseUnary())
453     return llvm::make_unique<UnaryExprAST>(Opc, std::move(Operand));
454   return nullptr;
455 }
456
457 /// binoprhs
458 ///   ::= ('+' unary)*
459 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
460                                               std::unique_ptr<ExprAST> LHS) {
461   // If this is a binop, find its precedence.
462   while (1) {
463     int TokPrec = GetTokPrecedence();
464
465     // If this is a binop that binds at least as tightly as the current binop,
466     // consume it, otherwise we are done.
467     if (TokPrec < ExprPrec)
468       return LHS;
469
470     // Okay, we know this is a binop.
471     int BinOp = CurTok;
472     getNextToken(); // eat binop
473
474     // Parse the unary expression after the binary operator.
475     auto RHS = ParseUnary();
476     if (!RHS)
477       return nullptr;
478
479     // If BinOp binds less tightly with RHS than the operator after RHS, let
480     // the pending operator take RHS as its LHS.
481     int NextPrec = GetTokPrecedence();
482     if (TokPrec < NextPrec) {
483       RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
484       if (!RHS)
485         return nullptr;
486     }
487
488     // Merge LHS/RHS.
489     LHS =
490         llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
491   }
492 }
493
494 /// expression
495 ///   ::= unary binoprhs
496 ///
497 static std::unique_ptr<ExprAST> ParseExpression() {
498   auto LHS = ParseUnary();
499   if (!LHS)
500     return nullptr;
501
502   return ParseBinOpRHS(0, std::move(LHS));
503 }
504
505 /// prototype
506 ///   ::= id '(' id* ')'
507 ///   ::= binary LETTER number? (id, id)
508 ///   ::= unary LETTER (id)
509 static std::unique_ptr<PrototypeAST> ParsePrototype() {
510   std::string FnName;
511
512   unsigned Kind = 0; // 0 = identifier, 1 = unary, 2 = binary.
513   unsigned BinaryPrecedence = 30;
514
515   switch (CurTok) {
516   default:
517     return ErrorP("Expected function name in prototype");
518   case tok_identifier:
519     FnName = IdentifierStr;
520     Kind = 0;
521     getNextToken();
522     break;
523   case tok_unary:
524     getNextToken();
525     if (!isascii(CurTok))
526       return ErrorP("Expected unary operator");
527     FnName = "unary";
528     FnName += (char)CurTok;
529     Kind = 1;
530     getNextToken();
531     break;
532   case tok_binary:
533     getNextToken();
534     if (!isascii(CurTok))
535       return ErrorP("Expected binary operator");
536     FnName = "binary";
537     FnName += (char)CurTok;
538     Kind = 2;
539     getNextToken();
540
541     // Read the precedence if present.
542     if (CurTok == tok_number) {
543       if (NumVal < 1 || NumVal > 100)
544         return ErrorP("Invalid precedecnce: must be 1..100");
545       BinaryPrecedence = (unsigned)NumVal;
546       getNextToken();
547     }
548     break;
549   }
550
551   if (CurTok != '(')
552     return ErrorP("Expected '(' in prototype");
553
554   std::vector<std::string> ArgNames;
555   while (getNextToken() == tok_identifier)
556     ArgNames.push_back(IdentifierStr);
557   if (CurTok != ')')
558     return ErrorP("Expected ')' in prototype");
559
560   // success.
561   getNextToken(); // eat ')'.
562
563   // Verify right number of names for operator.
564   if (Kind && ArgNames.size() != Kind)
565     return ErrorP("Invalid number of operands for operator");
566
567   return llvm::make_unique<PrototypeAST>(FnName, ArgNames, Kind != 0,
568                                          BinaryPrecedence);
569 }
570
571 /// definition ::= 'def' prototype expression
572 static std::unique_ptr<FunctionAST> ParseDefinition() {
573   getNextToken(); // eat def.
574   auto Proto = ParsePrototype();
575   if (!Proto)
576     return nullptr;
577
578   if (auto E = ParseExpression())
579     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
580   return nullptr;
581 }
582
583 /// toplevelexpr ::= expression
584 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
585   if (auto E = ParseExpression()) {
586     // Make an anonymous proto.
587     auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
588                                                  std::vector<std::string>());
589     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
590   }
591   return nullptr;
592 }
593
594 /// external ::= 'extern' prototype
595 static std::unique_ptr<PrototypeAST> ParseExtern() {
596   getNextToken(); // eat extern.
597   return ParsePrototype();
598 }
599
600 //===----------------------------------------------------------------------===//
601 // Code Generation
602 //===----------------------------------------------------------------------===//
603
604 static std::unique_ptr<Module> TheModule;
605 static IRBuilder<> Builder(getGlobalContext());
606 static std::map<std::string, Value *> NamedValues;
607 static std::unique_ptr<legacy::FunctionPassManager> TheFPM;
608 static std::unique_ptr<KaleidoscopeJIT> TheJIT;
609 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
610
611 Value *ErrorV(const char *Str) {
612   Error(Str);
613   return nullptr;
614 }
615
616 Function *getFunction(std::string Name) {
617   // First, see if the function has already been added to the current module.
618   if (auto *F = TheModule->getFunction(Name))
619     return F;
620
621   // If not, check whether we can codegen the declaration from some existing
622   // prototype.
623   auto FI = FunctionProtos.find(Name);
624   if (FI != FunctionProtos.end())
625     return FI->second->codegen();
626
627   // If no existing prototype exists, return null.
628   return nullptr;
629 }
630
631 Value *NumberExprAST::codegen() {
632   return ConstantFP::get(getGlobalContext(), APFloat(Val));
633 }
634
635 Value *VariableExprAST::codegen() {
636   // Look this variable up in the function.
637   Value *V = NamedValues[Name];
638   if (!V)
639     return ErrorV("Unknown variable name");
640   return V;
641 }
642
643 Value *UnaryExprAST::codegen() {
644   Value *OperandV = Operand->codegen();
645   if (!OperandV)
646     return nullptr;
647
648   Function *F = getFunction(std::string("unary") + Opcode);
649   if (!F)
650     return ErrorV("Unknown unary operator");
651
652   return Builder.CreateCall(F, OperandV, "unop");
653 }
654
655 Value *BinaryExprAST::codegen() {
656   Value *L = LHS->codegen();
657   Value *R = RHS->codegen();
658   if (!L || !R)
659     return nullptr;
660
661   switch (Op) {
662   case '+':
663     return Builder.CreateFAdd(L, R, "addtmp");
664   case '-':
665     return Builder.CreateFSub(L, R, "subtmp");
666   case '*':
667     return Builder.CreateFMul(L, R, "multmp");
668   case '<':
669     L = Builder.CreateFCmpULT(L, R, "cmptmp");
670     // Convert bool 0/1 to double 0.0 or 1.0
671     return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()),
672                                 "booltmp");
673   default:
674     break;
675   }
676
677   // If it wasn't a builtin binary operator, it must be a user defined one. Emit
678   // a call to it.
679   Function *F = getFunction(std::string("binary") + Op);
680   assert(F && "binary operator not found!");
681
682   Value *Ops[] = {L, R};
683   return Builder.CreateCall(F, Ops, "binop");
684 }
685
686 Value *CallExprAST::codegen() {
687   // Look up the name in the global module table.
688   Function *CalleeF = getFunction(Callee);
689   if (!CalleeF)
690     return ErrorV("Unknown function referenced");
691
692   // If argument mismatch error.
693   if (CalleeF->arg_size() != Args.size())
694     return ErrorV("Incorrect # arguments passed");
695
696   std::vector<Value *> ArgsV;
697   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
698     ArgsV.push_back(Args[i]->codegen());
699     if (!ArgsV.back())
700       return nullptr;
701   }
702
703   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
704 }
705
706 Value *IfExprAST::codegen() {
707   Value *CondV = Cond->codegen();
708   if (!CondV)
709     return nullptr;
710
711   // Convert condition to a bool by comparing equal to 0.0.
712   CondV = Builder.CreateFCmpONE(
713       CondV, ConstantFP::get(getGlobalContext(), APFloat(0.0)), "ifcond");
714
715   Function *TheFunction = Builder.GetInsertBlock()->getParent();
716
717   // Create blocks for the then and else cases.  Insert the 'then' block at the
718   // end of the function.
719   BasicBlock *ThenBB =
720       BasicBlock::Create(getGlobalContext(), "then", TheFunction);
721   BasicBlock *ElseBB = BasicBlock::Create(getGlobalContext(), "else");
722   BasicBlock *MergeBB = BasicBlock::Create(getGlobalContext(), "ifcont");
723
724   Builder.CreateCondBr(CondV, ThenBB, ElseBB);
725
726   // Emit then value.
727   Builder.SetInsertPoint(ThenBB);
728
729   Value *ThenV = Then->codegen();
730   if (!ThenV)
731     return nullptr;
732
733   Builder.CreateBr(MergeBB);
734   // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
735   ThenBB = Builder.GetInsertBlock();
736
737   // Emit else block.
738   TheFunction->getBasicBlockList().push_back(ElseBB);
739   Builder.SetInsertPoint(ElseBB);
740
741   Value *ElseV = Else->codegen();
742   if (!ElseV)
743     return nullptr;
744
745   Builder.CreateBr(MergeBB);
746   // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
747   ElseBB = Builder.GetInsertBlock();
748
749   // Emit merge block.
750   TheFunction->getBasicBlockList().push_back(MergeBB);
751   Builder.SetInsertPoint(MergeBB);
752   PHINode *PN =
753       Builder.CreatePHI(Type::getDoubleTy(getGlobalContext()), 2, "iftmp");
754
755   PN->addIncoming(ThenV, ThenBB);
756   PN->addIncoming(ElseV, ElseBB);
757   return PN;
758 }
759
760 // Output for-loop as:
761 //   ...
762 //   start = startexpr
763 //   goto loop
764 // loop:
765 //   variable = phi [start, loopheader], [nextvariable, loopend]
766 //   ...
767 //   bodyexpr
768 //   ...
769 // loopend:
770 //   step = stepexpr
771 //   nextvariable = variable + step
772 //   endcond = endexpr
773 //   br endcond, loop, endloop
774 // outloop:
775 Value *ForExprAST::codegen() {
776   // Emit the start code first, without 'variable' in scope.
777   Value *StartVal = Start->codegen();
778   if (!StartVal)
779     return nullptr;
780
781   // Make the new basic block for the loop header, inserting after current
782   // block.
783   Function *TheFunction = Builder.GetInsertBlock()->getParent();
784   BasicBlock *PreheaderBB = Builder.GetInsertBlock();
785   BasicBlock *LoopBB =
786       BasicBlock::Create(getGlobalContext(), "loop", TheFunction);
787
788   // Insert an explicit fall through from the current block to the LoopBB.
789   Builder.CreateBr(LoopBB);
790
791   // Start insertion in LoopBB.
792   Builder.SetInsertPoint(LoopBB);
793
794   // Start the PHI node with an entry for Start.
795   PHINode *Variable = Builder.CreatePHI(Type::getDoubleTy(getGlobalContext()),
796                                         2, VarName.c_str());
797   Variable->addIncoming(StartVal, PreheaderBB);
798
799   // Within the loop, the variable is defined equal to the PHI node.  If it
800   // shadows an existing variable, we have to restore it, so save it now.
801   Value *OldVal = NamedValues[VarName];
802   NamedValues[VarName] = Variable;
803
804   // Emit the body of the loop.  This, like any other expr, can change the
805   // current BB.  Note that we ignore the value computed by the body, but don't
806   // allow an error.
807   if (!Body->codegen())
808     return nullptr;
809
810   // Emit the step value.
811   Value *StepVal = nullptr;
812   if (Step) {
813     StepVal = Step->codegen();
814     if (!StepVal)
815       return nullptr;
816   } else {
817     // If not specified, use 1.0.
818     StepVal = ConstantFP::get(getGlobalContext(), APFloat(1.0));
819   }
820
821   Value *NextVar = Builder.CreateFAdd(Variable, StepVal, "nextvar");
822
823   // Compute the end condition.
824   Value *EndCond = End->codegen();
825   if (!EndCond)
826     return nullptr;
827
828   // Convert condition to a bool by comparing equal to 0.0.
829   EndCond = Builder.CreateFCmpONE(
830       EndCond, ConstantFP::get(getGlobalContext(), APFloat(0.0)), "loopcond");
831
832   // Create the "after loop" block and insert it.
833   BasicBlock *LoopEndBB = Builder.GetInsertBlock();
834   BasicBlock *AfterBB =
835       BasicBlock::Create(getGlobalContext(), "afterloop", TheFunction);
836
837   // Insert the conditional branch into the end of LoopEndBB.
838   Builder.CreateCondBr(EndCond, LoopBB, AfterBB);
839
840   // Any new code will be inserted in AfterBB.
841   Builder.SetInsertPoint(AfterBB);
842
843   // Add a new entry to the PHI node for the backedge.
844   Variable->addIncoming(NextVar, LoopEndBB);
845
846   // Restore the unshadowed variable.
847   if (OldVal)
848     NamedValues[VarName] = OldVal;
849   else
850     NamedValues.erase(VarName);
851
852   // for expr always returns 0.0.
853   return Constant::getNullValue(Type::getDoubleTy(getGlobalContext()));
854 }
855
856 Function *PrototypeAST::codegen() {
857   // Make the function type:  double(double,double) etc.
858   std::vector<Type *> Doubles(Args.size(),
859                               Type::getDoubleTy(getGlobalContext()));
860   FunctionType *FT =
861       FunctionType::get(Type::getDoubleTy(getGlobalContext()), Doubles, false);
862
863   Function *F =
864       Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
865
866   // Set names for all arguments.
867   unsigned Idx = 0;
868   for (auto &Arg : F->args())
869     Arg.setName(Args[Idx++]);
870
871   return F;
872 }
873
874 Function *FunctionAST::codegen() {
875   // Transfer ownership of the prototype to the FunctionProtos map, but keep a
876   // reference to it for use below.
877   auto &P = *Proto;
878   FunctionProtos[Proto->getName()] = std::move(Proto);
879   Function *TheFunction = getFunction(P.getName());
880   if (!TheFunction)
881     return nullptr;
882
883   // If this is an operator, install it.
884   if (P.isBinaryOp())
885     BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
886
887   // Create a new basic block to start insertion into.
888   BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
889   Builder.SetInsertPoint(BB);
890
891   // Record the function arguments in the NamedValues map.
892   NamedValues.clear();
893   for (auto &Arg : TheFunction->args())
894     NamedValues[Arg.getName()] = &Arg;
895
896   if (Value *RetVal = Body->codegen()) {
897     // Finish off the function.
898     Builder.CreateRet(RetVal);
899
900     // Validate the generated code, checking for consistency.
901     verifyFunction(*TheFunction);
902
903     // Run the optimizer on the function.
904     TheFPM->run(*TheFunction);
905
906     return TheFunction;
907   }
908
909   // Error reading body, remove function.
910   TheFunction->eraseFromParent();
911
912   if (P.isBinaryOp())
913     BinopPrecedence.erase(Proto->getOperatorName());
914   return nullptr;
915 }
916
917 //===----------------------------------------------------------------------===//
918 // Top-Level parsing and JIT Driver
919 //===----------------------------------------------------------------------===//
920
921 static void InitializeModuleAndPassManager() {
922   // Open a new module.
923   TheModule = llvm::make_unique<Module>("my cool jit", getGlobalContext());
924   TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
925
926   // Create a new pass manager attached to it.
927   TheFPM = llvm::make_unique<legacy::FunctionPassManager>(TheModule.get());
928
929   // Provide basic AliasAnalysis support for GVN.
930   TheFPM->add(createBasicAliasAnalysisPass());
931   // Do simple "peephole" optimizations and bit-twiddling optzns.
932   TheFPM->add(createInstructionCombiningPass());
933   // Reassociate expressions.
934   TheFPM->add(createReassociatePass());
935   // Eliminate Common SubExpressions.
936   TheFPM->add(createGVNPass());
937   // Simplify the control flow graph (deleting unreachable blocks, etc).
938   TheFPM->add(createCFGSimplificationPass());
939
940   TheFPM->doInitialization();
941 }
942
943 static void HandleDefinition() {
944   if (auto FnAST = ParseDefinition()) {
945     if (auto *FnIR = FnAST->codegen()) {
946       fprintf(stderr, "Read function definition:");
947       FnIR->dump();
948       TheJIT->addModule(std::move(TheModule));
949       InitializeModuleAndPassManager();
950     }
951   } else {
952     // Skip token for error recovery.
953     getNextToken();
954   }
955 }
956
957 static void HandleExtern() {
958   if (auto ProtoAST = ParseExtern()) {
959     if (auto *FnIR = ProtoAST->codegen()) {
960       fprintf(stderr, "Read extern: ");
961       FnIR->dump();
962       FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
963     }
964   } else {
965     // Skip token for error recovery.
966     getNextToken();
967   }
968 }
969
970 static void HandleTopLevelExpression() {
971   // Evaluate a top-level expression into an anonymous function.
972   if (auto FnAST = ParseTopLevelExpr()) {
973     if (FnAST->codegen()) {
974
975       // JIT the module containing the anonymous expression, keeping a handle so
976       // we can free it later.
977       auto H = TheJIT->addModule(std::move(TheModule));
978       InitializeModuleAndPassManager();
979
980       // Search the JIT for the __anon_expr symbol.
981       auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
982       assert(ExprSymbol && "Function not found");
983
984       // Get the symbol's address and cast it to the right type (takes no
985       // arguments, returns a double) so we can call it as a native function.
986       double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
987       fprintf(stderr, "Evaluated to %f\n", FP());
988
989       // Delete the anonymous expression module from the JIT.
990       TheJIT->removeModule(H);
991     }
992   } else {
993     // Skip token for error recovery.
994     getNextToken();
995   }
996 }
997
998 /// top ::= definition | external | expression | ';'
999 static void MainLoop() {
1000   while (1) {
1001     fprintf(stderr, "ready> ");
1002     switch (CurTok) {
1003     case tok_eof:
1004       return;
1005     case ';': // ignore top-level semicolons.
1006       getNextToken();
1007       break;
1008     case tok_def:
1009       HandleDefinition();
1010       break;
1011     case tok_extern:
1012       HandleExtern();
1013       break;
1014     default:
1015       HandleTopLevelExpression();
1016       break;
1017     }
1018   }
1019 }
1020
1021 //===----------------------------------------------------------------------===//
1022 // "Library" functions that can be "extern'd" from user code.
1023 //===----------------------------------------------------------------------===//
1024
1025 /// putchard - putchar that takes a double and returns 0.
1026 __attribute__((used)) extern "C" double putchard(double X) {
1027   fputc((char)X, stderr);
1028   return 0;
1029 }
1030
1031 /// printd - printf that takes a double prints it as "%f\n", returning 0.
1032 __attribute__((used)) extern "C" double printd(double X) {
1033   fprintf(stderr, "%f\n", X);
1034   return 0;
1035 }
1036
1037 //===----------------------------------------------------------------------===//
1038 // Main driver code.
1039 //===----------------------------------------------------------------------===//
1040
1041 int main() {
1042   InitializeNativeTarget();
1043   InitializeNativeTargetAsmPrinter();
1044   InitializeNativeTargetAsmParser();
1045
1046   // Install standard binary operators.
1047   // 1 is lowest precedence.
1048   BinopPrecedence['<'] = 10;
1049   BinopPrecedence['+'] = 20;
1050   BinopPrecedence['-'] = 20;
1051   BinopPrecedence['*'] = 40; // highest.
1052
1053   // Prime the first token.
1054   fprintf(stderr, "ready> ");
1055   getNextToken();
1056
1057   TheJIT = llvm::make_unique<KaleidoscopeJIT>();
1058
1059   InitializeModuleAndPassManager();
1060
1061   // Run the main "interpreter loop" now.
1062   MainLoop();
1063
1064   return 0;
1065 }