[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / clang / lib / AST / Interp / ByteCodeStmtGen.cpp
blobb1ab5fcf9cb64c34c4aaf85de24d3fa170526105
1 //===--- ByteCodeStmtGen.cpp - Code generator for expressions ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "ByteCodeStmtGen.h"
10 #include "ByteCodeEmitter.h"
11 #include "ByteCodeGenError.h"
12 #include "Context.h"
13 #include "Function.h"
14 #include "PrimType.h"
16 using namespace clang;
17 using namespace clang::interp;
19 namespace clang {
20 namespace interp {
22 /// Scope managing label targets.
23 template <class Emitter> class LabelScope {
24 public:
25 virtual ~LabelScope() { }
27 protected:
28 LabelScope(ByteCodeStmtGen<Emitter> *Ctx) : Ctx(Ctx) {}
29 /// ByteCodeStmtGen instance.
30 ByteCodeStmtGen<Emitter> *Ctx;
33 /// Sets the context for break/continue statements.
34 template <class Emitter> class LoopScope final : public LabelScope<Emitter> {
35 public:
36 using LabelTy = typename ByteCodeStmtGen<Emitter>::LabelTy;
37 using OptLabelTy = typename ByteCodeStmtGen<Emitter>::OptLabelTy;
39 LoopScope(ByteCodeStmtGen<Emitter> *Ctx, LabelTy BreakLabel,
40 LabelTy ContinueLabel)
41 : LabelScope<Emitter>(Ctx), OldBreakLabel(Ctx->BreakLabel),
42 OldContinueLabel(Ctx->ContinueLabel) {
43 this->Ctx->BreakLabel = BreakLabel;
44 this->Ctx->ContinueLabel = ContinueLabel;
47 ~LoopScope() {
48 this->Ctx->BreakLabel = OldBreakLabel;
49 this->Ctx->ContinueLabel = OldContinueLabel;
52 private:
53 OptLabelTy OldBreakLabel;
54 OptLabelTy OldContinueLabel;
57 // Sets the context for a switch scope, mapping labels.
58 template <class Emitter> class SwitchScope final : public LabelScope<Emitter> {
59 public:
60 using LabelTy = typename ByteCodeStmtGen<Emitter>::LabelTy;
61 using OptLabelTy = typename ByteCodeStmtGen<Emitter>::OptLabelTy;
62 using CaseMap = typename ByteCodeStmtGen<Emitter>::CaseMap;
64 SwitchScope(ByteCodeStmtGen<Emitter> *Ctx, CaseMap &&CaseLabels,
65 LabelTy BreakLabel, OptLabelTy DefaultLabel)
66 : LabelScope<Emitter>(Ctx), OldBreakLabel(Ctx->BreakLabel),
67 OldDefaultLabel(this->Ctx->DefaultLabel),
68 OldCaseLabels(std::move(this->Ctx->CaseLabels)) {
69 this->Ctx->BreakLabel = BreakLabel;
70 this->Ctx->DefaultLabel = DefaultLabel;
71 this->Ctx->CaseLabels = std::move(CaseLabels);
74 ~SwitchScope() {
75 this->Ctx->BreakLabel = OldBreakLabel;
76 this->Ctx->DefaultLabel = OldDefaultLabel;
77 this->Ctx->CaseLabels = std::move(OldCaseLabels);
80 private:
81 OptLabelTy OldBreakLabel;
82 OptLabelTy OldDefaultLabel;
83 CaseMap OldCaseLabels;
86 } // namespace interp
87 } // namespace clang
89 template <class Emitter>
90 bool ByteCodeStmtGen<Emitter>::emitLambdaStaticInvokerBody(
91 const CXXMethodDecl *MD) {
92 assert(MD->isLambdaStaticInvoker());
93 assert(MD->hasBody());
94 assert(cast<CompoundStmt>(MD->getBody())->body_empty());
96 const CXXRecordDecl *ClosureClass = MD->getParent();
97 const CXXMethodDecl *LambdaCallOp = ClosureClass->getLambdaCallOperator();
98 assert(ClosureClass->captures_begin() == ClosureClass->captures_end());
99 const Function *Func = this->getFunction(LambdaCallOp);
100 if (!Func)
101 return false;
102 assert(Func->hasThisPointer());
103 assert(Func->getNumParams() == (MD->getNumParams() + 1 + Func->hasRVO()));
105 if (Func->hasRVO()) {
106 if (!this->emitRVOPtr(MD))
107 return false;
110 // The lambda call operator needs an instance pointer, but we don't have
111 // one here, and we don't need one either because the lambda cannot have
112 // any captures, as verified above. Emit a null pointer. This is then
113 // special-cased when interpreting to not emit any misleading diagnostics.
114 if (!this->emitNullPtr(MD))
115 return false;
117 // Forward all arguments from the static invoker to the lambda call operator.
118 for (const ParmVarDecl *PVD : MD->parameters()) {
119 auto It = this->Params.find(PVD);
120 assert(It != this->Params.end());
122 // We do the lvalue-to-rvalue conversion manually here, so no need
123 // to care about references.
124 PrimType ParamType = this->classify(PVD->getType()).value_or(PT_Ptr);
125 if (!this->emitGetParam(ParamType, It->second.Offset, MD))
126 return false;
129 if (!this->emitCall(Func, LambdaCallOp))
130 return false;
132 this->emitCleanup();
133 if (ReturnType)
134 return this->emitRet(*ReturnType, MD);
136 // Nothing to do, since we emitted the RVO pointer above.
137 return this->emitRetVoid(MD);
140 template <class Emitter>
141 bool ByteCodeStmtGen<Emitter>::visitFunc(const FunctionDecl *F) {
142 // Classify the return type.
143 ReturnType = this->classify(F->getReturnType());
145 // Emit custom code if this is a lambda static invoker.
146 if (const auto *MD = dyn_cast<CXXMethodDecl>(F);
147 MD && MD->isLambdaStaticInvoker())
148 return this->emitLambdaStaticInvokerBody(MD);
150 // Constructor. Set up field initializers.
151 if (const auto *Ctor = dyn_cast<CXXConstructorDecl>(F)) {
152 const RecordDecl *RD = Ctor->getParent();
153 const Record *R = this->getRecord(RD);
154 if (!R)
155 return false;
157 for (const auto *Init : Ctor->inits()) {
158 // Scope needed for the initializers.
159 BlockScope<Emitter> Scope(this);
161 const Expr *InitExpr = Init->getInit();
162 if (const FieldDecl *Member = Init->getMember()) {
163 const Record::Field *F = R->getField(Member);
165 if (std::optional<PrimType> T = this->classify(InitExpr)) {
166 if (!this->visit(InitExpr))
167 return false;
169 if (F->isBitField()) {
170 if (!this->emitInitThisBitField(*T, F, InitExpr))
171 return false;
172 } else {
173 if (!this->emitInitThisField(*T, F->Offset, InitExpr))
174 return false;
176 } else {
177 // Non-primitive case. Get a pointer to the field-to-initialize
178 // on the stack and call visitInitialzer() for it.
179 if (!this->emitGetPtrThisField(F->Offset, InitExpr))
180 return false;
182 if (!this->visitInitializer(InitExpr))
183 return false;
185 if (!this->emitPopPtr(InitExpr))
186 return false;
188 } else if (const Type *Base = Init->getBaseClass()) {
189 // Base class initializer.
190 // Get This Base and call initializer on it.
191 const auto *BaseDecl = Base->getAsCXXRecordDecl();
192 assert(BaseDecl);
193 const Record::Base *B = R->getBase(BaseDecl);
194 assert(B);
195 if (!this->emitGetPtrThisBase(B->Offset, InitExpr))
196 return false;
197 if (!this->visitInitializer(InitExpr))
198 return false;
199 if (!this->emitInitPtrPop(InitExpr))
200 return false;
201 } else {
202 assert(Init->isDelegatingInitializer());
203 if (!this->emitThis(InitExpr))
204 return false;
205 if (!this->visitInitializer(Init->getInit()))
206 return false;
207 if (!this->emitPopPtr(InitExpr))
208 return false;
213 if (const auto *Body = F->getBody())
214 if (!visitStmt(Body))
215 return false;
217 // Emit a guard return to protect against a code path missing one.
218 if (F->getReturnType()->isVoidType())
219 return this->emitRetVoid(SourceInfo{});
220 else
221 return this->emitNoRet(SourceInfo{});
224 template <class Emitter>
225 bool ByteCodeStmtGen<Emitter>::visitStmt(const Stmt *S) {
226 switch (S->getStmtClass()) {
227 case Stmt::CompoundStmtClass:
228 return visitCompoundStmt(cast<CompoundStmt>(S));
229 case Stmt::DeclStmtClass:
230 return visitDeclStmt(cast<DeclStmt>(S));
231 case Stmt::ReturnStmtClass:
232 return visitReturnStmt(cast<ReturnStmt>(S));
233 case Stmt::IfStmtClass:
234 return visitIfStmt(cast<IfStmt>(S));
235 case Stmt::WhileStmtClass:
236 return visitWhileStmt(cast<WhileStmt>(S));
237 case Stmt::DoStmtClass:
238 return visitDoStmt(cast<DoStmt>(S));
239 case Stmt::ForStmtClass:
240 return visitForStmt(cast<ForStmt>(S));
241 case Stmt::CXXForRangeStmtClass:
242 return visitCXXForRangeStmt(cast<CXXForRangeStmt>(S));
243 case Stmt::BreakStmtClass:
244 return visitBreakStmt(cast<BreakStmt>(S));
245 case Stmt::ContinueStmtClass:
246 return visitContinueStmt(cast<ContinueStmt>(S));
247 case Stmt::SwitchStmtClass:
248 return visitSwitchStmt(cast<SwitchStmt>(S));
249 case Stmt::CaseStmtClass:
250 return visitCaseStmt(cast<CaseStmt>(S));
251 case Stmt::DefaultStmtClass:
252 return visitDefaultStmt(cast<DefaultStmt>(S));
253 case Stmt::GCCAsmStmtClass:
254 case Stmt::MSAsmStmtClass:
255 return visitAsmStmt(cast<AsmStmt>(S));
256 case Stmt::AttributedStmtClass:
257 return visitAttributedStmt(cast<AttributedStmt>(S));
258 case Stmt::CXXTryStmtClass:
259 return visitCXXTryStmt(cast<CXXTryStmt>(S));
260 case Stmt::NullStmtClass:
261 return true;
262 default: {
263 if (auto *Exp = dyn_cast<Expr>(S))
264 return this->discard(Exp);
265 return this->bail(S);
270 /// Visits the given statment without creating a variable
271 /// scope for it in case it is a compound statement.
272 template <class Emitter>
273 bool ByteCodeStmtGen<Emitter>::visitLoopBody(const Stmt *S) {
274 if (isa<NullStmt>(S))
275 return true;
277 if (const auto *CS = dyn_cast<CompoundStmt>(S)) {
278 for (auto *InnerStmt : CS->body())
279 if (!visitStmt(InnerStmt))
280 return false;
281 return true;
284 return this->visitStmt(S);
287 template <class Emitter>
288 bool ByteCodeStmtGen<Emitter>::visitCompoundStmt(
289 const CompoundStmt *CompoundStmt) {
290 BlockScope<Emitter> Scope(this);
291 for (auto *InnerStmt : CompoundStmt->body())
292 if (!visitStmt(InnerStmt))
293 return false;
294 return true;
297 template <class Emitter>
298 bool ByteCodeStmtGen<Emitter>::visitDeclStmt(const DeclStmt *DS) {
299 for (auto *D : DS->decls()) {
300 if (isa<StaticAssertDecl, TagDecl, TypedefNameDecl>(D))
301 continue;
303 const auto *VD = dyn_cast<VarDecl>(D);
304 if (!VD)
305 return false;
306 if (!this->visitVarDecl(VD))
307 return false;
310 return true;
313 template <class Emitter>
314 bool ByteCodeStmtGen<Emitter>::visitReturnStmt(const ReturnStmt *RS) {
315 if (const Expr *RE = RS->getRetValue()) {
316 ExprScope<Emitter> RetScope(this);
317 if (ReturnType) {
318 // Primitive types are simply returned.
319 if (!this->visit(RE))
320 return false;
321 this->emitCleanup();
322 return this->emitRet(*ReturnType, RS);
323 } else if (RE->getType()->isVoidType()) {
324 if (!this->visit(RE))
325 return false;
326 } else {
327 // RVO - construct the value in the return location.
328 if (!this->emitRVOPtr(RE))
329 return false;
330 if (!this->visitInitializer(RE))
331 return false;
332 if (!this->emitPopPtr(RE))
333 return false;
335 this->emitCleanup();
336 return this->emitRetVoid(RS);
340 // Void return.
341 this->emitCleanup();
342 return this->emitRetVoid(RS);
345 template <class Emitter>
346 bool ByteCodeStmtGen<Emitter>::visitIfStmt(const IfStmt *IS) {
347 BlockScope<Emitter> IfScope(this);
349 if (IS->isNonNegatedConsteval())
350 return visitStmt(IS->getThen());
351 if (IS->isNegatedConsteval())
352 return IS->getElse() ? visitStmt(IS->getElse()) : true;
354 if (auto *CondInit = IS->getInit())
355 if (!visitStmt(CondInit))
356 return false;
358 if (const DeclStmt *CondDecl = IS->getConditionVariableDeclStmt())
359 if (!visitDeclStmt(CondDecl))
360 return false;
362 if (!this->visitBool(IS->getCond()))
363 return false;
365 if (const Stmt *Else = IS->getElse()) {
366 LabelTy LabelElse = this->getLabel();
367 LabelTy LabelEnd = this->getLabel();
368 if (!this->jumpFalse(LabelElse))
369 return false;
370 if (!visitStmt(IS->getThen()))
371 return false;
372 if (!this->jump(LabelEnd))
373 return false;
374 this->emitLabel(LabelElse);
375 if (!visitStmt(Else))
376 return false;
377 this->emitLabel(LabelEnd);
378 } else {
379 LabelTy LabelEnd = this->getLabel();
380 if (!this->jumpFalse(LabelEnd))
381 return false;
382 if (!visitStmt(IS->getThen()))
383 return false;
384 this->emitLabel(LabelEnd);
387 return true;
390 template <class Emitter>
391 bool ByteCodeStmtGen<Emitter>::visitWhileStmt(const WhileStmt *S) {
392 const Expr *Cond = S->getCond();
393 const Stmt *Body = S->getBody();
395 LabelTy CondLabel = this->getLabel(); // Label before the condition.
396 LabelTy EndLabel = this->getLabel(); // Label after the loop.
397 LoopScope<Emitter> LS(this, EndLabel, CondLabel);
399 this->emitLabel(CondLabel);
400 if (!this->visitBool(Cond))
401 return false;
402 if (!this->jumpFalse(EndLabel))
403 return false;
405 LocalScope<Emitter> Scope(this);
407 DestructorScope<Emitter> DS(Scope);
408 if (!this->visitLoopBody(Body))
409 return false;
412 if (!this->jump(CondLabel))
413 return false;
414 this->emitLabel(EndLabel);
416 return true;
419 template <class Emitter>
420 bool ByteCodeStmtGen<Emitter>::visitDoStmt(const DoStmt *S) {
421 const Expr *Cond = S->getCond();
422 const Stmt *Body = S->getBody();
424 LabelTy StartLabel = this->getLabel();
425 LabelTy EndLabel = this->getLabel();
426 LabelTy CondLabel = this->getLabel();
427 LoopScope<Emitter> LS(this, EndLabel, CondLabel);
428 LocalScope<Emitter> Scope(this);
430 this->emitLabel(StartLabel);
432 DestructorScope<Emitter> DS(Scope);
434 if (!this->visitLoopBody(Body))
435 return false;
436 this->emitLabel(CondLabel);
437 if (!this->visitBool(Cond))
438 return false;
440 if (!this->jumpTrue(StartLabel))
441 return false;
443 this->emitLabel(EndLabel);
444 return true;
447 template <class Emitter>
448 bool ByteCodeStmtGen<Emitter>::visitForStmt(const ForStmt *S) {
449 // for (Init; Cond; Inc) { Body }
450 const Stmt *Init = S->getInit();
451 const Expr *Cond = S->getCond();
452 const Expr *Inc = S->getInc();
453 const Stmt *Body = S->getBody();
455 LabelTy EndLabel = this->getLabel();
456 LabelTy CondLabel = this->getLabel();
457 LabelTy IncLabel = this->getLabel();
458 LoopScope<Emitter> LS(this, EndLabel, IncLabel);
459 LocalScope<Emitter> Scope(this);
461 if (Init && !this->visitStmt(Init))
462 return false;
463 this->emitLabel(CondLabel);
464 if (Cond) {
465 if (!this->visitBool(Cond))
466 return false;
467 if (!this->jumpFalse(EndLabel))
468 return false;
472 DestructorScope<Emitter> DS(Scope);
474 if (Body && !this->visitLoopBody(Body))
475 return false;
476 this->emitLabel(IncLabel);
477 if (Inc && !this->discard(Inc))
478 return false;
481 if (!this->jump(CondLabel))
482 return false;
483 this->emitLabel(EndLabel);
484 return true;
487 template <class Emitter>
488 bool ByteCodeStmtGen<Emitter>::visitCXXForRangeStmt(const CXXForRangeStmt *S) {
489 const Stmt *Init = S->getInit();
490 const Expr *Cond = S->getCond();
491 const Expr *Inc = S->getInc();
492 const Stmt *Body = S->getBody();
493 const Stmt *BeginStmt = S->getBeginStmt();
494 const Stmt *RangeStmt = S->getRangeStmt();
495 const Stmt *EndStmt = S->getEndStmt();
496 const VarDecl *LoopVar = S->getLoopVariable();
498 LabelTy EndLabel = this->getLabel();
499 LabelTy CondLabel = this->getLabel();
500 LabelTy IncLabel = this->getLabel();
501 LoopScope<Emitter> LS(this, EndLabel, IncLabel);
503 // Emit declarations needed in the loop.
504 if (Init && !this->visitStmt(Init))
505 return false;
506 if (!this->visitStmt(RangeStmt))
507 return false;
508 if (!this->visitStmt(BeginStmt))
509 return false;
510 if (!this->visitStmt(EndStmt))
511 return false;
513 // Now the condition as well as the loop variable assignment.
514 this->emitLabel(CondLabel);
515 if (!this->visitBool(Cond))
516 return false;
517 if (!this->jumpFalse(EndLabel))
518 return false;
520 if (!this->visitVarDecl(LoopVar))
521 return false;
523 // Body.
524 LocalScope<Emitter> Scope(this);
526 DestructorScope<Emitter> DS(Scope);
528 if (!this->visitLoopBody(Body))
529 return false;
530 this->emitLabel(IncLabel);
531 if (!this->discard(Inc))
532 return false;
534 if (!this->jump(CondLabel))
535 return false;
537 this->emitLabel(EndLabel);
538 return true;
541 template <class Emitter>
542 bool ByteCodeStmtGen<Emitter>::visitBreakStmt(const BreakStmt *S) {
543 if (!BreakLabel)
544 return false;
546 this->VarScope->emitDestructors();
547 return this->jump(*BreakLabel);
550 template <class Emitter>
551 bool ByteCodeStmtGen<Emitter>::visitContinueStmt(const ContinueStmt *S) {
552 if (!ContinueLabel)
553 return false;
555 this->VarScope->emitDestructors();
556 return this->jump(*ContinueLabel);
559 template <class Emitter>
560 bool ByteCodeStmtGen<Emitter>::visitSwitchStmt(const SwitchStmt *S) {
561 const Expr *Cond = S->getCond();
562 PrimType CondT = this->classifyPrim(Cond->getType());
564 LabelTy EndLabel = this->getLabel();
565 OptLabelTy DefaultLabel = std::nullopt;
566 unsigned CondVar = this->allocateLocalPrimitive(Cond, CondT, true, false);
568 if (const auto *CondInit = S->getInit())
569 if (!visitStmt(CondInit))
570 return false;
572 // Initialize condition variable.
573 if (!this->visit(Cond))
574 return false;
575 if (!this->emitSetLocal(CondT, CondVar, S))
576 return false;
578 CaseMap CaseLabels;
579 // Create labels and comparison ops for all case statements.
580 for (const SwitchCase *SC = S->getSwitchCaseList(); SC;
581 SC = SC->getNextSwitchCase()) {
582 if (const auto *CS = dyn_cast<CaseStmt>(SC)) {
583 // FIXME: Implement ranges.
584 if (CS->caseStmtIsGNURange())
585 return false;
586 CaseLabels[SC] = this->getLabel();
588 const Expr *Value = CS->getLHS();
589 PrimType ValueT = this->classifyPrim(Value->getType());
591 // Compare the case statement's value to the switch condition.
592 if (!this->emitGetLocal(CondT, CondVar, CS))
593 return false;
594 if (!this->visit(Value))
595 return false;
597 // Compare and jump to the case label.
598 if (!this->emitEQ(ValueT, S))
599 return false;
600 if (!this->jumpTrue(CaseLabels[CS]))
601 return false;
602 } else {
603 assert(!DefaultLabel);
604 DefaultLabel = this->getLabel();
608 // If none of the conditions above were true, fall through to the default
609 // statement or jump after the switch statement.
610 if (DefaultLabel) {
611 if (!this->jump(*DefaultLabel))
612 return false;
613 } else {
614 if (!this->jump(EndLabel))
615 return false;
618 SwitchScope<Emitter> SS(this, std::move(CaseLabels), EndLabel, DefaultLabel);
619 if (!this->visitStmt(S->getBody()))
620 return false;
621 this->emitLabel(EndLabel);
622 return true;
625 template <class Emitter>
626 bool ByteCodeStmtGen<Emitter>::visitCaseStmt(const CaseStmt *S) {
627 this->emitLabel(CaseLabels[S]);
628 return this->visitStmt(S->getSubStmt());
631 template <class Emitter>
632 bool ByteCodeStmtGen<Emitter>::visitDefaultStmt(const DefaultStmt *S) {
633 this->emitLabel(*DefaultLabel);
634 return this->visitStmt(S->getSubStmt());
637 template <class Emitter>
638 bool ByteCodeStmtGen<Emitter>::visitAsmStmt(const AsmStmt *S) {
639 return this->emitInvalid(S);
642 template <class Emitter>
643 bool ByteCodeStmtGen<Emitter>::visitAttributedStmt(const AttributedStmt *S) {
644 // Ignore all attributes.
645 return this->visitStmt(S->getSubStmt());
648 template <class Emitter>
649 bool ByteCodeStmtGen<Emitter>::visitCXXTryStmt(const CXXTryStmt *S) {
650 // Ignore all handlers.
651 return this->visitStmt(S->getTryBlock());
654 namespace clang {
655 namespace interp {
657 template class ByteCodeStmtGen<ByteCodeEmitter>;
659 } // namespace interp
660 } // namespace clang