1 //===- Coroutines.cpp -----------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the common infrastructure for Coroutine Passes.
11 //===----------------------------------------------------------------------===//
13 #include "CoroInstr.h"
14 #include "CoroInternal.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "llvm/Analysis/CallGraph.h"
18 #include "llvm/IR/Attributes.h"
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/DerivedTypes.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/InstIterator.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/IntrinsicInst.h"
25 #include "llvm/IR/Intrinsics.h"
26 #include "llvm/IR/Module.h"
27 #include "llvm/IR/Type.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Transforms/Utils/Local.h"
37 // Construct the lowerer base class and initialize its members.
38 coro::LowererBase::LowererBase(Module
&M
)
39 : TheModule(M
), Context(M
.getContext()),
40 Int8Ptr(PointerType::get(Context
, 0)),
41 ResumeFnType(FunctionType::get(Type::getVoidTy(Context
), Int8Ptr
,
43 NullPtr(ConstantPointerNull::get(Int8Ptr
)) {}
45 // Creates a sequence of instructions to obtain a resume function address using
46 // llvm.coro.subfn.addr. It generates the following sequence:
48 // call i8* @llvm.coro.subfn.addr(i8* %Arg, i8 %index)
49 // bitcast i8* %2 to void(i8*)*
51 Value
*coro::LowererBase::makeSubFnCall(Value
*Arg
, int Index
,
52 Instruction
*InsertPt
) {
53 auto *IndexVal
= ConstantInt::get(Type::getInt8Ty(Context
), Index
);
54 auto *Fn
= Intrinsic::getDeclaration(&TheModule
, Intrinsic::coro_subfn_addr
);
56 assert(Index
>= CoroSubFnInst::IndexFirst
&&
57 Index
< CoroSubFnInst::IndexLast
&&
58 "makeSubFnCall: Index value out of range");
59 auto *Call
= CallInst::Create(Fn
, {Arg
, IndexVal
}, "", InsertPt
);
62 new BitCastInst(Call
, ResumeFnType
->getPointerTo(), "", InsertPt
);
66 // NOTE: Must be sorted!
67 static const char *const CoroIntrinsics
[] = {
70 "llvm.coro.async.context.alloc",
71 "llvm.coro.async.context.dealloc",
72 "llvm.coro.async.resume",
73 "llvm.coro.async.size.replace",
74 "llvm.coro.async.store_resume",
79 "llvm.coro.end.async",
84 "llvm.coro.id.retcon",
85 "llvm.coro.id.retcon.once",
87 "llvm.coro.prepare.async",
88 "llvm.coro.prepare.retcon",
93 "llvm.coro.subfn.addr",
95 "llvm.coro.suspend.async",
96 "llvm.coro.suspend.retcon",
100 static bool isCoroutineIntrinsicName(StringRef Name
) {
101 return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics
, Name
) != -1;
105 bool coro::declaresAnyIntrinsic(const Module
&M
) {
106 for (StringRef Name
: CoroIntrinsics
) {
107 assert(isCoroutineIntrinsicName(Name
) && "not a coroutine intrinsic");
108 if (M
.getNamedValue(Name
))
115 // Verifies if a module has named values listed. Also, in debug mode verifies
116 // that names are intrinsic names.
117 bool coro::declaresIntrinsics(const Module
&M
,
118 const std::initializer_list
<StringRef
> List
) {
119 for (StringRef Name
: List
) {
120 assert(isCoroutineIntrinsicName(Name
) && "not a coroutine intrinsic");
121 if (M
.getNamedValue(Name
))
128 // Replace all coro.frees associated with the provided CoroId either with 'null'
129 // if Elide is true and with its frame parameter otherwise.
130 void coro::replaceCoroFree(CoroIdInst
*CoroId
, bool Elide
) {
131 SmallVector
<CoroFreeInst
*, 4> CoroFrees
;
132 for (User
*U
: CoroId
->users())
133 if (auto CF
= dyn_cast
<CoroFreeInst
>(U
))
134 CoroFrees
.push_back(CF
);
136 if (CoroFrees
.empty())
141 ? ConstantPointerNull::get(PointerType::get(CoroId
->getContext(), 0))
142 : CoroFrees
.front()->getFrame();
144 for (CoroFreeInst
*CF
: CoroFrees
) {
145 CF
->replaceAllUsesWith(Replacement
);
146 CF
->eraseFromParent();
150 static void clear(coro::Shape
&Shape
) {
151 Shape
.CoroBegin
= nullptr;
152 Shape
.CoroEnds
.clear();
153 Shape
.CoroSizes
.clear();
154 Shape
.CoroSuspends
.clear();
156 Shape
.FrameTy
= nullptr;
157 Shape
.FramePtr
= nullptr;
158 Shape
.AllocaSpillBlock
= nullptr;
161 static CoroSaveInst
*createCoroSave(CoroBeginInst
*CoroBegin
,
162 CoroSuspendInst
*SuspendInst
) {
163 Module
*M
= SuspendInst
->getModule();
164 auto *Fn
= Intrinsic::getDeclaration(M
, Intrinsic::coro_save
);
166 cast
<CoroSaveInst
>(CallInst::Create(Fn
, CoroBegin
, "", SuspendInst
));
167 assert(!SuspendInst
->getCoroSave());
168 SuspendInst
->setArgOperand(0, SaveInst
);
172 // Collect "interesting" coroutine intrinsics.
173 void coro::Shape::buildFrom(Function
&F
) {
174 bool HasFinalSuspend
= false;
175 bool HasUnwindCoroEnd
= false;
176 size_t FinalSuspendIndex
= 0;
178 SmallVector
<CoroFrameInst
*, 8> CoroFrames
;
179 SmallVector
<CoroSaveInst
*, 2> UnusedCoroSaves
;
181 for (Instruction
&I
: instructions(F
)) {
182 if (auto II
= dyn_cast
<IntrinsicInst
>(&I
)) {
183 switch (II
->getIntrinsicID()) {
186 case Intrinsic::coro_size
:
187 CoroSizes
.push_back(cast
<CoroSizeInst
>(II
));
189 case Intrinsic::coro_align
:
190 CoroAligns
.push_back(cast
<CoroAlignInst
>(II
));
192 case Intrinsic::coro_frame
:
193 CoroFrames
.push_back(cast
<CoroFrameInst
>(II
));
195 case Intrinsic::coro_save
:
196 // After optimizations, coro_suspends using this coro_save might have
197 // been removed, remember orphaned coro_saves to remove them later.
199 UnusedCoroSaves
.push_back(cast
<CoroSaveInst
>(II
));
201 case Intrinsic::coro_suspend_async
: {
202 auto *Suspend
= cast
<CoroSuspendAsyncInst
>(II
);
203 Suspend
->checkWellFormed();
204 CoroSuspends
.push_back(Suspend
);
207 case Intrinsic::coro_suspend_retcon
: {
208 auto Suspend
= cast
<CoroSuspendRetconInst
>(II
);
209 CoroSuspends
.push_back(Suspend
);
212 case Intrinsic::coro_suspend
: {
213 auto Suspend
= cast
<CoroSuspendInst
>(II
);
214 CoroSuspends
.push_back(Suspend
);
215 if (Suspend
->isFinal()) {
218 "Only one suspend point can be marked as final");
219 HasFinalSuspend
= true;
220 FinalSuspendIndex
= CoroSuspends
.size() - 1;
224 case Intrinsic::coro_begin
: {
225 auto CB
= cast
<CoroBeginInst
>(II
);
227 // Ignore coro id's that aren't pre-split.
228 auto Id
= dyn_cast
<CoroIdInst
>(CB
->getId());
229 if (Id
&& !Id
->getInfo().isPreSplit())
234 "coroutine should have exactly one defining @llvm.coro.begin");
235 CB
->addRetAttr(Attribute::NonNull
);
236 CB
->addRetAttr(Attribute::NoAlias
);
237 CB
->removeFnAttr(Attribute::NoDuplicate
);
241 case Intrinsic::coro_end_async
:
242 case Intrinsic::coro_end
:
243 CoroEnds
.push_back(cast
<AnyCoroEndInst
>(II
));
244 if (auto *AsyncEnd
= dyn_cast
<CoroAsyncEndInst
>(II
)) {
245 AsyncEnd
->checkWellFormed();
248 if (CoroEnds
.back()->isUnwind())
249 HasUnwindCoroEnd
= true;
251 if (CoroEnds
.back()->isFallthrough() && isa
<CoroEndInst
>(II
)) {
252 // Make sure that the fallthrough coro.end is the first element in the
254 // Note: I don't think this is neccessary anymore.
255 if (CoroEnds
.size() > 1) {
256 if (CoroEnds
.front()->isFallthrough())
258 "Only one coro.end can be marked as fallthrough");
259 std::swap(CoroEnds
.front(), CoroEnds
.back());
267 // If for some reason, we were not able to find coro.begin, bailout.
269 // Replace coro.frame which are supposed to be lowered to the result of
270 // coro.begin with undef.
271 auto *Undef
= UndefValue::get(PointerType::get(F
.getContext(), 0));
272 for (CoroFrameInst
*CF
: CoroFrames
) {
273 CF
->replaceAllUsesWith(Undef
);
274 CF
->eraseFromParent();
277 // Replace all coro.suspend with undef and remove related coro.saves if
279 for (AnyCoroSuspendInst
*CS
: CoroSuspends
) {
280 CS
->replaceAllUsesWith(UndefValue::get(CS
->getType()));
281 CS
->eraseFromParent();
282 if (auto *CoroSave
= CS
->getCoroSave())
283 CoroSave
->eraseFromParent();
286 // Replace all coro.ends with unreachable instruction.
287 for (AnyCoroEndInst
*CE
: CoroEnds
)
288 changeToUnreachable(CE
);
293 auto Id
= CoroBegin
->getId();
294 switch (auto IdIntrinsic
= Id
->getIntrinsicID()) {
295 case Intrinsic::coro_id
: {
296 auto SwitchId
= cast
<CoroIdInst
>(Id
);
297 this->ABI
= coro::ABI::Switch
;
298 this->SwitchLowering
.HasFinalSuspend
= HasFinalSuspend
;
299 this->SwitchLowering
.HasUnwindCoroEnd
= HasUnwindCoroEnd
;
300 this->SwitchLowering
.ResumeSwitch
= nullptr;
301 this->SwitchLowering
.PromiseAlloca
= SwitchId
->getPromise();
302 this->SwitchLowering
.ResumeEntryBlock
= nullptr;
304 for (auto *AnySuspend
: CoroSuspends
) {
305 auto Suspend
= dyn_cast
<CoroSuspendInst
>(AnySuspend
);
310 report_fatal_error("coro.id must be paired with coro.suspend");
313 if (!Suspend
->getCoroSave())
314 createCoroSave(CoroBegin
, Suspend
);
318 case Intrinsic::coro_id_async
: {
319 auto *AsyncId
= cast
<CoroIdAsyncInst
>(Id
);
320 AsyncId
->checkWellFormed();
321 this->ABI
= coro::ABI::Async
;
322 this->AsyncLowering
.Context
= AsyncId
->getStorage();
323 this->AsyncLowering
.ContextArgNo
= AsyncId
->getStorageArgumentIndex();
324 this->AsyncLowering
.ContextHeaderSize
= AsyncId
->getStorageSize();
325 this->AsyncLowering
.ContextAlignment
=
326 AsyncId
->getStorageAlignment().value();
327 this->AsyncLowering
.AsyncFuncPointer
= AsyncId
->getAsyncFunctionPointer();
328 this->AsyncLowering
.AsyncCC
= F
.getCallingConv();
331 case Intrinsic::coro_id_retcon
:
332 case Intrinsic::coro_id_retcon_once
: {
333 auto ContinuationId
= cast
<AnyCoroIdRetconInst
>(Id
);
334 ContinuationId
->checkWellFormed();
335 this->ABI
= (IdIntrinsic
== Intrinsic::coro_id_retcon
337 : coro::ABI::RetconOnce
);
338 auto Prototype
= ContinuationId
->getPrototype();
339 this->RetconLowering
.ResumePrototype
= Prototype
;
340 this->RetconLowering
.Alloc
= ContinuationId
->getAllocFunction();
341 this->RetconLowering
.Dealloc
= ContinuationId
->getDeallocFunction();
342 this->RetconLowering
.ReturnBlock
= nullptr;
343 this->RetconLowering
.IsFrameInlineInStorage
= false;
345 // Determine the result value types, and make sure they match up with
346 // the values passed to the suspends.
347 auto ResultTys
= getRetconResultTypes();
348 auto ResumeTys
= getRetconResumeTypes();
350 for (auto *AnySuspend
: CoroSuspends
) {
351 auto Suspend
= dyn_cast
<CoroSuspendRetconInst
>(AnySuspend
);
356 report_fatal_error("coro.id.retcon.* must be paired with "
357 "coro.suspend.retcon");
360 // Check that the argument types of the suspend match the results.
361 auto SI
= Suspend
->value_begin(), SE
= Suspend
->value_end();
362 auto RI
= ResultTys
.begin(), RE
= ResultTys
.end();
363 for (; SI
!= SE
&& RI
!= RE
; ++SI
, ++RI
) {
364 auto SrcTy
= (*SI
)->getType();
366 // The optimizer likes to eliminate bitcasts leading into variadic
367 // calls, but that messes with our invariants. Re-insert the
368 // bitcast and ignore this type mismatch.
369 if (CastInst::isBitCastable(SrcTy
, *RI
)) {
370 auto BCI
= new BitCastInst(*SI
, *RI
, "", Suspend
);
377 Prototype
->getFunctionType()->dump();
379 report_fatal_error("argument to coro.suspend.retcon does not "
380 "match corresponding prototype function result");
383 if (SI
!= SE
|| RI
!= RE
) {
386 Prototype
->getFunctionType()->dump();
388 report_fatal_error("wrong number of arguments to coro.suspend.retcon");
391 // Check that the result type of the suspend matches the resume types.
392 Type
*SResultTy
= Suspend
->getType();
393 ArrayRef
<Type
*> SuspendResultTys
;
394 if (SResultTy
->isVoidTy()) {
395 // leave as empty array
396 } else if (auto SResultStructTy
= dyn_cast
<StructType
>(SResultTy
)) {
397 SuspendResultTys
= SResultStructTy
->elements();
399 // forms an ArrayRef using SResultTy, be careful
400 SuspendResultTys
= SResultTy
;
402 if (SuspendResultTys
.size() != ResumeTys
.size()) {
405 Prototype
->getFunctionType()->dump();
407 report_fatal_error("wrong number of results from coro.suspend.retcon");
409 for (size_t I
= 0, E
= ResumeTys
.size(); I
!= E
; ++I
) {
410 if (SuspendResultTys
[I
] != ResumeTys
[I
]) {
413 Prototype
->getFunctionType()->dump();
415 report_fatal_error("result from coro.suspend.retcon does not "
416 "match corresponding prototype function param");
424 llvm_unreachable("coro.begin is not dependent on a coro.id call");
427 // The coro.free intrinsic is always lowered to the result of coro.begin.
428 for (CoroFrameInst
*CF
: CoroFrames
) {
429 CF
->replaceAllUsesWith(CoroBegin
);
430 CF
->eraseFromParent();
433 // Move final suspend to be the last element in the CoroSuspends vector.
434 if (ABI
== coro::ABI::Switch
&&
435 SwitchLowering
.HasFinalSuspend
&&
436 FinalSuspendIndex
!= CoroSuspends
.size() - 1)
437 std::swap(CoroSuspends
[FinalSuspendIndex
], CoroSuspends
.back());
439 // Remove orphaned coro.saves.
440 for (CoroSaveInst
*CoroSave
: UnusedCoroSaves
)
441 CoroSave
->eraseFromParent();
444 static void propagateCallAttrsFromCallee(CallInst
*Call
, Function
*Callee
) {
445 Call
->setCallingConv(Callee
->getCallingConv());
449 static void addCallToCallGraph(CallGraph
*CG
, CallInst
*Call
, Function
*Callee
){
451 (*CG
)[Call
->getFunction()]->addCalledFunction(Call
, (*CG
)[Callee
]);
454 Value
*coro::Shape::emitAlloc(IRBuilder
<> &Builder
, Value
*Size
,
455 CallGraph
*CG
) const {
457 case coro::ABI::Switch
:
458 llvm_unreachable("can't allocate memory in coro switch-lowering");
460 case coro::ABI::Retcon
:
461 case coro::ABI::RetconOnce
: {
462 auto Alloc
= RetconLowering
.Alloc
;
463 Size
= Builder
.CreateIntCast(Size
,
464 Alloc
->getFunctionType()->getParamType(0),
465 /*is signed*/ false);
466 auto *Call
= Builder
.CreateCall(Alloc
, Size
);
467 propagateCallAttrsFromCallee(Call
, Alloc
);
468 addCallToCallGraph(CG
, Call
, Alloc
);
471 case coro::ABI::Async
:
472 llvm_unreachable("can't allocate memory in coro async-lowering");
474 llvm_unreachable("Unknown coro::ABI enum");
477 void coro::Shape::emitDealloc(IRBuilder
<> &Builder
, Value
*Ptr
,
478 CallGraph
*CG
) const {
480 case coro::ABI::Switch
:
481 llvm_unreachable("can't allocate memory in coro switch-lowering");
483 case coro::ABI::Retcon
:
484 case coro::ABI::RetconOnce
: {
485 auto Dealloc
= RetconLowering
.Dealloc
;
486 Ptr
= Builder
.CreateBitCast(Ptr
,
487 Dealloc
->getFunctionType()->getParamType(0));
488 auto *Call
= Builder
.CreateCall(Dealloc
, Ptr
);
489 propagateCallAttrsFromCallee(Call
, Dealloc
);
490 addCallToCallGraph(CG
, Call
, Dealloc
);
493 case coro::ABI::Async
:
494 llvm_unreachable("can't allocate memory in coro async-lowering");
496 llvm_unreachable("Unknown coro::ABI enum");
499 [[noreturn
]] static void fail(const Instruction
*I
, const char *Reason
,
504 errs() << " Value: ";
505 V
->printAsOperand(llvm::errs());
509 report_fatal_error(Reason
);
512 /// Check that the given value is a well-formed prototype for the
513 /// llvm.coro.id.retcon.* intrinsics.
514 static void checkWFRetconPrototype(const AnyCoroIdRetconInst
*I
, Value
*V
) {
515 auto F
= dyn_cast
<Function
>(V
->stripPointerCasts());
517 fail(I
, "llvm.coro.id.retcon.* prototype not a Function", V
);
519 auto FT
= F
->getFunctionType();
521 if (isa
<CoroIdRetconInst
>(I
)) {
523 if (FT
->getReturnType()->isPointerTy()) {
525 } else if (auto SRetTy
= dyn_cast
<StructType
>(FT
->getReturnType())) {
526 ResultOkay
= (!SRetTy
->isOpaque() &&
527 SRetTy
->getNumElements() > 0 &&
528 SRetTy
->getElementType(0)->isPointerTy());
533 fail(I
, "llvm.coro.id.retcon prototype must return pointer as first "
536 if (FT
->getReturnType() !=
537 I
->getFunction()->getFunctionType()->getReturnType())
538 fail(I
, "llvm.coro.id.retcon prototype return type must be same as"
539 "current function return type", F
);
541 // No meaningful validation to do here for llvm.coro.id.unique.once.
544 if (FT
->getNumParams() == 0 || !FT
->getParamType(0)->isPointerTy())
545 fail(I
, "llvm.coro.id.retcon.* prototype must take pointer as "
546 "its first parameter", F
);
549 /// Check that the given value is a well-formed allocator.
550 static void checkWFAlloc(const Instruction
*I
, Value
*V
) {
551 auto F
= dyn_cast
<Function
>(V
->stripPointerCasts());
553 fail(I
, "llvm.coro.* allocator not a Function", V
);
555 auto FT
= F
->getFunctionType();
556 if (!FT
->getReturnType()->isPointerTy())
557 fail(I
, "llvm.coro.* allocator must return a pointer", F
);
559 if (FT
->getNumParams() != 1 ||
560 !FT
->getParamType(0)->isIntegerTy())
561 fail(I
, "llvm.coro.* allocator must take integer as only param", F
);
564 /// Check that the given value is a well-formed deallocator.
565 static void checkWFDealloc(const Instruction
*I
, Value
*V
) {
566 auto F
= dyn_cast
<Function
>(V
->stripPointerCasts());
568 fail(I
, "llvm.coro.* deallocator not a Function", V
);
570 auto FT
= F
->getFunctionType();
571 if (!FT
->getReturnType()->isVoidTy())
572 fail(I
, "llvm.coro.* deallocator must return void", F
);
574 if (FT
->getNumParams() != 1 ||
575 !FT
->getParamType(0)->isPointerTy())
576 fail(I
, "llvm.coro.* deallocator must take pointer as only param", F
);
579 static void checkConstantInt(const Instruction
*I
, Value
*V
,
580 const char *Reason
) {
581 if (!isa
<ConstantInt
>(V
)) {
586 void AnyCoroIdRetconInst::checkWellFormed() const {
587 checkConstantInt(this, getArgOperand(SizeArg
),
588 "size argument to coro.id.retcon.* must be constant");
589 checkConstantInt(this, getArgOperand(AlignArg
),
590 "alignment argument to coro.id.retcon.* must be constant");
591 checkWFRetconPrototype(this, getArgOperand(PrototypeArg
));
592 checkWFAlloc(this, getArgOperand(AllocArg
));
593 checkWFDealloc(this, getArgOperand(DeallocArg
));
596 static void checkAsyncFuncPointer(const Instruction
*I
, Value
*V
) {
597 auto *AsyncFuncPtrAddr
= dyn_cast
<GlobalVariable
>(V
->stripPointerCasts());
598 if (!AsyncFuncPtrAddr
)
599 fail(I
, "llvm.coro.id.async async function pointer not a global", V
);
602 void CoroIdAsyncInst::checkWellFormed() const {
603 checkConstantInt(this, getArgOperand(SizeArg
),
604 "size argument to coro.id.async must be constant");
605 checkConstantInt(this, getArgOperand(AlignArg
),
606 "alignment argument to coro.id.async must be constant");
607 checkConstantInt(this, getArgOperand(StorageArg
),
608 "storage argument offset to coro.id.async must be constant");
609 checkAsyncFuncPointer(this, getArgOperand(AsyncFuncPtrArg
));
612 static void checkAsyncContextProjectFunction(const Instruction
*I
,
614 auto *FunTy
= cast
<FunctionType
>(F
->getValueType());
615 if (!FunTy
->getReturnType()->isPointerTy())
617 "llvm.coro.suspend.async resume function projection function must "
620 if (FunTy
->getNumParams() != 1 || !FunTy
->getParamType(0)->isPointerTy())
622 "llvm.coro.suspend.async resume function projection function must "
623 "take one ptr type as parameter",
627 void CoroSuspendAsyncInst::checkWellFormed() const {
628 checkAsyncContextProjectFunction(this, getAsyncContextProjectionFunction());
631 void CoroAsyncEndInst::checkWellFormed() const {
632 auto *MustTailCallFunc
= getMustTailCallFunction();
633 if (!MustTailCallFunc
)
635 auto *FnTy
= MustTailCallFunc
->getFunctionType();
636 if (FnTy
->getNumParams() != (arg_size() - 3))
638 "llvm.coro.end.async must tail call function argument type must "
639 "match the tail arguments",