[ARM] MVE sext and widen/narrow tests from larger types. NFC
[llvm-core.git] / examples / Kaleidoscope / Chapter6 / toy.cpp
blob5b3dd5a6c4e3ee28415a6960a2cede4ba375974e
1 #include "../include/KaleidoscopeJIT.h"
2 #include "llvm/ADT/APFloat.h"
3 #include "llvm/ADT/STLExtras.h"
4 #include "llvm/IR/BasicBlock.h"
5 #include "llvm/IR/Constants.h"
6 #include "llvm/IR/DerivedTypes.h"
7 #include "llvm/IR/Function.h"
8 #include "llvm/IR/IRBuilder.h"
9 #include "llvm/IR/Instructions.h"
10 #include "llvm/IR/LLVMContext.h"
11 #include "llvm/IR/LegacyPassManager.h"
12 #include "llvm/IR/Module.h"
13 #include "llvm/IR/Type.h"
14 #include "llvm/IR/Verifier.h"
15 #include "llvm/Support/TargetSelect.h"
16 #include "llvm/Target/TargetMachine.h"
17 #include "llvm/Transforms/InstCombine/InstCombine.h"
18 #include "llvm/Transforms/Scalar.h"
19 #include "llvm/Transforms/Scalar/GVN.h"
20 #include <algorithm>
21 #include <cassert>
22 #include <cctype>
23 #include <cstdint>
24 #include <cstdio>
25 #include <cstdlib>
26 #include <map>
27 #include <memory>
28 #include <string>
29 #include <vector>
31 using namespace llvm;
32 using namespace llvm::orc;
34 //===----------------------------------------------------------------------===//
35 // Lexer
36 //===----------------------------------------------------------------------===//
38 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
39 // of these for known things.
40 enum Token {
41 tok_eof = -1,
43 // commands
44 tok_def = -2,
45 tok_extern = -3,
47 // primary
48 tok_identifier = -4,
49 tok_number = -5,
51 // control
52 tok_if = -6,
53 tok_then = -7,
54 tok_else = -8,
55 tok_for = -9,
56 tok_in = -10,
58 // operators
59 tok_binary = -11,
60 tok_unary = -12
63 static std::string IdentifierStr; // Filled in if tok_identifier
64 static double NumVal; // Filled in if tok_number
66 /// gettok - Return the next token from standard input.
67 static int gettok() {
68 static int LastChar = ' ';
70 // Skip any whitespace.
71 while (isspace(LastChar))
72 LastChar = getchar();
74 if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
75 IdentifierStr = LastChar;
76 while (isalnum((LastChar = getchar())))
77 IdentifierStr += LastChar;
79 if (IdentifierStr == "def")
80 return tok_def;
81 if (IdentifierStr == "extern")
82 return tok_extern;
83 if (IdentifierStr == "if")
84 return tok_if;
85 if (IdentifierStr == "then")
86 return tok_then;
87 if (IdentifierStr == "else")
88 return tok_else;
89 if (IdentifierStr == "for")
90 return tok_for;
91 if (IdentifierStr == "in")
92 return tok_in;
93 if (IdentifierStr == "binary")
94 return tok_binary;
95 if (IdentifierStr == "unary")
96 return tok_unary;
97 return tok_identifier;
100 if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
101 std::string NumStr;
102 do {
103 NumStr += LastChar;
104 LastChar = getchar();
105 } while (isdigit(LastChar) || LastChar == '.');
107 NumVal = strtod(NumStr.c_str(), nullptr);
108 return tok_number;
111 if (LastChar == '#') {
112 // Comment until end of line.
114 LastChar = getchar();
115 while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
117 if (LastChar != EOF)
118 return gettok();
121 // Check for end of file. Don't eat the EOF.
122 if (LastChar == EOF)
123 return tok_eof;
125 // Otherwise, just return the character as its ascii value.
126 int ThisChar = LastChar;
127 LastChar = getchar();
128 return ThisChar;
131 //===----------------------------------------------------------------------===//
132 // Abstract Syntax Tree (aka Parse Tree)
133 //===----------------------------------------------------------------------===//
135 namespace {
137 /// ExprAST - Base class for all expression nodes.
138 class ExprAST {
139 public:
140 virtual ~ExprAST() = default;
142 virtual Value *codegen() = 0;
145 /// NumberExprAST - Expression class for numeric literals like "1.0".
146 class NumberExprAST : public ExprAST {
147 double Val;
149 public:
150 NumberExprAST(double Val) : Val(Val) {}
152 Value *codegen() override;
155 /// VariableExprAST - Expression class for referencing a variable, like "a".
156 class VariableExprAST : public ExprAST {
157 std::string Name;
159 public:
160 VariableExprAST(const std::string &Name) : Name(Name) {}
162 Value *codegen() override;
165 /// UnaryExprAST - Expression class for a unary operator.
166 class UnaryExprAST : public ExprAST {
167 char Opcode;
168 std::unique_ptr<ExprAST> Operand;
170 public:
171 UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
172 : Opcode(Opcode), Operand(std::move(Operand)) {}
174 Value *codegen() override;
177 /// BinaryExprAST - Expression class for a binary operator.
178 class BinaryExprAST : public ExprAST {
179 char Op;
180 std::unique_ptr<ExprAST> LHS, RHS;
182 public:
183 BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
184 std::unique_ptr<ExprAST> RHS)
185 : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
187 Value *codegen() override;
190 /// CallExprAST - Expression class for function calls.
191 class CallExprAST : public ExprAST {
192 std::string Callee;
193 std::vector<std::unique_ptr<ExprAST>> Args;
195 public:
196 CallExprAST(const std::string &Callee,
197 std::vector<std::unique_ptr<ExprAST>> Args)
198 : Callee(Callee), Args(std::move(Args)) {}
200 Value *codegen() override;
203 /// IfExprAST - Expression class for if/then/else.
204 class IfExprAST : public ExprAST {
205 std::unique_ptr<ExprAST> Cond, Then, Else;
207 public:
208 IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
209 std::unique_ptr<ExprAST> Else)
210 : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
212 Value *codegen() override;
215 /// ForExprAST - Expression class for for/in.
216 class ForExprAST : public ExprAST {
217 std::string VarName;
218 std::unique_ptr<ExprAST> Start, End, Step, Body;
220 public:
221 ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start,
222 std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step,
223 std::unique_ptr<ExprAST> Body)
224 : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
225 Step(std::move(Step)), Body(std::move(Body)) {}
227 Value *codegen() override;
230 /// PrototypeAST - This class represents the "prototype" for a function,
231 /// which captures its name, and its argument names (thus implicitly the number
232 /// of arguments the function takes), as well as if it is an operator.
233 class PrototypeAST {
234 std::string Name;
235 std::vector<std::string> Args;
236 bool IsOperator;
237 unsigned Precedence; // Precedence if a binary op.
239 public:
240 PrototypeAST(const std::string &Name, std::vector<std::string> Args,
241 bool IsOperator = false, unsigned Prec = 0)
242 : Name(Name), Args(std::move(Args)), IsOperator(IsOperator),
243 Precedence(Prec) {}
245 Function *codegen();
246 const std::string &getName() const { return Name; }
248 bool isUnaryOp() const { return IsOperator && Args.size() == 1; }
249 bool isBinaryOp() const { return IsOperator && Args.size() == 2; }
251 char getOperatorName() const {
252 assert(isUnaryOp() || isBinaryOp());
253 return Name[Name.size() - 1];
256 unsigned getBinaryPrecedence() const { return Precedence; }
259 /// FunctionAST - This class represents a function definition itself.
260 class FunctionAST {
261 std::unique_ptr<PrototypeAST> Proto;
262 std::unique_ptr<ExprAST> Body;
264 public:
265 FunctionAST(std::unique_ptr<PrototypeAST> Proto,
266 std::unique_ptr<ExprAST> Body)
267 : Proto(std::move(Proto)), Body(std::move(Body)) {}
269 Function *codegen();
272 } // end anonymous namespace
274 //===----------------------------------------------------------------------===//
275 // Parser
276 //===----------------------------------------------------------------------===//
278 /// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current
279 /// token the parser is looking at. getNextToken reads another token from the
280 /// lexer and updates CurTok with its results.
281 static int CurTok;
282 static int getNextToken() { return CurTok = gettok(); }
284 /// BinopPrecedence - This holds the precedence for each binary operator that is
285 /// defined.
286 static std::map<char, int> BinopPrecedence;
288 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
289 static int GetTokPrecedence() {
290 if (!isascii(CurTok))
291 return -1;
293 // Make sure it's a declared binop.
294 int TokPrec = BinopPrecedence[CurTok];
295 if (TokPrec <= 0)
296 return -1;
297 return TokPrec;
300 /// Error* - These are little helper functions for error handling.
301 std::unique_ptr<ExprAST> LogError(const char *Str) {
302 fprintf(stderr, "Error: %s\n", Str);
303 return nullptr;
306 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
307 LogError(Str);
308 return nullptr;
311 static std::unique_ptr<ExprAST> ParseExpression();
313 /// numberexpr ::= number
314 static std::unique_ptr<ExprAST> ParseNumberExpr() {
315 auto Result = std::make_unique<NumberExprAST>(NumVal);
316 getNextToken(); // consume the number
317 return std::move(Result);
320 /// parenexpr ::= '(' expression ')'
321 static std::unique_ptr<ExprAST> ParseParenExpr() {
322 getNextToken(); // eat (.
323 auto V = ParseExpression();
324 if (!V)
325 return nullptr;
327 if (CurTok != ')')
328 return LogError("expected ')'");
329 getNextToken(); // eat ).
330 return V;
333 /// identifierexpr
334 /// ::= identifier
335 /// ::= identifier '(' expression* ')'
336 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
337 std::string IdName = IdentifierStr;
339 getNextToken(); // eat identifier.
341 if (CurTok != '(') // Simple variable ref.
342 return std::make_unique<VariableExprAST>(IdName);
344 // Call.
345 getNextToken(); // eat (
346 std::vector<std::unique_ptr<ExprAST>> Args;
347 if (CurTok != ')') {
348 while (true) {
349 if (auto Arg = ParseExpression())
350 Args.push_back(std::move(Arg));
351 else
352 return nullptr;
354 if (CurTok == ')')
355 break;
357 if (CurTok != ',')
358 return LogError("Expected ')' or ',' in argument list");
359 getNextToken();
363 // Eat the ')'.
364 getNextToken();
366 return std::make_unique<CallExprAST>(IdName, std::move(Args));
369 /// ifexpr ::= 'if' expression 'then' expression 'else' expression
370 static std::unique_ptr<ExprAST> ParseIfExpr() {
371 getNextToken(); // eat the if.
373 // condition.
374 auto Cond = ParseExpression();
375 if (!Cond)
376 return nullptr;
378 if (CurTok != tok_then)
379 return LogError("expected then");
380 getNextToken(); // eat the then
382 auto Then = ParseExpression();
383 if (!Then)
384 return nullptr;
386 if (CurTok != tok_else)
387 return LogError("expected else");
389 getNextToken();
391 auto Else = ParseExpression();
392 if (!Else)
393 return nullptr;
395 return std::make_unique<IfExprAST>(std::move(Cond), std::move(Then),
396 std::move(Else));
399 /// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
400 static std::unique_ptr<ExprAST> ParseForExpr() {
401 getNextToken(); // eat the for.
403 if (CurTok != tok_identifier)
404 return LogError("expected identifier after for");
406 std::string IdName = IdentifierStr;
407 getNextToken(); // eat identifier.
409 if (CurTok != '=')
410 return LogError("expected '=' after for");
411 getNextToken(); // eat '='.
413 auto Start = ParseExpression();
414 if (!Start)
415 return nullptr;
416 if (CurTok != ',')
417 return LogError("expected ',' after for start value");
418 getNextToken();
420 auto End = ParseExpression();
421 if (!End)
422 return nullptr;
424 // The step value is optional.
425 std::unique_ptr<ExprAST> Step;
426 if (CurTok == ',') {
427 getNextToken();
428 Step = ParseExpression();
429 if (!Step)
430 return nullptr;
433 if (CurTok != tok_in)
434 return LogError("expected 'in' after for");
435 getNextToken(); // eat 'in'.
437 auto Body = ParseExpression();
438 if (!Body)
439 return nullptr;
441 return std::make_unique<ForExprAST>(IdName, std::move(Start), std::move(End),
442 std::move(Step), std::move(Body));
445 /// primary
446 /// ::= identifierexpr
447 /// ::= numberexpr
448 /// ::= parenexpr
449 /// ::= ifexpr
450 /// ::= forexpr
451 static std::unique_ptr<ExprAST> ParsePrimary() {
452 switch (CurTok) {
453 default:
454 return LogError("unknown token when expecting an expression");
455 case tok_identifier:
456 return ParseIdentifierExpr();
457 case tok_number:
458 return ParseNumberExpr();
459 case '(':
460 return ParseParenExpr();
461 case tok_if:
462 return ParseIfExpr();
463 case tok_for:
464 return ParseForExpr();
468 /// unary
469 /// ::= primary
470 /// ::= '!' unary
471 static std::unique_ptr<ExprAST> ParseUnary() {
472 // If the current token is not an operator, it must be a primary expr.
473 if (!isascii(CurTok) || CurTok == '(' || CurTok == ',')
474 return ParsePrimary();
476 // If this is a unary operator, read it.
477 int Opc = CurTok;
478 getNextToken();
479 if (auto Operand = ParseUnary())
480 return std::make_unique<UnaryExprAST>(Opc, std::move(Operand));
481 return nullptr;
484 /// binoprhs
485 /// ::= ('+' unary)*
486 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
487 std::unique_ptr<ExprAST> LHS) {
488 // If this is a binop, find its precedence.
489 while (true) {
490 int TokPrec = GetTokPrecedence();
492 // If this is a binop that binds at least as tightly as the current binop,
493 // consume it, otherwise we are done.
494 if (TokPrec < ExprPrec)
495 return LHS;
497 // Okay, we know this is a binop.
498 int BinOp = CurTok;
499 getNextToken(); // eat binop
501 // Parse the unary expression after the binary operator.
502 auto RHS = ParseUnary();
503 if (!RHS)
504 return nullptr;
506 // If BinOp binds less tightly with RHS than the operator after RHS, let
507 // the pending operator take RHS as its LHS.
508 int NextPrec = GetTokPrecedence();
509 if (TokPrec < NextPrec) {
510 RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
511 if (!RHS)
512 return nullptr;
515 // Merge LHS/RHS.
516 LHS =
517 std::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
521 /// expression
522 /// ::= unary binoprhs
524 static std::unique_ptr<ExprAST> ParseExpression() {
525 auto LHS = ParseUnary();
526 if (!LHS)
527 return nullptr;
529 return ParseBinOpRHS(0, std::move(LHS));
532 /// prototype
533 /// ::= id '(' id* ')'
534 /// ::= binary LETTER number? (id, id)
535 /// ::= unary LETTER (id)
536 static std::unique_ptr<PrototypeAST> ParsePrototype() {
537 std::string FnName;
539 unsigned Kind = 0; // 0 = identifier, 1 = unary, 2 = binary.
540 unsigned BinaryPrecedence = 30;
542 switch (CurTok) {
543 default:
544 return LogErrorP("Expected function name in prototype");
545 case tok_identifier:
546 FnName = IdentifierStr;
547 Kind = 0;
548 getNextToken();
549 break;
550 case tok_unary:
551 getNextToken();
552 if (!isascii(CurTok))
553 return LogErrorP("Expected unary operator");
554 FnName = "unary";
555 FnName += (char)CurTok;
556 Kind = 1;
557 getNextToken();
558 break;
559 case tok_binary:
560 getNextToken();
561 if (!isascii(CurTok))
562 return LogErrorP("Expected binary operator");
563 FnName = "binary";
564 FnName += (char)CurTok;
565 Kind = 2;
566 getNextToken();
568 // Read the precedence if present.
569 if (CurTok == tok_number) {
570 if (NumVal < 1 || NumVal > 100)
571 return LogErrorP("Invalid precedence: must be 1..100");
572 BinaryPrecedence = (unsigned)NumVal;
573 getNextToken();
575 break;
578 if (CurTok != '(')
579 return LogErrorP("Expected '(' in prototype");
581 std::vector<std::string> ArgNames;
582 while (getNextToken() == tok_identifier)
583 ArgNames.push_back(IdentifierStr);
584 if (CurTok != ')')
585 return LogErrorP("Expected ')' in prototype");
587 // success.
588 getNextToken(); // eat ')'.
590 // Verify right number of names for operator.
591 if (Kind && ArgNames.size() != Kind)
592 return LogErrorP("Invalid number of operands for operator");
594 return std::make_unique<PrototypeAST>(FnName, ArgNames, Kind != 0,
595 BinaryPrecedence);
598 /// definition ::= 'def' prototype expression
599 static std::unique_ptr<FunctionAST> ParseDefinition() {
600 getNextToken(); // eat def.
601 auto Proto = ParsePrototype();
602 if (!Proto)
603 return nullptr;
605 if (auto E = ParseExpression())
606 return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
607 return nullptr;
610 /// toplevelexpr ::= expression
611 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
612 if (auto E = ParseExpression()) {
613 // Make an anonymous proto.
614 auto Proto = std::make_unique<PrototypeAST>("__anon_expr",
615 std::vector<std::string>());
616 return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
618 return nullptr;
621 /// external ::= 'extern' prototype
622 static std::unique_ptr<PrototypeAST> ParseExtern() {
623 getNextToken(); // eat extern.
624 return ParsePrototype();
627 //===----------------------------------------------------------------------===//
628 // Code Generation
629 //===----------------------------------------------------------------------===//
631 static LLVMContext TheContext;
632 static IRBuilder<> Builder(TheContext);
633 static std::unique_ptr<Module> TheModule;
634 static std::map<std::string, Value *> NamedValues;
635 static std::unique_ptr<legacy::FunctionPassManager> TheFPM;
636 static std::unique_ptr<KaleidoscopeJIT> TheJIT;
637 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
639 Value *LogErrorV(const char *Str) {
640 LogError(Str);
641 return nullptr;
644 Function *getFunction(std::string Name) {
645 // First, see if the function has already been added to the current module.
646 if (auto *F = TheModule->getFunction(Name))
647 return F;
649 // If not, check whether we can codegen the declaration from some existing
650 // prototype.
651 auto FI = FunctionProtos.find(Name);
652 if (FI != FunctionProtos.end())
653 return FI->second->codegen();
655 // If no existing prototype exists, return null.
656 return nullptr;
659 Value *NumberExprAST::codegen() {
660 return ConstantFP::get(TheContext, APFloat(Val));
663 Value *VariableExprAST::codegen() {
664 // Look this variable up in the function.
665 Value *V = NamedValues[Name];
666 if (!V)
667 return LogErrorV("Unknown variable name");
668 return V;
671 Value *UnaryExprAST::codegen() {
672 Value *OperandV = Operand->codegen();
673 if (!OperandV)
674 return nullptr;
676 Function *F = getFunction(std::string("unary") + Opcode);
677 if (!F)
678 return LogErrorV("Unknown unary operator");
680 return Builder.CreateCall(F, OperandV, "unop");
683 Value *BinaryExprAST::codegen() {
684 Value *L = LHS->codegen();
685 Value *R = RHS->codegen();
686 if (!L || !R)
687 return nullptr;
689 switch (Op) {
690 case '+':
691 return Builder.CreateFAdd(L, R, "addtmp");
692 case '-':
693 return Builder.CreateFSub(L, R, "subtmp");
694 case '*':
695 return Builder.CreateFMul(L, R, "multmp");
696 case '<':
697 L = Builder.CreateFCmpULT(L, R, "cmptmp");
698 // Convert bool 0/1 to double 0.0 or 1.0
699 return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp");
700 default:
701 break;
704 // If it wasn't a builtin binary operator, it must be a user defined one. Emit
705 // a call to it.
706 Function *F = getFunction(std::string("binary") + Op);
707 assert(F && "binary operator not found!");
709 Value *Ops[] = {L, R};
710 return Builder.CreateCall(F, Ops, "binop");
713 Value *CallExprAST::codegen() {
714 // Look up the name in the global module table.
715 Function *CalleeF = getFunction(Callee);
716 if (!CalleeF)
717 return LogErrorV("Unknown function referenced");
719 // If argument mismatch error.
720 if (CalleeF->arg_size() != Args.size())
721 return LogErrorV("Incorrect # arguments passed");
723 std::vector<Value *> ArgsV;
724 for (unsigned i = 0, e = Args.size(); i != e; ++i) {
725 ArgsV.push_back(Args[i]->codegen());
726 if (!ArgsV.back())
727 return nullptr;
730 return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
733 Value *IfExprAST::codegen() {
734 Value *CondV = Cond->codegen();
735 if (!CondV)
736 return nullptr;
738 // Convert condition to a bool by comparing non-equal to 0.0.
739 CondV = Builder.CreateFCmpONE(
740 CondV, ConstantFP::get(TheContext, APFloat(0.0)), "ifcond");
742 Function *TheFunction = Builder.GetInsertBlock()->getParent();
744 // Create blocks for the then and else cases. Insert the 'then' block at the
745 // end of the function.
746 BasicBlock *ThenBB = BasicBlock::Create(TheContext, "then", TheFunction);
747 BasicBlock *ElseBB = BasicBlock::Create(TheContext, "else");
748 BasicBlock *MergeBB = BasicBlock::Create(TheContext, "ifcont");
750 Builder.CreateCondBr(CondV, ThenBB, ElseBB);
752 // Emit then value.
753 Builder.SetInsertPoint(ThenBB);
755 Value *ThenV = Then->codegen();
756 if (!ThenV)
757 return nullptr;
759 Builder.CreateBr(MergeBB);
760 // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
761 ThenBB = Builder.GetInsertBlock();
763 // Emit else block.
764 TheFunction->getBasicBlockList().push_back(ElseBB);
765 Builder.SetInsertPoint(ElseBB);
767 Value *ElseV = Else->codegen();
768 if (!ElseV)
769 return nullptr;
771 Builder.CreateBr(MergeBB);
772 // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
773 ElseBB = Builder.GetInsertBlock();
775 // Emit merge block.
776 TheFunction->getBasicBlockList().push_back(MergeBB);
777 Builder.SetInsertPoint(MergeBB);
778 PHINode *PN = Builder.CreatePHI(Type::getDoubleTy(TheContext), 2, "iftmp");
780 PN->addIncoming(ThenV, ThenBB);
781 PN->addIncoming(ElseV, ElseBB);
782 return PN;
785 // Output for-loop as:
786 // ...
787 // start = startexpr
788 // goto loop
789 // loop:
790 // variable = phi [start, loopheader], [nextvariable, loopend]
791 // ...
792 // bodyexpr
793 // ...
794 // loopend:
795 // step = stepexpr
796 // nextvariable = variable + step
797 // endcond = endexpr
798 // br endcond, loop, endloop
799 // outloop:
800 Value *ForExprAST::codegen() {
801 // Emit the start code first, without 'variable' in scope.
802 Value *StartVal = Start->codegen();
803 if (!StartVal)
804 return nullptr;
806 // Make the new basic block for the loop header, inserting after current
807 // block.
808 Function *TheFunction = Builder.GetInsertBlock()->getParent();
809 BasicBlock *PreheaderBB = Builder.GetInsertBlock();
810 BasicBlock *LoopBB = BasicBlock::Create(TheContext, "loop", TheFunction);
812 // Insert an explicit fall through from the current block to the LoopBB.
813 Builder.CreateBr(LoopBB);
815 // Start insertion in LoopBB.
816 Builder.SetInsertPoint(LoopBB);
818 // Start the PHI node with an entry for Start.
819 PHINode *Variable =
820 Builder.CreatePHI(Type::getDoubleTy(TheContext), 2, VarName);
821 Variable->addIncoming(StartVal, PreheaderBB);
823 // Within the loop, the variable is defined equal to the PHI node. If it
824 // shadows an existing variable, we have to restore it, so save it now.
825 Value *OldVal = NamedValues[VarName];
826 NamedValues[VarName] = Variable;
828 // Emit the body of the loop. This, like any other expr, can change the
829 // current BB. Note that we ignore the value computed by the body, but don't
830 // allow an error.
831 if (!Body->codegen())
832 return nullptr;
834 // Emit the step value.
835 Value *StepVal = nullptr;
836 if (Step) {
837 StepVal = Step->codegen();
838 if (!StepVal)
839 return nullptr;
840 } else {
841 // If not specified, use 1.0.
842 StepVal = ConstantFP::get(TheContext, APFloat(1.0));
845 Value *NextVar = Builder.CreateFAdd(Variable, StepVal, "nextvar");
847 // Compute the end condition.
848 Value *EndCond = End->codegen();
849 if (!EndCond)
850 return nullptr;
852 // Convert condition to a bool by comparing non-equal to 0.0.
853 EndCond = Builder.CreateFCmpONE(
854 EndCond, ConstantFP::get(TheContext, APFloat(0.0)), "loopcond");
856 // Create the "after loop" block and insert it.
857 BasicBlock *LoopEndBB = Builder.GetInsertBlock();
858 BasicBlock *AfterBB =
859 BasicBlock::Create(TheContext, "afterloop", TheFunction);
861 // Insert the conditional branch into the end of LoopEndBB.
862 Builder.CreateCondBr(EndCond, LoopBB, AfterBB);
864 // Any new code will be inserted in AfterBB.
865 Builder.SetInsertPoint(AfterBB);
867 // Add a new entry to the PHI node for the backedge.
868 Variable->addIncoming(NextVar, LoopEndBB);
870 // Restore the unshadowed variable.
871 if (OldVal)
872 NamedValues[VarName] = OldVal;
873 else
874 NamedValues.erase(VarName);
876 // for expr always returns 0.0.
877 return Constant::getNullValue(Type::getDoubleTy(TheContext));
880 Function *PrototypeAST::codegen() {
881 // Make the function type: double(double,double) etc.
882 std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(TheContext));
883 FunctionType *FT =
884 FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false);
886 Function *F =
887 Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
889 // Set names for all arguments.
890 unsigned Idx = 0;
891 for (auto &Arg : F->args())
892 Arg.setName(Args[Idx++]);
894 return F;
897 Function *FunctionAST::codegen() {
898 // Transfer ownership of the prototype to the FunctionProtos map, but keep a
899 // reference to it for use below.
900 auto &P = *Proto;
901 FunctionProtos[Proto->getName()] = std::move(Proto);
902 Function *TheFunction = getFunction(P.getName());
903 if (!TheFunction)
904 return nullptr;
906 // If this is an operator, install it.
907 if (P.isBinaryOp())
908 BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
910 // Create a new basic block to start insertion into.
911 BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction);
912 Builder.SetInsertPoint(BB);
914 // Record the function arguments in the NamedValues map.
915 NamedValues.clear();
916 for (auto &Arg : TheFunction->args())
917 NamedValues[Arg.getName()] = &Arg;
919 if (Value *RetVal = Body->codegen()) {
920 // Finish off the function.
921 Builder.CreateRet(RetVal);
923 // Validate the generated code, checking for consistency.
924 verifyFunction(*TheFunction);
926 // Run the optimizer on the function.
927 TheFPM->run(*TheFunction);
929 return TheFunction;
932 // Error reading body, remove function.
933 TheFunction->eraseFromParent();
935 if (P.isBinaryOp())
936 BinopPrecedence.erase(P.getOperatorName());
937 return nullptr;
940 //===----------------------------------------------------------------------===//
941 // Top-Level parsing and JIT Driver
942 //===----------------------------------------------------------------------===//
944 static void InitializeModuleAndPassManager() {
945 // Open a new module.
946 TheModule = std::make_unique<Module>("my cool jit", TheContext);
947 TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
949 // Create a new pass manager attached to it.
950 TheFPM = std::make_unique<legacy::FunctionPassManager>(TheModule.get());
952 // Do simple "peephole" optimizations and bit-twiddling optzns.
953 TheFPM->add(createInstructionCombiningPass());
954 // Reassociate expressions.
955 TheFPM->add(createReassociatePass());
956 // Eliminate Common SubExpressions.
957 TheFPM->add(createGVNPass());
958 // Simplify the control flow graph (deleting unreachable blocks, etc).
959 TheFPM->add(createCFGSimplificationPass());
961 TheFPM->doInitialization();
964 static void HandleDefinition() {
965 if (auto FnAST = ParseDefinition()) {
966 if (auto *FnIR = FnAST->codegen()) {
967 fprintf(stderr, "Read function definition:");
968 FnIR->print(errs());
969 fprintf(stderr, "\n");
970 TheJIT->addModule(std::move(TheModule));
971 InitializeModuleAndPassManager();
973 } else {
974 // Skip token for error recovery.
975 getNextToken();
979 static void HandleExtern() {
980 if (auto ProtoAST = ParseExtern()) {
981 if (auto *FnIR = ProtoAST->codegen()) {
982 fprintf(stderr, "Read extern: ");
983 FnIR->print(errs());
984 fprintf(stderr, "\n");
985 FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
987 } else {
988 // Skip token for error recovery.
989 getNextToken();
993 static void HandleTopLevelExpression() {
994 // Evaluate a top-level expression into an anonymous function.
995 if (auto FnAST = ParseTopLevelExpr()) {
996 if (FnAST->codegen()) {
997 // JIT the module containing the anonymous expression, keeping a handle so
998 // we can free it later.
999 auto H = TheJIT->addModule(std::move(TheModule));
1000 InitializeModuleAndPassManager();
1002 // Search the JIT for the __anon_expr symbol.
1003 auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
1004 assert(ExprSymbol && "Function not found");
1006 // Get the symbol's address and cast it to the right type (takes no
1007 // arguments, returns a double) so we can call it as a native function.
1008 double (*FP)() = (double (*)())(intptr_t)cantFail(ExprSymbol.getAddress());
1009 fprintf(stderr, "Evaluated to %f\n", FP());
1011 // Delete the anonymous expression module from the JIT.
1012 TheJIT->removeModule(H);
1014 } else {
1015 // Skip token for error recovery.
1016 getNextToken();
1020 /// top ::= definition | external | expression | ';'
1021 static void MainLoop() {
1022 while (true) {
1023 fprintf(stderr, "ready> ");
1024 switch (CurTok) {
1025 case tok_eof:
1026 return;
1027 case ';': // ignore top-level semicolons.
1028 getNextToken();
1029 break;
1030 case tok_def:
1031 HandleDefinition();
1032 break;
1033 case tok_extern:
1034 HandleExtern();
1035 break;
1036 default:
1037 HandleTopLevelExpression();
1038 break;
1043 //===----------------------------------------------------------------------===//
1044 // "Library" functions that can be "extern'd" from user code.
1045 //===----------------------------------------------------------------------===//
1047 #ifdef _WIN32
1048 #define DLLEXPORT __declspec(dllexport)
1049 #else
1050 #define DLLEXPORT
1051 #endif
1053 /// putchard - putchar that takes a double and returns 0.
1054 extern "C" DLLEXPORT double putchard(double X) {
1055 fputc((char)X, stderr);
1056 return 0;
1059 /// printd - printf that takes a double prints it as "%f\n", returning 0.
1060 extern "C" DLLEXPORT double printd(double X) {
1061 fprintf(stderr, "%f\n", X);
1062 return 0;
1065 //===----------------------------------------------------------------------===//
1066 // Main driver code.
1067 //===----------------------------------------------------------------------===//
1069 int main() {
1070 InitializeNativeTarget();
1071 InitializeNativeTargetAsmPrinter();
1072 InitializeNativeTargetAsmParser();
1074 // Install standard binary operators.
1075 // 1 is lowest precedence.
1076 BinopPrecedence['<'] = 10;
1077 BinopPrecedence['+'] = 20;
1078 BinopPrecedence['-'] = 20;
1079 BinopPrecedence['*'] = 40; // highest.
1081 // Prime the first token.
1082 fprintf(stderr, "ready> ");
1083 getNextToken();
1085 TheJIT = std::make_unique<KaleidoscopeJIT>();
1087 InitializeModuleAndPassManager();
1089 // Run the main "interpreter loop" now.
1090 MainLoop();
1092 return 0;