789caabce0ff757afa98db6b15b2ed31ec861e04
[oota-llvm.git] / examples / Kaleidoscope / Chapter2 / toy.cpp
1 #include "llvm/ADT/STLExtras.h"
2 #include <cctype>
3 #include <cstdio>
4 #include <map>
5 #include <string>
6 #include <vector>
7
8 //===----------------------------------------------------------------------===//
9 // Lexer
10 //===----------------------------------------------------------------------===//
11
12 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
13 // of these for known things.
14 enum Token {
15   tok_eof = -1,
16
17   // commands
18   tok_def = -2, tok_extern = -3,
19
20   // primary
21   tok_identifier = -4, tok_number = -5
22 };
23
24 static std::string IdentifierStr;  // Filled in if tok_identifier
25 static double NumVal;              // Filled in if tok_number
26
27 /// gettok - Return the next token from standard input.
28 static int gettok() {
29   static int LastChar = ' ';
30
31   // Skip any whitespace.
32   while (isspace(LastChar))
33     LastChar = getchar();
34
35   if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
36     IdentifierStr = LastChar;
37     while (isalnum((LastChar = getchar())))
38       IdentifierStr += LastChar;
39
40     if (IdentifierStr == "def") return tok_def;
41     if (IdentifierStr == "extern") return tok_extern;
42     return tok_identifier;
43   }
44
45   if (isdigit(LastChar) || LastChar == '.') {   // Number: [0-9.]+
46     std::string NumStr;
47     do {
48       NumStr += LastChar;
49       LastChar = getchar();
50     } while (isdigit(LastChar) || LastChar == '.');
51
52     NumVal = strtod(NumStr.c_str(), 0);
53     return tok_number;
54   }
55
56   if (LastChar == '#') {
57     // Comment until end of line.
58     do LastChar = getchar();
59     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
60     
61     if (LastChar != EOF)
62       return gettok();
63   }
64   
65   // Check for end of file.  Don't eat the EOF.
66   if (LastChar == EOF)
67     return tok_eof;
68
69   // Otherwise, just return the character as its ascii value.
70   int ThisChar = LastChar;
71   LastChar = getchar();
72   return ThisChar;
73 }
74
75 //===----------------------------------------------------------------------===//
76 // Abstract Syntax Tree (aka Parse Tree)
77 //===----------------------------------------------------------------------===//
78 namespace {
79 /// ExprAST - Base class for all expression nodes.
80 class ExprAST {
81 public:
82   virtual ~ExprAST() {}
83 };
84
85 /// NumberExprAST - Expression class for numeric literals like "1.0".
86 class NumberExprAST : public ExprAST {
87 public:
88   NumberExprAST(double Val) {}
89 };
90
91 /// VariableExprAST - Expression class for referencing a variable, like "a".
92 class VariableExprAST : public ExprAST {
93   std::string Name;
94 public:
95   VariableExprAST(const std::string &Name) : Name(Name) {}
96 };
97
98 /// BinaryExprAST - Expression class for a binary operator.
99 class BinaryExprAST : public ExprAST {
100 public:
101   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
102                 std::unique_ptr<ExprAST> RHS) {}
103 };
104
105 /// CallExprAST - Expression class for function calls.
106 class CallExprAST : public ExprAST {
107   std::string Callee;
108   std::vector<std::unique_ptr<ExprAST>> Args;
109 public:
110   CallExprAST(const std::string &Callee,
111               std::vector<std::unique_ptr<ExprAST>> Args)
112     : Callee(Callee), Args(std::move(Args)) {}
113 };
114
115 /// PrototypeAST - This class represents the "prototype" for a function,
116 /// which captures its name, and its argument names (thus implicitly the number
117 /// of arguments the function takes).
118 class PrototypeAST {
119   std::string Name;
120   std::vector<std::string> Args;
121 public:
122   PrototypeAST(const std::string &Name, std::vector<std::string> Args)
123     : Name(Name), Args(std::move(Args)) {}
124   
125 };
126
127 /// FunctionAST - This class represents a function definition itself.
128 class FunctionAST {
129 public:
130   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
131               std::unique_ptr<ExprAST> Body) {}
132 };
133 } // end anonymous namespace
134
135 //===----------------------------------------------------------------------===//
136 // Parser
137 //===----------------------------------------------------------------------===//
138
139 /// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
140 /// token the parser is looking at.  getNextToken reads another token from the
141 /// lexer and updates CurTok with its results.
142 static int CurTok;
143 static int getNextToken() {
144   return CurTok = gettok();
145 }
146
147 /// BinopPrecedence - This holds the precedence for each binary operator that is
148 /// defined.
149 static std::map<char, int> BinopPrecedence;
150
151 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
152 static int GetTokPrecedence() {
153   if (!isascii(CurTok))
154     return -1;
155   
156   // Make sure it's a declared binop.
157   int TokPrec = BinopPrecedence[CurTok];
158   if (TokPrec <= 0) return -1;
159   return TokPrec;
160 }
161
162 /// Error* - These are little helper functions for error handling.
163 std::unique_ptr<ExprAST> Error(const char *Str) {
164   fprintf(stderr, "Error: %s\n", Str);
165   return nullptr;
166 }
167 std::unique_ptr<PrototypeAST> ErrorP(const char *Str) {
168   Error(Str);
169   return nullptr;
170 }
171
172 static std::unique_ptr<ExprAST> ParseExpression();
173
174 /// identifierexpr
175 ///   ::= identifier
176 ///   ::= identifier '(' expression* ')'
177 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
178   std::string IdName = IdentifierStr;
179   
180   getNextToken();  // eat identifier.
181   
182   if (CurTok != '(') // Simple variable ref.
183     return llvm::make_unique<VariableExprAST>(IdName);
184   
185   // Call.
186   getNextToken();  // eat (
187   std::vector<std::unique_ptr<ExprAST>> Args;
188   if (CurTok != ')') {
189     while (1) {
190       if (auto Arg = ParseExpression())
191         Args.push_back(std::move(Arg));
192       else
193         return nullptr;
194
195       if (CurTok == ')') break;
196
197       if (CurTok != ',')
198         return Error("Expected ')' or ',' in argument list");
199       getNextToken();
200     }
201   }
202
203   // Eat the ')'.
204   getNextToken();
205   
206   return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
207 }
208
209 /// numberexpr ::= number
210 static std::unique_ptr<ExprAST> ParseNumberExpr() {
211   auto Result = llvm::make_unique<NumberExprAST>(NumVal);
212   getNextToken(); // consume the number
213   return std::move(Result);
214 }
215
216 /// parenexpr ::= '(' expression ')'
217 static std::unique_ptr<ExprAST> ParseParenExpr() {
218   getNextToken();  // eat (.
219   auto V = ParseExpression();
220   if (!V)
221     return nullptr;
222   
223   if (CurTok != ')')
224     return Error("expected ')'");
225   getNextToken();  // eat ).
226   return V;
227 }
228
229 /// primary
230 ///   ::= identifierexpr
231 ///   ::= numberexpr
232 ///   ::= parenexpr
233 static std::unique_ptr<ExprAST> ParsePrimary() {
234   switch (CurTok) {
235   default: return Error("unknown token when expecting an expression");
236   case tok_identifier: return ParseIdentifierExpr();
237   case tok_number:     return ParseNumberExpr();
238   case '(':            return ParseParenExpr();
239   }
240 }
241
242 /// binoprhs
243 ///   ::= ('+' primary)*
244 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
245                                               std::unique_ptr<ExprAST> LHS) {
246   // If this is a binop, find its precedence.
247   while (1) {
248     int TokPrec = GetTokPrecedence();
249     
250     // If this is a binop that binds at least as tightly as the current binop,
251     // consume it, otherwise we are done.
252     if (TokPrec < ExprPrec)
253       return LHS;
254     
255     // Okay, we know this is a binop.
256     int BinOp = CurTok;
257     getNextToken();  // eat binop
258     
259     // Parse the primary expression after the binary operator.
260     auto RHS = ParsePrimary();
261     if (!RHS) return nullptr;
262     
263     // If BinOp binds less tightly with RHS than the operator after RHS, let
264     // the pending operator take RHS as its LHS.
265     int NextPrec = GetTokPrecedence();
266     if (TokPrec < NextPrec) {
267       RHS = ParseBinOpRHS(TokPrec+1, std::move(RHS));
268       if (!RHS) return nullptr;
269     }
270     
271     // Merge LHS/RHS.
272     LHS = llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS),
273                                            std::move(RHS));
274   }
275 }
276
277 /// expression
278 ///   ::= primary binoprhs
279 ///
280 static std::unique_ptr<ExprAST> ParseExpression() {
281   auto LHS = ParsePrimary();
282   if (!LHS) return nullptr;
283   
284   return ParseBinOpRHS(0, std::move(LHS));
285 }
286
287 /// prototype
288 ///   ::= id '(' id* ')'
289 static std::unique_ptr<PrototypeAST> ParsePrototype() {
290   if (CurTok != tok_identifier)
291     return ErrorP("Expected function name in prototype");
292
293   std::string FnName = IdentifierStr;
294   getNextToken();
295   
296   if (CurTok != '(')
297     return ErrorP("Expected '(' in prototype");
298   
299   std::vector<std::string> ArgNames;
300   while (getNextToken() == tok_identifier)
301     ArgNames.push_back(IdentifierStr);
302   if (CurTok != ')')
303     return ErrorP("Expected ')' in prototype");
304   
305   // success.
306   getNextToken();  // eat ')'.
307   
308   return llvm::make_unique<PrototypeAST>(std::move(FnName),
309                                          std::move(ArgNames));
310 }
311
312 /// definition ::= 'def' prototype expression
313 static std::unique_ptr<FunctionAST> ParseDefinition() {
314   getNextToken();  // eat def.
315   auto Proto = ParsePrototype();
316   if (!Proto) return nullptr;
317
318   if (auto E = ParseExpression())
319     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
320   return nullptr;
321 }
322
323 /// toplevelexpr ::= expression
324 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
325   if (auto E = ParseExpression()) {
326     // Make an anonymous proto.
327     auto Proto = llvm::make_unique<PrototypeAST>("",
328                                                  std::vector<std::string>());
329     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
330   }
331   return nullptr;
332 }
333
334 /// external ::= 'extern' prototype
335 static std::unique_ptr<PrototypeAST> ParseExtern() {
336   getNextToken();  // eat extern.
337   return ParsePrototype();
338 }
339
340 //===----------------------------------------------------------------------===//
341 // Top-Level parsing
342 //===----------------------------------------------------------------------===//
343
344 static void HandleDefinition() {
345   if (ParseDefinition()) {
346     fprintf(stderr, "Parsed a function definition.\n");
347   } else {
348     // Skip token for error recovery.
349     getNextToken();
350   }
351 }
352
353 static void HandleExtern() {
354   if (ParseExtern()) {
355     fprintf(stderr, "Parsed an extern\n");
356   } else {
357     // Skip token for error recovery.
358     getNextToken();
359   }
360 }
361
362 static void HandleTopLevelExpression() {
363   // Evaluate a top-level expression into an anonymous function.
364   if (ParseTopLevelExpr()) {
365     fprintf(stderr, "Parsed a top-level expr\n");
366   } else {
367     // Skip token for error recovery.
368     getNextToken();
369   }
370 }
371
372 /// top ::= definition | external | expression | ';'
373 static void MainLoop() {
374   while (1) {
375     fprintf(stderr, "ready> ");
376     switch (CurTok) {
377     case tok_eof:    return;
378     case ';':        getNextToken(); break;  // ignore top-level semicolons.
379     case tok_def:    HandleDefinition(); break;
380     case tok_extern: HandleExtern(); break;
381     default:         HandleTopLevelExpression(); break;
382     }
383   }
384 }
385
386 //===----------------------------------------------------------------------===//
387 // Main driver code.
388 //===----------------------------------------------------------------------===//
389
390 int main() {
391   // Install standard binary operators.
392   // 1 is lowest precedence.
393   BinopPrecedence['<'] = 10;
394   BinopPrecedence['+'] = 20;
395   BinopPrecedence['-'] = 20;
396   BinopPrecedence['*'] = 40;  // highest.
397
398   // Prime the first token.
399   fprintf(stderr, "ready> ");
400   getNextToken();
401
402   // Run the main "interpreter loop" now.
403   MainLoop();
404
405   return 0;
406 }