Fix test failures introduced by PR #113697 (#116941)
[llvm-project.git] / llvm / examples / Kaleidoscope / Chapter4 / toy.cpp
blob1bbc294bf35263774ba1c25d5f2202c9f276b3f7
1 #include "../include/KaleidoscopeJIT.h"
2 #include "llvm/ADT/APFloat.h"
3 #include "llvm/ADT/STLExtras.h"
4 #include "llvm/IR/BasicBlock.h"
5 #include "llvm/IR/Constants.h"
6 #include "llvm/IR/DerivedTypes.h"
7 #include "llvm/IR/Function.h"
8 #include "llvm/IR/IRBuilder.h"
9 #include "llvm/IR/LLVMContext.h"
10 #include "llvm/IR/Module.h"
11 #include "llvm/IR/PassManager.h"
12 #include "llvm/IR/Type.h"
13 #include "llvm/IR/Verifier.h"
14 #include "llvm/Passes/PassBuilder.h"
15 #include "llvm/Passes/StandardInstrumentations.h"
16 #include "llvm/Support/TargetSelect.h"
17 #include "llvm/Target/TargetMachine.h"
18 #include "llvm/Transforms/InstCombine/InstCombine.h"
19 #include "llvm/Transforms/Scalar.h"
20 #include "llvm/Transforms/Scalar/GVN.h"
21 #include "llvm/Transforms/Scalar/Reassociate.h"
22 #include "llvm/Transforms/Scalar/SimplifyCFG.h"
23 #include <algorithm>
24 #include <cassert>
25 #include <cctype>
26 #include <cstdint>
27 #include <cstdio>
28 #include <cstdlib>
29 #include <map>
30 #include <memory>
31 #include <string>
32 #include <vector>
34 using namespace llvm;
35 using namespace llvm::orc;
37 //===----------------------------------------------------------------------===//
38 // Lexer
39 //===----------------------------------------------------------------------===//
41 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
42 // of these for known things.
43 enum Token {
44 tok_eof = -1,
46 // commands
47 tok_def = -2,
48 tok_extern = -3,
50 // primary
51 tok_identifier = -4,
52 tok_number = -5
55 static std::string IdentifierStr; // Filled in if tok_identifier
56 static double NumVal; // Filled in if tok_number
58 /// gettok - Return the next token from standard input.
59 static int gettok() {
60 static int LastChar = ' ';
62 // Skip any whitespace.
63 while (isspace(LastChar))
64 LastChar = getchar();
66 if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
67 IdentifierStr = LastChar;
68 while (isalnum((LastChar = getchar())))
69 IdentifierStr += LastChar;
71 if (IdentifierStr == "def")
72 return tok_def;
73 if (IdentifierStr == "extern")
74 return tok_extern;
75 return tok_identifier;
78 if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
79 std::string NumStr;
80 do {
81 NumStr += LastChar;
82 LastChar = getchar();
83 } while (isdigit(LastChar) || LastChar == '.');
85 NumVal = strtod(NumStr.c_str(), nullptr);
86 return tok_number;
89 if (LastChar == '#') {
90 // Comment until end of line.
92 LastChar = getchar();
93 while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
95 if (LastChar != EOF)
96 return gettok();
99 // Check for end of file. Don't eat the EOF.
100 if (LastChar == EOF)
101 return tok_eof;
103 // Otherwise, just return the character as its ascii value.
104 int ThisChar = LastChar;
105 LastChar = getchar();
106 return ThisChar;
109 //===----------------------------------------------------------------------===//
110 // Abstract Syntax Tree (aka Parse Tree)
111 //===----------------------------------------------------------------------===//
113 namespace {
115 /// ExprAST - Base class for all expression nodes.
116 class ExprAST {
117 public:
118 virtual ~ExprAST() = default;
120 virtual Value *codegen() = 0;
123 /// NumberExprAST - Expression class for numeric literals like "1.0".
124 class NumberExprAST : public ExprAST {
125 double Val;
127 public:
128 NumberExprAST(double Val) : Val(Val) {}
130 Value *codegen() override;
133 /// VariableExprAST - Expression class for referencing a variable, like "a".
134 class VariableExprAST : public ExprAST {
135 std::string Name;
137 public:
138 VariableExprAST(const std::string &Name) : Name(Name) {}
140 Value *codegen() override;
143 /// BinaryExprAST - Expression class for a binary operator.
144 class BinaryExprAST : public ExprAST {
145 char Op;
146 std::unique_ptr<ExprAST> LHS, RHS;
148 public:
149 BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
150 std::unique_ptr<ExprAST> RHS)
151 : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
153 Value *codegen() override;
156 /// CallExprAST - Expression class for function calls.
157 class CallExprAST : public ExprAST {
158 std::string Callee;
159 std::vector<std::unique_ptr<ExprAST>> Args;
161 public:
162 CallExprAST(const std::string &Callee,
163 std::vector<std::unique_ptr<ExprAST>> Args)
164 : Callee(Callee), Args(std::move(Args)) {}
166 Value *codegen() override;
169 /// PrototypeAST - This class represents the "prototype" for a function,
170 /// which captures its name, and its argument names (thus implicitly the number
171 /// of arguments the function takes).
172 class PrototypeAST {
173 std::string Name;
174 std::vector<std::string> Args;
176 public:
177 PrototypeAST(const std::string &Name, std::vector<std::string> Args)
178 : Name(Name), Args(std::move(Args)) {}
180 Function *codegen();
181 const std::string &getName() const { return Name; }
184 /// FunctionAST - This class represents a function definition itself.
185 class FunctionAST {
186 std::unique_ptr<PrototypeAST> Proto;
187 std::unique_ptr<ExprAST> Body;
189 public:
190 FunctionAST(std::unique_ptr<PrototypeAST> Proto,
191 std::unique_ptr<ExprAST> Body)
192 : Proto(std::move(Proto)), Body(std::move(Body)) {}
194 Function *codegen();
197 } // end anonymous namespace
199 //===----------------------------------------------------------------------===//
200 // Parser
201 //===----------------------------------------------------------------------===//
203 /// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current
204 /// token the parser is looking at. getNextToken reads another token from the
205 /// lexer and updates CurTok with its results.
206 static int CurTok;
207 static int getNextToken() { return CurTok = gettok(); }
209 /// BinopPrecedence - This holds the precedence for each binary operator that is
210 /// defined.
211 static std::map<char, int> BinopPrecedence;
213 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
214 static int GetTokPrecedence() {
215 if (!isascii(CurTok))
216 return -1;
218 // Make sure it's a declared binop.
219 int TokPrec = BinopPrecedence[CurTok];
220 if (TokPrec <= 0)
221 return -1;
222 return TokPrec;
225 /// LogError* - These are little helper functions for error handling.
226 std::unique_ptr<ExprAST> LogError(const char *Str) {
227 fprintf(stderr, "Error: %s\n", Str);
228 return nullptr;
231 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
232 LogError(Str);
233 return nullptr;
236 static std::unique_ptr<ExprAST> ParseExpression();
238 /// numberexpr ::= number
239 static std::unique_ptr<ExprAST> ParseNumberExpr() {
240 auto Result = std::make_unique<NumberExprAST>(NumVal);
241 getNextToken(); // consume the number
242 return std::move(Result);
245 /// parenexpr ::= '(' expression ')'
246 static std::unique_ptr<ExprAST> ParseParenExpr() {
247 getNextToken(); // eat (.
248 auto V = ParseExpression();
249 if (!V)
250 return nullptr;
252 if (CurTok != ')')
253 return LogError("expected ')'");
254 getNextToken(); // eat ).
255 return V;
258 /// identifierexpr
259 /// ::= identifier
260 /// ::= identifier '(' expression* ')'
261 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
262 std::string IdName = IdentifierStr;
264 getNextToken(); // eat identifier.
266 if (CurTok != '(') // Simple variable ref.
267 return std::make_unique<VariableExprAST>(IdName);
269 // Call.
270 getNextToken(); // eat (
271 std::vector<std::unique_ptr<ExprAST>> Args;
272 if (CurTok != ')') {
273 while (true) {
274 if (auto Arg = ParseExpression())
275 Args.push_back(std::move(Arg));
276 else
277 return nullptr;
279 if (CurTok == ')')
280 break;
282 if (CurTok != ',')
283 return LogError("Expected ')' or ',' in argument list");
284 getNextToken();
288 // Eat the ')'.
289 getNextToken();
291 return std::make_unique<CallExprAST>(IdName, std::move(Args));
294 /// primary
295 /// ::= identifierexpr
296 /// ::= numberexpr
297 /// ::= parenexpr
298 static std::unique_ptr<ExprAST> ParsePrimary() {
299 switch (CurTok) {
300 default:
301 return LogError("unknown token when expecting an expression");
302 case tok_identifier:
303 return ParseIdentifierExpr();
304 case tok_number:
305 return ParseNumberExpr();
306 case '(':
307 return ParseParenExpr();
311 /// binoprhs
312 /// ::= ('+' primary)*
313 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
314 std::unique_ptr<ExprAST> LHS) {
315 // If this is a binop, find its precedence.
316 while (true) {
317 int TokPrec = GetTokPrecedence();
319 // If this is a binop that binds at least as tightly as the current binop,
320 // consume it, otherwise we are done.
321 if (TokPrec < ExprPrec)
322 return LHS;
324 // Okay, we know this is a binop.
325 int BinOp = CurTok;
326 getNextToken(); // eat binop
328 // Parse the primary expression after the binary operator.
329 auto RHS = ParsePrimary();
330 if (!RHS)
331 return nullptr;
333 // If BinOp binds less tightly with RHS than the operator after RHS, let
334 // the pending operator take RHS as its LHS.
335 int NextPrec = GetTokPrecedence();
336 if (TokPrec < NextPrec) {
337 RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
338 if (!RHS)
339 return nullptr;
342 // Merge LHS/RHS.
343 LHS =
344 std::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
348 /// expression
349 /// ::= primary binoprhs
351 static std::unique_ptr<ExprAST> ParseExpression() {
352 auto LHS = ParsePrimary();
353 if (!LHS)
354 return nullptr;
356 return ParseBinOpRHS(0, std::move(LHS));
359 /// prototype
360 /// ::= id '(' id* ')'
361 static std::unique_ptr<PrototypeAST> ParsePrototype() {
362 if (CurTok != tok_identifier)
363 return LogErrorP("Expected function name in prototype");
365 std::string FnName = IdentifierStr;
366 getNextToken();
368 if (CurTok != '(')
369 return LogErrorP("Expected '(' in prototype");
371 std::vector<std::string> ArgNames;
372 while (getNextToken() == tok_identifier)
373 ArgNames.push_back(IdentifierStr);
374 if (CurTok != ')')
375 return LogErrorP("Expected ')' in prototype");
377 // success.
378 getNextToken(); // eat ')'.
380 return std::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
383 /// definition ::= 'def' prototype expression
384 static std::unique_ptr<FunctionAST> ParseDefinition() {
385 getNextToken(); // eat def.
386 auto Proto = ParsePrototype();
387 if (!Proto)
388 return nullptr;
390 if (auto E = ParseExpression())
391 return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
392 return nullptr;
395 /// toplevelexpr ::= expression
396 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
397 if (auto E = ParseExpression()) {
398 // Make an anonymous proto.
399 auto Proto = std::make_unique<PrototypeAST>("__anon_expr",
400 std::vector<std::string>());
401 return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
403 return nullptr;
406 /// external ::= 'extern' prototype
407 static std::unique_ptr<PrototypeAST> ParseExtern() {
408 getNextToken(); // eat extern.
409 return ParsePrototype();
412 //===----------------------------------------------------------------------===//
413 // Code Generation
414 //===----------------------------------------------------------------------===//
416 static std::unique_ptr<LLVMContext> TheContext;
417 static std::unique_ptr<Module> TheModule;
418 static std::unique_ptr<IRBuilder<>> Builder;
419 static std::map<std::string, Value *> NamedValues;
420 static std::unique_ptr<KaleidoscopeJIT> TheJIT;
421 static std::unique_ptr<FunctionPassManager> TheFPM;
422 static std::unique_ptr<LoopAnalysisManager> TheLAM;
423 static std::unique_ptr<FunctionAnalysisManager> TheFAM;
424 static std::unique_ptr<CGSCCAnalysisManager> TheCGAM;
425 static std::unique_ptr<ModuleAnalysisManager> TheMAM;
426 static std::unique_ptr<PassInstrumentationCallbacks> ThePIC;
427 static std::unique_ptr<StandardInstrumentations> TheSI;
428 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
429 static ExitOnError ExitOnErr;
431 Value *LogErrorV(const char *Str) {
432 LogError(Str);
433 return nullptr;
436 Function *getFunction(std::string Name) {
437 // First, see if the function has already been added to the current module.
438 if (auto *F = TheModule->getFunction(Name))
439 return F;
441 // If not, check whether we can codegen the declaration from some existing
442 // prototype.
443 auto FI = FunctionProtos.find(Name);
444 if (FI != FunctionProtos.end())
445 return FI->second->codegen();
447 // If no existing prototype exists, return null.
448 return nullptr;
451 Value *NumberExprAST::codegen() {
452 return ConstantFP::get(*TheContext, APFloat(Val));
455 Value *VariableExprAST::codegen() {
456 // Look this variable up in the function.
457 Value *V = NamedValues[Name];
458 if (!V)
459 return LogErrorV("Unknown variable name");
460 return V;
463 Value *BinaryExprAST::codegen() {
464 Value *L = LHS->codegen();
465 Value *R = RHS->codegen();
466 if (!L || !R)
467 return nullptr;
469 switch (Op) {
470 case '+':
471 return Builder->CreateFAdd(L, R, "addtmp");
472 case '-':
473 return Builder->CreateFSub(L, R, "subtmp");
474 case '*':
475 return Builder->CreateFMul(L, R, "multmp");
476 case '<':
477 L = Builder->CreateFCmpULT(L, R, "cmptmp");
478 // Convert bool 0/1 to double 0.0 or 1.0
479 return Builder->CreateUIToFP(L, Type::getDoubleTy(*TheContext), "booltmp");
480 default:
481 return LogErrorV("invalid binary operator");
485 Value *CallExprAST::codegen() {
486 // Look up the name in the global module table.
487 Function *CalleeF = getFunction(Callee);
488 if (!CalleeF)
489 return LogErrorV("Unknown function referenced");
491 // If argument mismatch error.
492 if (CalleeF->arg_size() != Args.size())
493 return LogErrorV("Incorrect # arguments passed");
495 std::vector<Value *> ArgsV;
496 for (unsigned i = 0, e = Args.size(); i != e; ++i) {
497 ArgsV.push_back(Args[i]->codegen());
498 if (!ArgsV.back())
499 return nullptr;
502 return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
505 Function *PrototypeAST::codegen() {
506 // Make the function type: double(double,double) etc.
507 std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(*TheContext));
508 FunctionType *FT =
509 FunctionType::get(Type::getDoubleTy(*TheContext), Doubles, false);
511 Function *F =
512 Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
514 // Set names for all arguments.
515 unsigned Idx = 0;
516 for (auto &Arg : F->args())
517 Arg.setName(Args[Idx++]);
519 return F;
522 Function *FunctionAST::codegen() {
523 // Transfer ownership of the prototype to the FunctionProtos map, but keep a
524 // reference to it for use below.
525 auto &P = *Proto;
526 FunctionProtos[Proto->getName()] = std::move(Proto);
527 Function *TheFunction = getFunction(P.getName());
528 if (!TheFunction)
529 return nullptr;
531 // Create a new basic block to start insertion into.
532 BasicBlock *BB = BasicBlock::Create(*TheContext, "entry", TheFunction);
533 Builder->SetInsertPoint(BB);
535 // Record the function arguments in the NamedValues map.
536 NamedValues.clear();
537 for (auto &Arg : TheFunction->args())
538 NamedValues[std::string(Arg.getName())] = &Arg;
540 if (Value *RetVal = Body->codegen()) {
541 // Finish off the function.
542 Builder->CreateRet(RetVal);
544 // Validate the generated code, checking for consistency.
545 verifyFunction(*TheFunction);
547 // Run the optimizer on the function.
548 TheFPM->run(*TheFunction, *TheFAM);
550 return TheFunction;
553 // Error reading body, remove function.
554 TheFunction->eraseFromParent();
555 return nullptr;
558 //===----------------------------------------------------------------------===//
559 // Top-Level parsing and JIT Driver
560 //===----------------------------------------------------------------------===//
562 static void InitializeModuleAndManagers() {
563 // Open a new context and module.
564 TheContext = std::make_unique<LLVMContext>();
565 TheModule = std::make_unique<Module>("KaleidoscopeJIT", *TheContext);
566 TheModule->setDataLayout(TheJIT->getDataLayout());
568 // Create a new builder for the module.
569 Builder = std::make_unique<IRBuilder<>>(*TheContext);
571 // Create new pass and analysis managers.
572 TheFPM = std::make_unique<FunctionPassManager>();
573 TheLAM = std::make_unique<LoopAnalysisManager>();
574 TheFAM = std::make_unique<FunctionAnalysisManager>();
575 TheCGAM = std::make_unique<CGSCCAnalysisManager>();
576 TheMAM = std::make_unique<ModuleAnalysisManager>();
577 ThePIC = std::make_unique<PassInstrumentationCallbacks>();
578 TheSI = std::make_unique<StandardInstrumentations>(*TheContext,
579 /*DebugLogging*/ true);
580 TheSI->registerCallbacks(*ThePIC, TheMAM.get());
582 // Add transform passes.
583 // Do simple "peephole" optimizations and bit-twiddling optzns.
584 TheFPM->addPass(InstCombinePass());
585 // Reassociate expressions.
586 TheFPM->addPass(ReassociatePass());
587 // Eliminate Common SubExpressions.
588 TheFPM->addPass(GVNPass());
589 // Simplify the control flow graph (deleting unreachable blocks, etc).
590 TheFPM->addPass(SimplifyCFGPass());
592 // Register analysis passes used in these transform passes.
593 PassBuilder PB;
594 PB.registerModuleAnalyses(*TheMAM);
595 PB.registerFunctionAnalyses(*TheFAM);
596 PB.crossRegisterProxies(*TheLAM, *TheFAM, *TheCGAM, *TheMAM);
599 static void HandleDefinition() {
600 if (auto FnAST = ParseDefinition()) {
601 if (auto *FnIR = FnAST->codegen()) {
602 fprintf(stderr, "Read function definition:");
603 FnIR->print(errs());
604 fprintf(stderr, "\n");
605 ExitOnErr(TheJIT->addModule(
606 ThreadSafeModule(std::move(TheModule), std::move(TheContext))));
607 InitializeModuleAndManagers();
609 } else {
610 // Skip token for error recovery.
611 getNextToken();
615 static void HandleExtern() {
616 if (auto ProtoAST = ParseExtern()) {
617 if (auto *FnIR = ProtoAST->codegen()) {
618 fprintf(stderr, "Read extern: ");
619 FnIR->print(errs());
620 fprintf(stderr, "\n");
621 FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
623 } else {
624 // Skip token for error recovery.
625 getNextToken();
629 static void HandleTopLevelExpression() {
630 // Evaluate a top-level expression into an anonymous function.
631 if (auto FnAST = ParseTopLevelExpr()) {
632 if (FnAST->codegen()) {
633 // Create a ResourceTracker to track JIT'd memory allocated to our
634 // anonymous expression -- that way we can free it after executing.
635 auto RT = TheJIT->getMainJITDylib().createResourceTracker();
637 auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
638 ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
639 InitializeModuleAndManagers();
641 // Search the JIT for the __anon_expr symbol.
642 auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr"));
644 // Get the symbol's address and cast it to the right type (takes no
645 // arguments, returns a double) so we can call it as a native function.
646 double (*FP)() = ExprSymbol.getAddress().toPtr<double (*)()>();
647 fprintf(stderr, "Evaluated to %f\n", FP());
649 // Delete the anonymous expression module from the JIT.
650 ExitOnErr(RT->remove());
652 } else {
653 // Skip token for error recovery.
654 getNextToken();
658 /// top ::= definition | external | expression | ';'
659 static void MainLoop() {
660 while (true) {
661 fprintf(stderr, "ready> ");
662 switch (CurTok) {
663 case tok_eof:
664 return;
665 case ';': // ignore top-level semicolons.
666 getNextToken();
667 break;
668 case tok_def:
669 HandleDefinition();
670 break;
671 case tok_extern:
672 HandleExtern();
673 break;
674 default:
675 HandleTopLevelExpression();
676 break;
681 //===----------------------------------------------------------------------===//
682 // "Library" functions that can be "extern'd" from user code.
683 //===----------------------------------------------------------------------===//
685 #ifdef _WIN32
686 #define DLLEXPORT __declspec(dllexport)
687 #else
688 #define DLLEXPORT
689 #endif
691 /// putchard - putchar that takes a double and returns 0.
692 extern "C" DLLEXPORT double putchard(double X) {
693 fputc((char)X, stderr);
694 return 0;
697 /// printd - printf that takes a double prints it as "%f\n", returning 0.
698 extern "C" DLLEXPORT double printd(double X) {
699 fprintf(stderr, "%f\n", X);
700 return 0;
703 //===----------------------------------------------------------------------===//
704 // Main driver code.
705 //===----------------------------------------------------------------------===//
707 int main() {
708 InitializeNativeTarget();
709 InitializeNativeTargetAsmPrinter();
710 InitializeNativeTargetAsmParser();
712 // Install standard binary operators.
713 // 1 is lowest precedence.
714 BinopPrecedence['<'] = 10;
715 BinopPrecedence['+'] = 20;
716 BinopPrecedence['-'] = 20;
717 BinopPrecedence['*'] = 40; // highest.
719 // Prime the first token.
720 fprintf(stderr, "ready> ");
721 getNextToken();
723 TheJIT = ExitOnErr(KaleidoscopeJIT::Create());
725 InitializeModuleAndManagers();
727 // Run the main "interpreter loop" now.
728 MainLoop();
730 return 0;