Fix test failures introduced by PR #113697 (#116941)
[llvm-project.git] / llvm / examples / Kaleidoscope / Chapter6 / toy.cpp
blobebe4322287b21f3afd84f448a2ac8bccf439ea50
1 #include "../include/KaleidoscopeJIT.h"
2 #include "llvm/ADT/APFloat.h"
3 #include "llvm/ADT/STLExtras.h"
4 #include "llvm/IR/BasicBlock.h"
5 #include "llvm/IR/Constants.h"
6 #include "llvm/IR/DerivedTypes.h"
7 #include "llvm/IR/Function.h"
8 #include "llvm/IR/IRBuilder.h"
9 #include "llvm/IR/Instructions.h"
10 #include "llvm/IR/LLVMContext.h"
11 #include "llvm/IR/Module.h"
12 #include "llvm/IR/PassManager.h"
13 #include "llvm/IR/Type.h"
14 #include "llvm/IR/Verifier.h"
15 #include "llvm/Passes/PassBuilder.h"
16 #include "llvm/Passes/StandardInstrumentations.h"
17 #include "llvm/Support/TargetSelect.h"
18 #include "llvm/Target/TargetMachine.h"
19 #include "llvm/Transforms/InstCombine/InstCombine.h"
20 #include "llvm/Transforms/Scalar.h"
21 #include "llvm/Transforms/Scalar/GVN.h"
22 #include "llvm/Transforms/Scalar/Reassociate.h"
23 #include "llvm/Transforms/Scalar/SimplifyCFG.h"
24 #include <algorithm>
25 #include <cassert>
26 #include <cctype>
27 #include <cstdint>
28 #include <cstdio>
29 #include <cstdlib>
30 #include <map>
31 #include <memory>
32 #include <string>
33 #include <vector>
35 using namespace llvm;
36 using namespace llvm::orc;
38 //===----------------------------------------------------------------------===//
39 // Lexer
40 //===----------------------------------------------------------------------===//
42 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
43 // of these for known things.
44 enum Token {
45 tok_eof = -1,
47 // commands
48 tok_def = -2,
49 tok_extern = -3,
51 // primary
52 tok_identifier = -4,
53 tok_number = -5,
55 // control
56 tok_if = -6,
57 tok_then = -7,
58 tok_else = -8,
59 tok_for = -9,
60 tok_in = -10,
62 // operators
63 tok_binary = -11,
64 tok_unary = -12
67 static std::string IdentifierStr; // Filled in if tok_identifier
68 static double NumVal; // Filled in if tok_number
70 /// gettok - Return the next token from standard input.
71 static int gettok() {
72 static int LastChar = ' ';
74 // Skip any whitespace.
75 while (isspace(LastChar))
76 LastChar = getchar();
78 if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
79 IdentifierStr = LastChar;
80 while (isalnum((LastChar = getchar())))
81 IdentifierStr += LastChar;
83 if (IdentifierStr == "def")
84 return tok_def;
85 if (IdentifierStr == "extern")
86 return tok_extern;
87 if (IdentifierStr == "if")
88 return tok_if;
89 if (IdentifierStr == "then")
90 return tok_then;
91 if (IdentifierStr == "else")
92 return tok_else;
93 if (IdentifierStr == "for")
94 return tok_for;
95 if (IdentifierStr == "in")
96 return tok_in;
97 if (IdentifierStr == "binary")
98 return tok_binary;
99 if (IdentifierStr == "unary")
100 return tok_unary;
101 return tok_identifier;
104 if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
105 std::string NumStr;
106 do {
107 NumStr += LastChar;
108 LastChar = getchar();
109 } while (isdigit(LastChar) || LastChar == '.');
111 NumVal = strtod(NumStr.c_str(), nullptr);
112 return tok_number;
115 if (LastChar == '#') {
116 // Comment until end of line.
118 LastChar = getchar();
119 while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
121 if (LastChar != EOF)
122 return gettok();
125 // Check for end of file. Don't eat the EOF.
126 if (LastChar == EOF)
127 return tok_eof;
129 // Otherwise, just return the character as its ascii value.
130 int ThisChar = LastChar;
131 LastChar = getchar();
132 return ThisChar;
135 //===----------------------------------------------------------------------===//
136 // Abstract Syntax Tree (aka Parse Tree)
137 //===----------------------------------------------------------------------===//
139 namespace {
141 /// ExprAST - Base class for all expression nodes.
142 class ExprAST {
143 public:
144 virtual ~ExprAST() = default;
146 virtual Value *codegen() = 0;
149 /// NumberExprAST - Expression class for numeric literals like "1.0".
150 class NumberExprAST : public ExprAST {
151 double Val;
153 public:
154 NumberExprAST(double Val) : Val(Val) {}
156 Value *codegen() override;
159 /// VariableExprAST - Expression class for referencing a variable, like "a".
160 class VariableExprAST : public ExprAST {
161 std::string Name;
163 public:
164 VariableExprAST(const std::string &Name) : Name(Name) {}
166 Value *codegen() override;
169 /// UnaryExprAST - Expression class for a unary operator.
170 class UnaryExprAST : public ExprAST {
171 char Opcode;
172 std::unique_ptr<ExprAST> Operand;
174 public:
175 UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
176 : Opcode(Opcode), Operand(std::move(Operand)) {}
178 Value *codegen() override;
181 /// BinaryExprAST - Expression class for a binary operator.
182 class BinaryExprAST : public ExprAST {
183 char Op;
184 std::unique_ptr<ExprAST> LHS, RHS;
186 public:
187 BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
188 std::unique_ptr<ExprAST> RHS)
189 : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
191 Value *codegen() override;
194 /// CallExprAST - Expression class for function calls.
195 class CallExprAST : public ExprAST {
196 std::string Callee;
197 std::vector<std::unique_ptr<ExprAST>> Args;
199 public:
200 CallExprAST(const std::string &Callee,
201 std::vector<std::unique_ptr<ExprAST>> Args)
202 : Callee(Callee), Args(std::move(Args)) {}
204 Value *codegen() override;
207 /// IfExprAST - Expression class for if/then/else.
208 class IfExprAST : public ExprAST {
209 std::unique_ptr<ExprAST> Cond, Then, Else;
211 public:
212 IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
213 std::unique_ptr<ExprAST> Else)
214 : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
216 Value *codegen() override;
219 /// ForExprAST - Expression class for for/in.
220 class ForExprAST : public ExprAST {
221 std::string VarName;
222 std::unique_ptr<ExprAST> Start, End, Step, Body;
224 public:
225 ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start,
226 std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step,
227 std::unique_ptr<ExprAST> Body)
228 : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
229 Step(std::move(Step)), Body(std::move(Body)) {}
231 Value *codegen() override;
234 /// PrototypeAST - This class represents the "prototype" for a function,
235 /// which captures its name, and its argument names (thus implicitly the number
236 /// of arguments the function takes), as well as if it is an operator.
237 class PrototypeAST {
238 std::string Name;
239 std::vector<std::string> Args;
240 bool IsOperator;
241 unsigned Precedence; // Precedence if a binary op.
243 public:
244 PrototypeAST(const std::string &Name, std::vector<std::string> Args,
245 bool IsOperator = false, unsigned Prec = 0)
246 : Name(Name), Args(std::move(Args)), IsOperator(IsOperator),
247 Precedence(Prec) {}
249 Function *codegen();
250 const std::string &getName() const { return Name; }
252 bool isUnaryOp() const { return IsOperator && Args.size() == 1; }
253 bool isBinaryOp() const { return IsOperator && Args.size() == 2; }
255 char getOperatorName() const {
256 assert(isUnaryOp() || isBinaryOp());
257 return Name[Name.size() - 1];
260 unsigned getBinaryPrecedence() const { return Precedence; }
263 /// FunctionAST - This class represents a function definition itself.
264 class FunctionAST {
265 std::unique_ptr<PrototypeAST> Proto;
266 std::unique_ptr<ExprAST> Body;
268 public:
269 FunctionAST(std::unique_ptr<PrototypeAST> Proto,
270 std::unique_ptr<ExprAST> Body)
271 : Proto(std::move(Proto)), Body(std::move(Body)) {}
273 Function *codegen();
276 } // end anonymous namespace
278 //===----------------------------------------------------------------------===//
279 // Parser
280 //===----------------------------------------------------------------------===//
282 /// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current
283 /// token the parser is looking at. getNextToken reads another token from the
284 /// lexer and updates CurTok with its results.
285 static int CurTok;
286 static int getNextToken() { return CurTok = gettok(); }
288 /// BinopPrecedence - This holds the precedence for each binary operator that is
289 /// defined.
290 static std::map<char, int> BinopPrecedence;
292 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
293 static int GetTokPrecedence() {
294 if (!isascii(CurTok))
295 return -1;
297 // Make sure it's a declared binop.
298 int TokPrec = BinopPrecedence[CurTok];
299 if (TokPrec <= 0)
300 return -1;
301 return TokPrec;
304 /// Error* - These are little helper functions for error handling.
305 std::unique_ptr<ExprAST> LogError(const char *Str) {
306 fprintf(stderr, "Error: %s\n", Str);
307 return nullptr;
310 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
311 LogError(Str);
312 return nullptr;
315 static std::unique_ptr<ExprAST> ParseExpression();
317 /// numberexpr ::= number
318 static std::unique_ptr<ExprAST> ParseNumberExpr() {
319 auto Result = std::make_unique<NumberExprAST>(NumVal);
320 getNextToken(); // consume the number
321 return std::move(Result);
324 /// parenexpr ::= '(' expression ')'
325 static std::unique_ptr<ExprAST> ParseParenExpr() {
326 getNextToken(); // eat (.
327 auto V = ParseExpression();
328 if (!V)
329 return nullptr;
331 if (CurTok != ')')
332 return LogError("expected ')'");
333 getNextToken(); // eat ).
334 return V;
337 /// identifierexpr
338 /// ::= identifier
339 /// ::= identifier '(' expression* ')'
340 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
341 std::string IdName = IdentifierStr;
343 getNextToken(); // eat identifier.
345 if (CurTok != '(') // Simple variable ref.
346 return std::make_unique<VariableExprAST>(IdName);
348 // Call.
349 getNextToken(); // eat (
350 std::vector<std::unique_ptr<ExprAST>> Args;
351 if (CurTok != ')') {
352 while (true) {
353 if (auto Arg = ParseExpression())
354 Args.push_back(std::move(Arg));
355 else
356 return nullptr;
358 if (CurTok == ')')
359 break;
361 if (CurTok != ',')
362 return LogError("Expected ')' or ',' in argument list");
363 getNextToken();
367 // Eat the ')'.
368 getNextToken();
370 return std::make_unique<CallExprAST>(IdName, std::move(Args));
373 /// ifexpr ::= 'if' expression 'then' expression 'else' expression
374 static std::unique_ptr<ExprAST> ParseIfExpr() {
375 getNextToken(); // eat the if.
377 // condition.
378 auto Cond = ParseExpression();
379 if (!Cond)
380 return nullptr;
382 if (CurTok != tok_then)
383 return LogError("expected then");
384 getNextToken(); // eat the then
386 auto Then = ParseExpression();
387 if (!Then)
388 return nullptr;
390 if (CurTok != tok_else)
391 return LogError("expected else");
393 getNextToken();
395 auto Else = ParseExpression();
396 if (!Else)
397 return nullptr;
399 return std::make_unique<IfExprAST>(std::move(Cond), std::move(Then),
400 std::move(Else));
403 /// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
404 static std::unique_ptr<ExprAST> ParseForExpr() {
405 getNextToken(); // eat the for.
407 if (CurTok != tok_identifier)
408 return LogError("expected identifier after for");
410 std::string IdName = IdentifierStr;
411 getNextToken(); // eat identifier.
413 if (CurTok != '=')
414 return LogError("expected '=' after for");
415 getNextToken(); // eat '='.
417 auto Start = ParseExpression();
418 if (!Start)
419 return nullptr;
420 if (CurTok != ',')
421 return LogError("expected ',' after for start value");
422 getNextToken();
424 auto End = ParseExpression();
425 if (!End)
426 return nullptr;
428 // The step value is optional.
429 std::unique_ptr<ExprAST> Step;
430 if (CurTok == ',') {
431 getNextToken();
432 Step = ParseExpression();
433 if (!Step)
434 return nullptr;
437 if (CurTok != tok_in)
438 return LogError("expected 'in' after for");
439 getNextToken(); // eat 'in'.
441 auto Body = ParseExpression();
442 if (!Body)
443 return nullptr;
445 return std::make_unique<ForExprAST>(IdName, std::move(Start), std::move(End),
446 std::move(Step), std::move(Body));
449 /// primary
450 /// ::= identifierexpr
451 /// ::= numberexpr
452 /// ::= parenexpr
453 /// ::= ifexpr
454 /// ::= forexpr
455 static std::unique_ptr<ExprAST> ParsePrimary() {
456 switch (CurTok) {
457 default:
458 return LogError("unknown token when expecting an expression");
459 case tok_identifier:
460 return ParseIdentifierExpr();
461 case tok_number:
462 return ParseNumberExpr();
463 case '(':
464 return ParseParenExpr();
465 case tok_if:
466 return ParseIfExpr();
467 case tok_for:
468 return ParseForExpr();
472 /// unary
473 /// ::= primary
474 /// ::= '!' unary
475 static std::unique_ptr<ExprAST> ParseUnary() {
476 // If the current token is not an operator, it must be a primary expr.
477 if (!isascii(CurTok) || CurTok == '(' || CurTok == ',')
478 return ParsePrimary();
480 // If this is a unary operator, read it.
481 int Opc = CurTok;
482 getNextToken();
483 if (auto Operand = ParseUnary())
484 return std::make_unique<UnaryExprAST>(Opc, std::move(Operand));
485 return nullptr;
488 /// binoprhs
489 /// ::= ('+' unary)*
490 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
491 std::unique_ptr<ExprAST> LHS) {
492 // If this is a binop, find its precedence.
493 while (true) {
494 int TokPrec = GetTokPrecedence();
496 // If this is a binop that binds at least as tightly as the current binop,
497 // consume it, otherwise we are done.
498 if (TokPrec < ExprPrec)
499 return LHS;
501 // Okay, we know this is a binop.
502 int BinOp = CurTok;
503 getNextToken(); // eat binop
505 // Parse the unary expression after the binary operator.
506 auto RHS = ParseUnary();
507 if (!RHS)
508 return nullptr;
510 // If BinOp binds less tightly with RHS than the operator after RHS, let
511 // the pending operator take RHS as its LHS.
512 int NextPrec = GetTokPrecedence();
513 if (TokPrec < NextPrec) {
514 RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
515 if (!RHS)
516 return nullptr;
519 // Merge LHS/RHS.
520 LHS =
521 std::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
525 /// expression
526 /// ::= unary binoprhs
528 static std::unique_ptr<ExprAST> ParseExpression() {
529 auto LHS = ParseUnary();
530 if (!LHS)
531 return nullptr;
533 return ParseBinOpRHS(0, std::move(LHS));
536 /// prototype
537 /// ::= id '(' id* ')'
538 /// ::= binary LETTER number? (id, id)
539 /// ::= unary LETTER (id)
540 static std::unique_ptr<PrototypeAST> ParsePrototype() {
541 std::string FnName;
543 unsigned Kind = 0; // 0 = identifier, 1 = unary, 2 = binary.
544 unsigned BinaryPrecedence = 30;
546 switch (CurTok) {
547 default:
548 return LogErrorP("Expected function name in prototype");
549 case tok_identifier:
550 FnName = IdentifierStr;
551 Kind = 0;
552 getNextToken();
553 break;
554 case tok_unary:
555 getNextToken();
556 if (!isascii(CurTok))
557 return LogErrorP("Expected unary operator");
558 FnName = "unary";
559 FnName += (char)CurTok;
560 Kind = 1;
561 getNextToken();
562 break;
563 case tok_binary:
564 getNextToken();
565 if (!isascii(CurTok))
566 return LogErrorP("Expected binary operator");
567 FnName = "binary";
568 FnName += (char)CurTok;
569 Kind = 2;
570 getNextToken();
572 // Read the precedence if present.
573 if (CurTok == tok_number) {
574 if (NumVal < 1 || NumVal > 100)
575 return LogErrorP("Invalid precedence: must be 1..100");
576 BinaryPrecedence = (unsigned)NumVal;
577 getNextToken();
579 break;
582 if (CurTok != '(')
583 return LogErrorP("Expected '(' in prototype");
585 std::vector<std::string> ArgNames;
586 while (getNextToken() == tok_identifier)
587 ArgNames.push_back(IdentifierStr);
588 if (CurTok != ')')
589 return LogErrorP("Expected ')' in prototype");
591 // success.
592 getNextToken(); // eat ')'.
594 // Verify right number of names for operator.
595 if (Kind && ArgNames.size() != Kind)
596 return LogErrorP("Invalid number of operands for operator");
598 return std::make_unique<PrototypeAST>(FnName, ArgNames, Kind != 0,
599 BinaryPrecedence);
602 /// definition ::= 'def' prototype expression
603 static std::unique_ptr<FunctionAST> ParseDefinition() {
604 getNextToken(); // eat def.
605 auto Proto = ParsePrototype();
606 if (!Proto)
607 return nullptr;
609 if (auto E = ParseExpression())
610 return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
611 return nullptr;
614 /// toplevelexpr ::= expression
615 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
616 if (auto E = ParseExpression()) {
617 // Make an anonymous proto.
618 auto Proto = std::make_unique<PrototypeAST>("__anon_expr",
619 std::vector<std::string>());
620 return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
622 return nullptr;
625 /// external ::= 'extern' prototype
626 static std::unique_ptr<PrototypeAST> ParseExtern() {
627 getNextToken(); // eat extern.
628 return ParsePrototype();
631 //===----------------------------------------------------------------------===//
632 // Code Generation
633 //===----------------------------------------------------------------------===//
635 static std::unique_ptr<LLVMContext> TheContext;
636 static std::unique_ptr<Module> TheModule;
637 static std::unique_ptr<IRBuilder<>> Builder;
638 static std::map<std::string, Value *> NamedValues;
639 static std::unique_ptr<KaleidoscopeJIT> TheJIT;
640 static std::unique_ptr<FunctionPassManager> TheFPM;
641 static std::unique_ptr<LoopAnalysisManager> TheLAM;
642 static std::unique_ptr<FunctionAnalysisManager> TheFAM;
643 static std::unique_ptr<CGSCCAnalysisManager> TheCGAM;
644 static std::unique_ptr<ModuleAnalysisManager> TheMAM;
645 static std::unique_ptr<PassInstrumentationCallbacks> ThePIC;
646 static std::unique_ptr<StandardInstrumentations> TheSI;
647 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
648 static ExitOnError ExitOnErr;
650 Value *LogErrorV(const char *Str) {
651 LogError(Str);
652 return nullptr;
655 Function *getFunction(std::string Name) {
656 // First, see if the function has already been added to the current module.
657 if (auto *F = TheModule->getFunction(Name))
658 return F;
660 // If not, check whether we can codegen the declaration from some existing
661 // prototype.
662 auto FI = FunctionProtos.find(Name);
663 if (FI != FunctionProtos.end())
664 return FI->second->codegen();
666 // If no existing prototype exists, return null.
667 return nullptr;
670 Value *NumberExprAST::codegen() {
671 return ConstantFP::get(*TheContext, APFloat(Val));
674 Value *VariableExprAST::codegen() {
675 // Look this variable up in the function.
676 Value *V = NamedValues[Name];
677 if (!V)
678 return LogErrorV("Unknown variable name");
679 return V;
682 Value *UnaryExprAST::codegen() {
683 Value *OperandV = Operand->codegen();
684 if (!OperandV)
685 return nullptr;
687 Function *F = getFunction(std::string("unary") + Opcode);
688 if (!F)
689 return LogErrorV("Unknown unary operator");
691 return Builder->CreateCall(F, OperandV, "unop");
694 Value *BinaryExprAST::codegen() {
695 Value *L = LHS->codegen();
696 Value *R = RHS->codegen();
697 if (!L || !R)
698 return nullptr;
700 switch (Op) {
701 case '+':
702 return Builder->CreateFAdd(L, R, "addtmp");
703 case '-':
704 return Builder->CreateFSub(L, R, "subtmp");
705 case '*':
706 return Builder->CreateFMul(L, R, "multmp");
707 case '<':
708 L = Builder->CreateFCmpULT(L, R, "cmptmp");
709 // Convert bool 0/1 to double 0.0 or 1.0
710 return Builder->CreateUIToFP(L, Type::getDoubleTy(*TheContext), "booltmp");
711 default:
712 break;
715 // If it wasn't a builtin binary operator, it must be a user defined one. Emit
716 // a call to it.
717 Function *F = getFunction(std::string("binary") + Op);
718 assert(F && "binary operator not found!");
720 Value *Ops[] = {L, R};
721 return Builder->CreateCall(F, Ops, "binop");
724 Value *CallExprAST::codegen() {
725 // Look up the name in the global module table.
726 Function *CalleeF = getFunction(Callee);
727 if (!CalleeF)
728 return LogErrorV("Unknown function referenced");
730 // If argument mismatch error.
731 if (CalleeF->arg_size() != Args.size())
732 return LogErrorV("Incorrect # arguments passed");
734 std::vector<Value *> ArgsV;
735 for (unsigned i = 0, e = Args.size(); i != e; ++i) {
736 ArgsV.push_back(Args[i]->codegen());
737 if (!ArgsV.back())
738 return nullptr;
741 return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
744 Value *IfExprAST::codegen() {
745 Value *CondV = Cond->codegen();
746 if (!CondV)
747 return nullptr;
749 // Convert condition to a bool by comparing non-equal to 0.0.
750 CondV = Builder->CreateFCmpONE(
751 CondV, ConstantFP::get(*TheContext, APFloat(0.0)), "ifcond");
753 Function *TheFunction = Builder->GetInsertBlock()->getParent();
755 // Create blocks for the then and else cases. Insert the 'then' block at the
756 // end of the function.
757 BasicBlock *ThenBB = BasicBlock::Create(*TheContext, "then", TheFunction);
758 BasicBlock *ElseBB = BasicBlock::Create(*TheContext, "else");
759 BasicBlock *MergeBB = BasicBlock::Create(*TheContext, "ifcont");
761 Builder->CreateCondBr(CondV, ThenBB, ElseBB);
763 // Emit then value.
764 Builder->SetInsertPoint(ThenBB);
766 Value *ThenV = Then->codegen();
767 if (!ThenV)
768 return nullptr;
770 Builder->CreateBr(MergeBB);
771 // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
772 ThenBB = Builder->GetInsertBlock();
774 // Emit else block.
775 TheFunction->insert(TheFunction->end(), ElseBB);
776 Builder->SetInsertPoint(ElseBB);
778 Value *ElseV = Else->codegen();
779 if (!ElseV)
780 return nullptr;
782 Builder->CreateBr(MergeBB);
783 // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
784 ElseBB = Builder->GetInsertBlock();
786 // Emit merge block.
787 TheFunction->insert(TheFunction->end(), MergeBB);
788 Builder->SetInsertPoint(MergeBB);
789 PHINode *PN = Builder->CreatePHI(Type::getDoubleTy(*TheContext), 2, "iftmp");
791 PN->addIncoming(ThenV, ThenBB);
792 PN->addIncoming(ElseV, ElseBB);
793 return PN;
796 // Output for-loop as:
797 // ...
798 // start = startexpr
799 // goto loop
800 // loop:
801 // variable = phi [start, loopheader], [nextvariable, loopend]
802 // ...
803 // bodyexpr
804 // ...
805 // loopend:
806 // step = stepexpr
807 // nextvariable = variable + step
808 // endcond = endexpr
809 // br endcond, loop, endloop
810 // outloop:
811 Value *ForExprAST::codegen() {
812 // Emit the start code first, without 'variable' in scope.
813 Value *StartVal = Start->codegen();
814 if (!StartVal)
815 return nullptr;
817 // Make the new basic block for the loop header, inserting after current
818 // block.
819 Function *TheFunction = Builder->GetInsertBlock()->getParent();
820 BasicBlock *PreheaderBB = Builder->GetInsertBlock();
821 BasicBlock *LoopBB = BasicBlock::Create(*TheContext, "loop", TheFunction);
823 // Insert an explicit fall through from the current block to the LoopBB.
824 Builder->CreateBr(LoopBB);
826 // Start insertion in LoopBB.
827 Builder->SetInsertPoint(LoopBB);
829 // Start the PHI node with an entry for Start.
830 PHINode *Variable =
831 Builder->CreatePHI(Type::getDoubleTy(*TheContext), 2, VarName);
832 Variable->addIncoming(StartVal, PreheaderBB);
834 // Within the loop, the variable is defined equal to the PHI node. If it
835 // shadows an existing variable, we have to restore it, so save it now.
836 Value *OldVal = NamedValues[VarName];
837 NamedValues[VarName] = Variable;
839 // Emit the body of the loop. This, like any other expr, can change the
840 // current BB. Note that we ignore the value computed by the body, but don't
841 // allow an error.
842 if (!Body->codegen())
843 return nullptr;
845 // Emit the step value.
846 Value *StepVal = nullptr;
847 if (Step) {
848 StepVal = Step->codegen();
849 if (!StepVal)
850 return nullptr;
851 } else {
852 // If not specified, use 1.0.
853 StepVal = ConstantFP::get(*TheContext, APFloat(1.0));
856 Value *NextVar = Builder->CreateFAdd(Variable, StepVal, "nextvar");
858 // Compute the end condition.
859 Value *EndCond = End->codegen();
860 if (!EndCond)
861 return nullptr;
863 // Convert condition to a bool by comparing non-equal to 0.0.
864 EndCond = Builder->CreateFCmpONE(
865 EndCond, ConstantFP::get(*TheContext, APFloat(0.0)), "loopcond");
867 // Create the "after loop" block and insert it.
868 BasicBlock *LoopEndBB = Builder->GetInsertBlock();
869 BasicBlock *AfterBB =
870 BasicBlock::Create(*TheContext, "afterloop", TheFunction);
872 // Insert the conditional branch into the end of LoopEndBB.
873 Builder->CreateCondBr(EndCond, LoopBB, AfterBB);
875 // Any new code will be inserted in AfterBB.
876 Builder->SetInsertPoint(AfterBB);
878 // Add a new entry to the PHI node for the backedge.
879 Variable->addIncoming(NextVar, LoopEndBB);
881 // Restore the unshadowed variable.
882 if (OldVal)
883 NamedValues[VarName] = OldVal;
884 else
885 NamedValues.erase(VarName);
887 // for expr always returns 0.0.
888 return Constant::getNullValue(Type::getDoubleTy(*TheContext));
891 Function *PrototypeAST::codegen() {
892 // Make the function type: double(double,double) etc.
893 std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(*TheContext));
894 FunctionType *FT =
895 FunctionType::get(Type::getDoubleTy(*TheContext), Doubles, false);
897 Function *F =
898 Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
900 // Set names for all arguments.
901 unsigned Idx = 0;
902 for (auto &Arg : F->args())
903 Arg.setName(Args[Idx++]);
905 return F;
908 Function *FunctionAST::codegen() {
909 // Transfer ownership of the prototype to the FunctionProtos map, but keep a
910 // reference to it for use below.
911 auto &P = *Proto;
912 FunctionProtos[Proto->getName()] = std::move(Proto);
913 Function *TheFunction = getFunction(P.getName());
914 if (!TheFunction)
915 return nullptr;
917 // If this is an operator, install it.
918 if (P.isBinaryOp())
919 BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
921 // Create a new basic block to start insertion into.
922 BasicBlock *BB = BasicBlock::Create(*TheContext, "entry", TheFunction);
923 Builder->SetInsertPoint(BB);
925 // Record the function arguments in the NamedValues map.
926 NamedValues.clear();
927 for (auto &Arg : TheFunction->args())
928 NamedValues[std::string(Arg.getName())] = &Arg;
930 if (Value *RetVal = Body->codegen()) {
931 // Finish off the function.
932 Builder->CreateRet(RetVal);
934 // Validate the generated code, checking for consistency.
935 verifyFunction(*TheFunction);
937 // Run the optimizer on the function.
938 TheFPM->run(*TheFunction, *TheFAM);
940 return TheFunction;
943 // Error reading body, remove function.
944 TheFunction->eraseFromParent();
946 if (P.isBinaryOp())
947 BinopPrecedence.erase(P.getOperatorName());
948 return nullptr;
951 //===----------------------------------------------------------------------===//
952 // Top-Level parsing and JIT Driver
953 //===----------------------------------------------------------------------===//
955 static void InitializeModuleAndManagers() {
956 // Open a new context and module.
957 TheContext = std::make_unique<LLVMContext>();
958 TheModule = std::make_unique<Module>("KaleidoscopeJIT", *TheContext);
959 TheModule->setDataLayout(TheJIT->getDataLayout());
961 // Create a new builder for the module.
962 Builder = std::make_unique<IRBuilder<>>(*TheContext);
964 // Create new pass and analysis managers.
965 TheFPM = std::make_unique<FunctionPassManager>();
966 TheLAM = std::make_unique<LoopAnalysisManager>();
967 TheFAM = std::make_unique<FunctionAnalysisManager>();
968 TheCGAM = std::make_unique<CGSCCAnalysisManager>();
969 TheMAM = std::make_unique<ModuleAnalysisManager>();
970 ThePIC = std::make_unique<PassInstrumentationCallbacks>();
971 TheSI = std::make_unique<StandardInstrumentations>(*TheContext,
972 /*DebugLogging*/ true);
973 TheSI->registerCallbacks(*ThePIC, TheMAM.get());
975 // Add transform passes.
976 // Do simple "peephole" optimizations and bit-twiddling optzns.
977 TheFPM->addPass(InstCombinePass());
978 // Reassociate expressions.
979 TheFPM->addPass(ReassociatePass());
980 // Eliminate Common SubExpressions.
981 TheFPM->addPass(GVNPass());
982 // Simplify the control flow graph (deleting unreachable blocks, etc).
983 TheFPM->addPass(SimplifyCFGPass());
985 // Register analysis passes used in these transform passes.
986 PassBuilder PB;
987 PB.registerModuleAnalyses(*TheMAM);
988 PB.registerFunctionAnalyses(*TheFAM);
989 PB.crossRegisterProxies(*TheLAM, *TheFAM, *TheCGAM, *TheMAM);
992 static void HandleDefinition() {
993 if (auto FnAST = ParseDefinition()) {
994 if (auto *FnIR = FnAST->codegen()) {
995 fprintf(stderr, "Read function definition:");
996 FnIR->print(errs());
997 fprintf(stderr, "\n");
998 ExitOnErr(TheJIT->addModule(
999 ThreadSafeModule(std::move(TheModule), std::move(TheContext))));
1000 InitializeModuleAndManagers();
1002 } else {
1003 // Skip token for error recovery.
1004 getNextToken();
1008 static void HandleExtern() {
1009 if (auto ProtoAST = ParseExtern()) {
1010 if (auto *FnIR = ProtoAST->codegen()) {
1011 fprintf(stderr, "Read extern: ");
1012 FnIR->print(errs());
1013 fprintf(stderr, "\n");
1014 FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
1016 } else {
1017 // Skip token for error recovery.
1018 getNextToken();
1022 static void HandleTopLevelExpression() {
1023 // Evaluate a top-level expression into an anonymous function.
1024 if (auto FnAST = ParseTopLevelExpr()) {
1025 if (FnAST->codegen()) {
1026 // Create a ResourceTracker to track JIT'd memory allocated to our
1027 // anonymous expression -- that way we can free it after executing.
1028 auto RT = TheJIT->getMainJITDylib().createResourceTracker();
1030 auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
1031 ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
1032 InitializeModuleAndManagers();
1034 // Search the JIT for the __anon_expr symbol.
1035 auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr"));
1037 // Get the symbol's address and cast it to the right type (takes no
1038 // arguments, returns a double) so we can call it as a native function.
1039 double (*FP)() = ExprSymbol.getAddress().toPtr<double (*)()>();
1040 fprintf(stderr, "Evaluated to %f\n", FP());
1042 // Delete the anonymous expression module from the JIT.
1043 ExitOnErr(RT->remove());
1045 } else {
1046 // Skip token for error recovery.
1047 getNextToken();
1051 /// top ::= definition | external | expression | ';'
1052 static void MainLoop() {
1053 while (true) {
1054 fprintf(stderr, "ready> ");
1055 switch (CurTok) {
1056 case tok_eof:
1057 return;
1058 case ';': // ignore top-level semicolons.
1059 getNextToken();
1060 break;
1061 case tok_def:
1062 HandleDefinition();
1063 break;
1064 case tok_extern:
1065 HandleExtern();
1066 break;
1067 default:
1068 HandleTopLevelExpression();
1069 break;
1074 //===----------------------------------------------------------------------===//
1075 // "Library" functions that can be "extern'd" from user code.
1076 //===----------------------------------------------------------------------===//
1078 #ifdef _WIN32
1079 #define DLLEXPORT __declspec(dllexport)
1080 #else
1081 #define DLLEXPORT
1082 #endif
1084 /// putchard - putchar that takes a double and returns 0.
1085 extern "C" DLLEXPORT double putchard(double X) {
1086 fputc((char)X, stderr);
1087 return 0;
1090 /// printd - printf that takes a double prints it as "%f\n", returning 0.
1091 extern "C" DLLEXPORT double printd(double X) {
1092 fprintf(stderr, "%f\n", X);
1093 return 0;
1096 //===----------------------------------------------------------------------===//
1097 // Main driver code.
1098 //===----------------------------------------------------------------------===//
1100 int main() {
1101 InitializeNativeTarget();
1102 InitializeNativeTargetAsmPrinter();
1103 InitializeNativeTargetAsmParser();
1105 // Install standard binary operators.
1106 // 1 is lowest precedence.
1107 BinopPrecedence['<'] = 10;
1108 BinopPrecedence['+'] = 20;
1109 BinopPrecedence['-'] = 20;
1110 BinopPrecedence['*'] = 40; // highest.
1112 // Prime the first token.
1113 fprintf(stderr, "ready> ");
1114 getNextToken();
1116 TheJIT = ExitOnErr(KaleidoscopeJIT::Create());
1118 InitializeModuleAndManagers();
1120 // Run the main "interpreter loop" now.
1121 MainLoop();
1123 return 0;