Run DCE after a LoopFlatten test to reduce spurious output [nfc]
[llvm-project.git] / llvm / examples / Kaleidoscope / Chapter3 / toy.cpp
blob03563006685addb2db7a31704075b72334036645
1 #include "llvm/ADT/APFloat.h"
2 #include "llvm/ADT/STLExtras.h"
3 #include "llvm/IR/BasicBlock.h"
4 #include "llvm/IR/Constants.h"
5 #include "llvm/IR/DerivedTypes.h"
6 #include "llvm/IR/Function.h"
7 #include "llvm/IR/IRBuilder.h"
8 #include "llvm/IR/LLVMContext.h"
9 #include "llvm/IR/Module.h"
10 #include "llvm/IR/Type.h"
11 #include "llvm/IR/Verifier.h"
12 #include <algorithm>
13 #include <cctype>
14 #include <cstdio>
15 #include <cstdlib>
16 #include <map>
17 #include <memory>
18 #include <string>
19 #include <vector>
21 using namespace llvm;
23 //===----------------------------------------------------------------------===//
24 // Lexer
25 //===----------------------------------------------------------------------===//
27 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
28 // of these for known things.
29 enum Token {
30 tok_eof = -1,
32 // commands
33 tok_def = -2,
34 tok_extern = -3,
36 // primary
37 tok_identifier = -4,
38 tok_number = -5
41 static std::string IdentifierStr; // Filled in if tok_identifier
42 static double NumVal; // Filled in if tok_number
44 /// gettok - Return the next token from standard input.
45 static int gettok() {
46 static int LastChar = ' ';
48 // Skip any whitespace.
49 while (isspace(LastChar))
50 LastChar = getchar();
52 if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
53 IdentifierStr = LastChar;
54 while (isalnum((LastChar = getchar())))
55 IdentifierStr += LastChar;
57 if (IdentifierStr == "def")
58 return tok_def;
59 if (IdentifierStr == "extern")
60 return tok_extern;
61 return tok_identifier;
64 if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
65 std::string NumStr;
66 do {
67 NumStr += LastChar;
68 LastChar = getchar();
69 } while (isdigit(LastChar) || LastChar == '.');
71 NumVal = strtod(NumStr.c_str(), nullptr);
72 return tok_number;
75 if (LastChar == '#') {
76 // Comment until end of line.
78 LastChar = getchar();
79 while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
81 if (LastChar != EOF)
82 return gettok();
85 // Check for end of file. Don't eat the EOF.
86 if (LastChar == EOF)
87 return tok_eof;
89 // Otherwise, just return the character as its ascii value.
90 int ThisChar = LastChar;
91 LastChar = getchar();
92 return ThisChar;
95 //===----------------------------------------------------------------------===//
96 // Abstract Syntax Tree (aka Parse Tree)
97 //===----------------------------------------------------------------------===//
99 namespace {
101 /// ExprAST - Base class for all expression nodes.
102 class ExprAST {
103 public:
104 virtual ~ExprAST() = default;
106 virtual Value *codegen() = 0;
109 /// NumberExprAST - Expression class for numeric literals like "1.0".
110 class NumberExprAST : public ExprAST {
111 double Val;
113 public:
114 NumberExprAST(double Val) : Val(Val) {}
116 Value *codegen() override;
119 /// VariableExprAST - Expression class for referencing a variable, like "a".
120 class VariableExprAST : public ExprAST {
121 std::string Name;
123 public:
124 VariableExprAST(const std::string &Name) : Name(Name) {}
126 Value *codegen() override;
129 /// BinaryExprAST - Expression class for a binary operator.
130 class BinaryExprAST : public ExprAST {
131 char Op;
132 std::unique_ptr<ExprAST> LHS, RHS;
134 public:
135 BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
136 std::unique_ptr<ExprAST> RHS)
137 : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
139 Value *codegen() override;
142 /// CallExprAST - Expression class for function calls.
143 class CallExprAST : public ExprAST {
144 std::string Callee;
145 std::vector<std::unique_ptr<ExprAST>> Args;
147 public:
148 CallExprAST(const std::string &Callee,
149 std::vector<std::unique_ptr<ExprAST>> Args)
150 : Callee(Callee), Args(std::move(Args)) {}
152 Value *codegen() override;
155 /// PrototypeAST - This class represents the "prototype" for a function,
156 /// which captures its name, and its argument names (thus implicitly the number
157 /// of arguments the function takes).
158 class PrototypeAST {
159 std::string Name;
160 std::vector<std::string> Args;
162 public:
163 PrototypeAST(const std::string &Name, std::vector<std::string> Args)
164 : Name(Name), Args(std::move(Args)) {}
166 Function *codegen();
167 const std::string &getName() const { return Name; }
170 /// FunctionAST - This class represents a function definition itself.
171 class FunctionAST {
172 std::unique_ptr<PrototypeAST> Proto;
173 std::unique_ptr<ExprAST> Body;
175 public:
176 FunctionAST(std::unique_ptr<PrototypeAST> Proto,
177 std::unique_ptr<ExprAST> Body)
178 : Proto(std::move(Proto)), Body(std::move(Body)) {}
180 Function *codegen();
183 } // end anonymous namespace
185 //===----------------------------------------------------------------------===//
186 // Parser
187 //===----------------------------------------------------------------------===//
189 /// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current
190 /// token the parser is looking at. getNextToken reads another token from the
191 /// lexer and updates CurTok with its results.
192 static int CurTok;
193 static int getNextToken() { return CurTok = gettok(); }
195 /// BinopPrecedence - This holds the precedence for each binary operator that is
196 /// defined.
197 static std::map<char, int> BinopPrecedence;
199 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
200 static int GetTokPrecedence() {
201 if (!isascii(CurTok))
202 return -1;
204 // Make sure it's a declared binop.
205 int TokPrec = BinopPrecedence[CurTok];
206 if (TokPrec <= 0)
207 return -1;
208 return TokPrec;
211 /// LogError* - These are little helper functions for error handling.
212 std::unique_ptr<ExprAST> LogError(const char *Str) {
213 fprintf(stderr, "Error: %s\n", Str);
214 return nullptr;
217 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
218 LogError(Str);
219 return nullptr;
222 static std::unique_ptr<ExprAST> ParseExpression();
224 /// numberexpr ::= number
225 static std::unique_ptr<ExprAST> ParseNumberExpr() {
226 auto Result = std::make_unique<NumberExprAST>(NumVal);
227 getNextToken(); // consume the number
228 return std::move(Result);
231 /// parenexpr ::= '(' expression ')'
232 static std::unique_ptr<ExprAST> ParseParenExpr() {
233 getNextToken(); // eat (.
234 auto V = ParseExpression();
235 if (!V)
236 return nullptr;
238 if (CurTok != ')')
239 return LogError("expected ')'");
240 getNextToken(); // eat ).
241 return V;
244 /// identifierexpr
245 /// ::= identifier
246 /// ::= identifier '(' expression* ')'
247 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
248 std::string IdName = IdentifierStr;
250 getNextToken(); // eat identifier.
252 if (CurTok != '(') // Simple variable ref.
253 return std::make_unique<VariableExprAST>(IdName);
255 // Call.
256 getNextToken(); // eat (
257 std::vector<std::unique_ptr<ExprAST>> Args;
258 if (CurTok != ')') {
259 while (true) {
260 if (auto Arg = ParseExpression())
261 Args.push_back(std::move(Arg));
262 else
263 return nullptr;
265 if (CurTok == ')')
266 break;
268 if (CurTok != ',')
269 return LogError("Expected ')' or ',' in argument list");
270 getNextToken();
274 // Eat the ')'.
275 getNextToken();
277 return std::make_unique<CallExprAST>(IdName, std::move(Args));
280 /// primary
281 /// ::= identifierexpr
282 /// ::= numberexpr
283 /// ::= parenexpr
284 static std::unique_ptr<ExprAST> ParsePrimary() {
285 switch (CurTok) {
286 default:
287 return LogError("unknown token when expecting an expression");
288 case tok_identifier:
289 return ParseIdentifierExpr();
290 case tok_number:
291 return ParseNumberExpr();
292 case '(':
293 return ParseParenExpr();
297 /// binoprhs
298 /// ::= ('+' primary)*
299 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
300 std::unique_ptr<ExprAST> LHS) {
301 // If this is a binop, find its precedence.
302 while (true) {
303 int TokPrec = GetTokPrecedence();
305 // If this is a binop that binds at least as tightly as the current binop,
306 // consume it, otherwise we are done.
307 if (TokPrec < ExprPrec)
308 return LHS;
310 // Okay, we know this is a binop.
311 int BinOp = CurTok;
312 getNextToken(); // eat binop
314 // Parse the primary expression after the binary operator.
315 auto RHS = ParsePrimary();
316 if (!RHS)
317 return nullptr;
319 // If BinOp binds less tightly with RHS than the operator after RHS, let
320 // the pending operator take RHS as its LHS.
321 int NextPrec = GetTokPrecedence();
322 if (TokPrec < NextPrec) {
323 RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
324 if (!RHS)
325 return nullptr;
328 // Merge LHS/RHS.
329 LHS =
330 std::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
334 /// expression
335 /// ::= primary binoprhs
337 static std::unique_ptr<ExprAST> ParseExpression() {
338 auto LHS = ParsePrimary();
339 if (!LHS)
340 return nullptr;
342 return ParseBinOpRHS(0, std::move(LHS));
345 /// prototype
346 /// ::= id '(' id* ')'
347 static std::unique_ptr<PrototypeAST> ParsePrototype() {
348 if (CurTok != tok_identifier)
349 return LogErrorP("Expected function name in prototype");
351 std::string FnName = IdentifierStr;
352 getNextToken();
354 if (CurTok != '(')
355 return LogErrorP("Expected '(' in prototype");
357 std::vector<std::string> ArgNames;
358 while (getNextToken() == tok_identifier)
359 ArgNames.push_back(IdentifierStr);
360 if (CurTok != ')')
361 return LogErrorP("Expected ')' in prototype");
363 // success.
364 getNextToken(); // eat ')'.
366 return std::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
369 /// definition ::= 'def' prototype expression
370 static std::unique_ptr<FunctionAST> ParseDefinition() {
371 getNextToken(); // eat def.
372 auto Proto = ParsePrototype();
373 if (!Proto)
374 return nullptr;
376 if (auto E = ParseExpression())
377 return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
378 return nullptr;
381 /// toplevelexpr ::= expression
382 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
383 if (auto E = ParseExpression()) {
384 // Make an anonymous proto.
385 auto Proto = std::make_unique<PrototypeAST>("__anon_expr",
386 std::vector<std::string>());
387 return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
389 return nullptr;
392 /// external ::= 'extern' prototype
393 static std::unique_ptr<PrototypeAST> ParseExtern() {
394 getNextToken(); // eat extern.
395 return ParsePrototype();
398 //===----------------------------------------------------------------------===//
399 // Code Generation
400 //===----------------------------------------------------------------------===//
402 static std::unique_ptr<LLVMContext> TheContext;
403 static std::unique_ptr<Module> TheModule;
404 static std::unique_ptr<IRBuilder<>> Builder;
405 static std::map<std::string, Value *> NamedValues;
407 Value *LogErrorV(const char *Str) {
408 LogError(Str);
409 return nullptr;
412 Value *NumberExprAST::codegen() {
413 return ConstantFP::get(*TheContext, APFloat(Val));
416 Value *VariableExprAST::codegen() {
417 // Look this variable up in the function.
418 Value *V = NamedValues[Name];
419 if (!V)
420 return LogErrorV("Unknown variable name");
421 return V;
424 Value *BinaryExprAST::codegen() {
425 Value *L = LHS->codegen();
426 Value *R = RHS->codegen();
427 if (!L || !R)
428 return nullptr;
430 switch (Op) {
431 case '+':
432 return Builder->CreateFAdd(L, R, "addtmp");
433 case '-':
434 return Builder->CreateFSub(L, R, "subtmp");
435 case '*':
436 return Builder->CreateFMul(L, R, "multmp");
437 case '<':
438 L = Builder->CreateFCmpULT(L, R, "cmptmp");
439 // Convert bool 0/1 to double 0.0 or 1.0
440 return Builder->CreateUIToFP(L, Type::getDoubleTy(*TheContext), "booltmp");
441 default:
442 return LogErrorV("invalid binary operator");
446 Value *CallExprAST::codegen() {
447 // Look up the name in the global module table.
448 Function *CalleeF = TheModule->getFunction(Callee);
449 if (!CalleeF)
450 return LogErrorV("Unknown function referenced");
452 // If argument mismatch error.
453 if (CalleeF->arg_size() != Args.size())
454 return LogErrorV("Incorrect # arguments passed");
456 std::vector<Value *> ArgsV;
457 for (unsigned i = 0, e = Args.size(); i != e; ++i) {
458 ArgsV.push_back(Args[i]->codegen());
459 if (!ArgsV.back())
460 return nullptr;
463 return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
466 Function *PrototypeAST::codegen() {
467 // Make the function type: double(double,double) etc.
468 std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(*TheContext));
469 FunctionType *FT =
470 FunctionType::get(Type::getDoubleTy(*TheContext), Doubles, false);
472 Function *F =
473 Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
475 // Set names for all arguments.
476 unsigned Idx = 0;
477 for (auto &Arg : F->args())
478 Arg.setName(Args[Idx++]);
480 return F;
483 Function *FunctionAST::codegen() {
484 // First, check for an existing function from a previous 'extern' declaration.
485 Function *TheFunction = TheModule->getFunction(Proto->getName());
487 if (!TheFunction)
488 TheFunction = Proto->codegen();
490 if (!TheFunction)
491 return nullptr;
493 // Create a new basic block to start insertion into.
494 BasicBlock *BB = BasicBlock::Create(*TheContext, "entry", TheFunction);
495 Builder->SetInsertPoint(BB);
497 // Record the function arguments in the NamedValues map.
498 NamedValues.clear();
499 for (auto &Arg : TheFunction->args())
500 NamedValues[std::string(Arg.getName())] = &Arg;
502 if (Value *RetVal = Body->codegen()) {
503 // Finish off the function.
504 Builder->CreateRet(RetVal);
506 // Validate the generated code, checking for consistency.
507 verifyFunction(*TheFunction);
509 return TheFunction;
512 // Error reading body, remove function.
513 TheFunction->eraseFromParent();
514 return nullptr;
517 //===----------------------------------------------------------------------===//
518 // Top-Level parsing and JIT Driver
519 //===----------------------------------------------------------------------===//
521 static void InitializeModule() {
522 // Open a new context and module.
523 TheContext = std::make_unique<LLVMContext>();
524 TheModule = std::make_unique<Module>("my cool jit", *TheContext);
526 // Create a new builder for the module.
527 Builder = std::make_unique<IRBuilder<>>(*TheContext);
530 static void HandleDefinition() {
531 if (auto FnAST = ParseDefinition()) {
532 if (auto *FnIR = FnAST->codegen()) {
533 fprintf(stderr, "Read function definition:");
534 FnIR->print(errs());
535 fprintf(stderr, "\n");
537 } else {
538 // Skip token for error recovery.
539 getNextToken();
543 static void HandleExtern() {
544 if (auto ProtoAST = ParseExtern()) {
545 if (auto *FnIR = ProtoAST->codegen()) {
546 fprintf(stderr, "Read extern: ");
547 FnIR->print(errs());
548 fprintf(stderr, "\n");
550 } else {
551 // Skip token for error recovery.
552 getNextToken();
556 static void HandleTopLevelExpression() {
557 // Evaluate a top-level expression into an anonymous function.
558 if (auto FnAST = ParseTopLevelExpr()) {
559 if (auto *FnIR = FnAST->codegen()) {
560 fprintf(stderr, "Read top-level expression:");
561 FnIR->print(errs());
562 fprintf(stderr, "\n");
564 // Remove the anonymous expression.
565 FnIR->eraseFromParent();
567 } else {
568 // Skip token for error recovery.
569 getNextToken();
573 /// top ::= definition | external | expression | ';'
574 static void MainLoop() {
575 while (true) {
576 fprintf(stderr, "ready> ");
577 switch (CurTok) {
578 case tok_eof:
579 return;
580 case ';': // ignore top-level semicolons.
581 getNextToken();
582 break;
583 case tok_def:
584 HandleDefinition();
585 break;
586 case tok_extern:
587 HandleExtern();
588 break;
589 default:
590 HandleTopLevelExpression();
591 break;
596 //===----------------------------------------------------------------------===//
597 // Main driver code.
598 //===----------------------------------------------------------------------===//
600 int main() {
601 // Install standard binary operators.
602 // 1 is lowest precedence.
603 BinopPrecedence['<'] = 10;
604 BinopPrecedence['+'] = 20;
605 BinopPrecedence['-'] = 20;
606 BinopPrecedence['*'] = 40; // highest.
608 // Prime the first token.
609 fprintf(stderr, "ready> ");
610 getNextToken();
612 // Make the module, which holds all the code.
613 InitializeModule();
615 // Run the main "interpreter loop" now.
616 MainLoop();
618 // Print out all of the generated code.
619 TheModule->print(errs(), nullptr);
621 return 0;