1 //===- Coroutines.cpp -----------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the common infrastructure for Coroutine Passes.
11 //===----------------------------------------------------------------------===//
13 #include "CoroInstr.h"
14 #include "CoroInternal.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "llvm/Analysis/CallGraph.h"
18 #include "llvm/IR/Attributes.h"
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/DerivedTypes.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/InstIterator.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/IntrinsicInst.h"
25 #include "llvm/IR/Intrinsics.h"
26 #include "llvm/IR/Module.h"
27 #include "llvm/IR/Type.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Transforms/Utils/Local.h"
37 // Construct the lowerer base class and initialize its members.
38 coro::LowererBase::LowererBase(Module
&M
)
39 : TheModule(M
), Context(M
.getContext()),
40 Int8Ptr(PointerType::get(Context
, 0)),
41 ResumeFnType(FunctionType::get(Type::getVoidTy(Context
), Int8Ptr
,
43 NullPtr(ConstantPointerNull::get(Int8Ptr
)) {}
45 // Creates a call to llvm.coro.subfn.addr to obtain a resume function address.
46 // It generates the following:
48 // call ptr @llvm.coro.subfn.addr(ptr %Arg, i8 %index)
50 Value
*coro::LowererBase::makeSubFnCall(Value
*Arg
, int Index
,
51 Instruction
*InsertPt
) {
52 auto *IndexVal
= ConstantInt::get(Type::getInt8Ty(Context
), Index
);
53 auto *Fn
= Intrinsic::getDeclaration(&TheModule
, Intrinsic::coro_subfn_addr
);
55 assert(Index
>= CoroSubFnInst::IndexFirst
&&
56 Index
< CoroSubFnInst::IndexLast
&&
57 "makeSubFnCall: Index value out of range");
58 return CallInst::Create(Fn
, {Arg
, IndexVal
}, "", InsertPt
);
61 // NOTE: Must be sorted!
62 static const char *const CoroIntrinsics
[] = {
65 "llvm.coro.async.context.alloc",
66 "llvm.coro.async.context.dealloc",
67 "llvm.coro.async.resume",
68 "llvm.coro.async.size.replace",
69 "llvm.coro.async.store_resume",
74 "llvm.coro.end.async",
79 "llvm.coro.id.retcon",
80 "llvm.coro.id.retcon.once",
82 "llvm.coro.prepare.async",
83 "llvm.coro.prepare.retcon",
88 "llvm.coro.subfn.addr",
90 "llvm.coro.suspend.async",
91 "llvm.coro.suspend.retcon",
95 static bool isCoroutineIntrinsicName(StringRef Name
) {
96 return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics
, Name
) != -1;
100 bool coro::declaresAnyIntrinsic(const Module
&M
) {
101 for (StringRef Name
: CoroIntrinsics
) {
102 assert(isCoroutineIntrinsicName(Name
) && "not a coroutine intrinsic");
103 if (M
.getNamedValue(Name
))
110 // Verifies if a module has named values listed. Also, in debug mode verifies
111 // that names are intrinsic names.
112 bool coro::declaresIntrinsics(const Module
&M
,
113 const std::initializer_list
<StringRef
> List
) {
114 for (StringRef Name
: List
) {
115 assert(isCoroutineIntrinsicName(Name
) && "not a coroutine intrinsic");
116 if (M
.getNamedValue(Name
))
123 // Replace all coro.frees associated with the provided CoroId either with 'null'
124 // if Elide is true and with its frame parameter otherwise.
125 void coro::replaceCoroFree(CoroIdInst
*CoroId
, bool Elide
) {
126 SmallVector
<CoroFreeInst
*, 4> CoroFrees
;
127 for (User
*U
: CoroId
->users())
128 if (auto CF
= dyn_cast
<CoroFreeInst
>(U
))
129 CoroFrees
.push_back(CF
);
131 if (CoroFrees
.empty())
136 ? ConstantPointerNull::get(PointerType::get(CoroId
->getContext(), 0))
137 : CoroFrees
.front()->getFrame();
139 for (CoroFreeInst
*CF
: CoroFrees
) {
140 CF
->replaceAllUsesWith(Replacement
);
141 CF
->eraseFromParent();
145 static void clear(coro::Shape
&Shape
) {
146 Shape
.CoroBegin
= nullptr;
147 Shape
.CoroEnds
.clear();
148 Shape
.CoroSizes
.clear();
149 Shape
.CoroSuspends
.clear();
151 Shape
.FrameTy
= nullptr;
152 Shape
.FramePtr
= nullptr;
153 Shape
.AllocaSpillBlock
= nullptr;
156 static CoroSaveInst
*createCoroSave(CoroBeginInst
*CoroBegin
,
157 CoroSuspendInst
*SuspendInst
) {
158 Module
*M
= SuspendInst
->getModule();
159 auto *Fn
= Intrinsic::getDeclaration(M
, Intrinsic::coro_save
);
161 cast
<CoroSaveInst
>(CallInst::Create(Fn
, CoroBegin
, "", SuspendInst
));
162 assert(!SuspendInst
->getCoroSave());
163 SuspendInst
->setArgOperand(0, SaveInst
);
167 // Collect "interesting" coroutine intrinsics.
168 void coro::Shape::buildFrom(Function
&F
) {
169 bool HasFinalSuspend
= false;
170 bool HasUnwindCoroEnd
= false;
171 size_t FinalSuspendIndex
= 0;
173 SmallVector
<CoroFrameInst
*, 8> CoroFrames
;
174 SmallVector
<CoroSaveInst
*, 2> UnusedCoroSaves
;
176 for (Instruction
&I
: instructions(F
)) {
177 if (auto II
= dyn_cast
<IntrinsicInst
>(&I
)) {
178 switch (II
->getIntrinsicID()) {
181 case Intrinsic::coro_size
:
182 CoroSizes
.push_back(cast
<CoroSizeInst
>(II
));
184 case Intrinsic::coro_align
:
185 CoroAligns
.push_back(cast
<CoroAlignInst
>(II
));
187 case Intrinsic::coro_frame
:
188 CoroFrames
.push_back(cast
<CoroFrameInst
>(II
));
190 case Intrinsic::coro_save
:
191 // After optimizations, coro_suspends using this coro_save might have
192 // been removed, remember orphaned coro_saves to remove them later.
194 UnusedCoroSaves
.push_back(cast
<CoroSaveInst
>(II
));
196 case Intrinsic::coro_suspend_async
: {
197 auto *Suspend
= cast
<CoroSuspendAsyncInst
>(II
);
198 Suspend
->checkWellFormed();
199 CoroSuspends
.push_back(Suspend
);
202 case Intrinsic::coro_suspend_retcon
: {
203 auto Suspend
= cast
<CoroSuspendRetconInst
>(II
);
204 CoroSuspends
.push_back(Suspend
);
207 case Intrinsic::coro_suspend
: {
208 auto Suspend
= cast
<CoroSuspendInst
>(II
);
209 CoroSuspends
.push_back(Suspend
);
210 if (Suspend
->isFinal()) {
213 "Only one suspend point can be marked as final");
214 HasFinalSuspend
= true;
215 FinalSuspendIndex
= CoroSuspends
.size() - 1;
219 case Intrinsic::coro_begin
: {
220 auto CB
= cast
<CoroBeginInst
>(II
);
222 // Ignore coro id's that aren't pre-split.
223 auto Id
= dyn_cast
<CoroIdInst
>(CB
->getId());
224 if (Id
&& !Id
->getInfo().isPreSplit())
229 "coroutine should have exactly one defining @llvm.coro.begin");
230 CB
->addRetAttr(Attribute::NonNull
);
231 CB
->addRetAttr(Attribute::NoAlias
);
232 CB
->removeFnAttr(Attribute::NoDuplicate
);
236 case Intrinsic::coro_end_async
:
237 case Intrinsic::coro_end
:
238 CoroEnds
.push_back(cast
<AnyCoroEndInst
>(II
));
239 if (auto *AsyncEnd
= dyn_cast
<CoroAsyncEndInst
>(II
)) {
240 AsyncEnd
->checkWellFormed();
243 if (CoroEnds
.back()->isUnwind())
244 HasUnwindCoroEnd
= true;
246 if (CoroEnds
.back()->isFallthrough() && isa
<CoroEndInst
>(II
)) {
247 // Make sure that the fallthrough coro.end is the first element in the
249 // Note: I don't think this is neccessary anymore.
250 if (CoroEnds
.size() > 1) {
251 if (CoroEnds
.front()->isFallthrough())
253 "Only one coro.end can be marked as fallthrough");
254 std::swap(CoroEnds
.front(), CoroEnds
.back());
262 // If for some reason, we were not able to find coro.begin, bailout.
264 // Replace coro.frame which are supposed to be lowered to the result of
265 // coro.begin with undef.
266 auto *Undef
= UndefValue::get(PointerType::get(F
.getContext(), 0));
267 for (CoroFrameInst
*CF
: CoroFrames
) {
268 CF
->replaceAllUsesWith(Undef
);
269 CF
->eraseFromParent();
272 // Replace all coro.suspend with undef and remove related coro.saves if
274 for (AnyCoroSuspendInst
*CS
: CoroSuspends
) {
275 CS
->replaceAllUsesWith(UndefValue::get(CS
->getType()));
276 CS
->eraseFromParent();
277 if (auto *CoroSave
= CS
->getCoroSave())
278 CoroSave
->eraseFromParent();
281 // Replace all coro.ends with unreachable instruction.
282 for (AnyCoroEndInst
*CE
: CoroEnds
)
283 changeToUnreachable(CE
);
288 auto Id
= CoroBegin
->getId();
289 switch (auto IdIntrinsic
= Id
->getIntrinsicID()) {
290 case Intrinsic::coro_id
: {
291 auto SwitchId
= cast
<CoroIdInst
>(Id
);
292 this->ABI
= coro::ABI::Switch
;
293 this->SwitchLowering
.HasFinalSuspend
= HasFinalSuspend
;
294 this->SwitchLowering
.HasUnwindCoroEnd
= HasUnwindCoroEnd
;
295 this->SwitchLowering
.ResumeSwitch
= nullptr;
296 this->SwitchLowering
.PromiseAlloca
= SwitchId
->getPromise();
297 this->SwitchLowering
.ResumeEntryBlock
= nullptr;
299 for (auto *AnySuspend
: CoroSuspends
) {
300 auto Suspend
= dyn_cast
<CoroSuspendInst
>(AnySuspend
);
305 report_fatal_error("coro.id must be paired with coro.suspend");
308 if (!Suspend
->getCoroSave())
309 createCoroSave(CoroBegin
, Suspend
);
313 case Intrinsic::coro_id_async
: {
314 auto *AsyncId
= cast
<CoroIdAsyncInst
>(Id
);
315 AsyncId
->checkWellFormed();
316 this->ABI
= coro::ABI::Async
;
317 this->AsyncLowering
.Context
= AsyncId
->getStorage();
318 this->AsyncLowering
.ContextArgNo
= AsyncId
->getStorageArgumentIndex();
319 this->AsyncLowering
.ContextHeaderSize
= AsyncId
->getStorageSize();
320 this->AsyncLowering
.ContextAlignment
=
321 AsyncId
->getStorageAlignment().value();
322 this->AsyncLowering
.AsyncFuncPointer
= AsyncId
->getAsyncFunctionPointer();
323 this->AsyncLowering
.AsyncCC
= F
.getCallingConv();
326 case Intrinsic::coro_id_retcon
:
327 case Intrinsic::coro_id_retcon_once
: {
328 auto ContinuationId
= cast
<AnyCoroIdRetconInst
>(Id
);
329 ContinuationId
->checkWellFormed();
330 this->ABI
= (IdIntrinsic
== Intrinsic::coro_id_retcon
332 : coro::ABI::RetconOnce
);
333 auto Prototype
= ContinuationId
->getPrototype();
334 this->RetconLowering
.ResumePrototype
= Prototype
;
335 this->RetconLowering
.Alloc
= ContinuationId
->getAllocFunction();
336 this->RetconLowering
.Dealloc
= ContinuationId
->getDeallocFunction();
337 this->RetconLowering
.ReturnBlock
= nullptr;
338 this->RetconLowering
.IsFrameInlineInStorage
= false;
340 // Determine the result value types, and make sure they match up with
341 // the values passed to the suspends.
342 auto ResultTys
= getRetconResultTypes();
343 auto ResumeTys
= getRetconResumeTypes();
345 for (auto *AnySuspend
: CoroSuspends
) {
346 auto Suspend
= dyn_cast
<CoroSuspendRetconInst
>(AnySuspend
);
351 report_fatal_error("coro.id.retcon.* must be paired with "
352 "coro.suspend.retcon");
355 // Check that the argument types of the suspend match the results.
356 auto SI
= Suspend
->value_begin(), SE
= Suspend
->value_end();
357 auto RI
= ResultTys
.begin(), RE
= ResultTys
.end();
358 for (; SI
!= SE
&& RI
!= RE
; ++SI
, ++RI
) {
359 auto SrcTy
= (*SI
)->getType();
361 // The optimizer likes to eliminate bitcasts leading into variadic
362 // calls, but that messes with our invariants. Re-insert the
363 // bitcast and ignore this type mismatch.
364 if (CastInst::isBitCastable(SrcTy
, *RI
)) {
365 auto BCI
= new BitCastInst(*SI
, *RI
, "", Suspend
);
372 Prototype
->getFunctionType()->dump();
374 report_fatal_error("argument to coro.suspend.retcon does not "
375 "match corresponding prototype function result");
378 if (SI
!= SE
|| RI
!= RE
) {
381 Prototype
->getFunctionType()->dump();
383 report_fatal_error("wrong number of arguments to coro.suspend.retcon");
386 // Check that the result type of the suspend matches the resume types.
387 Type
*SResultTy
= Suspend
->getType();
388 ArrayRef
<Type
*> SuspendResultTys
;
389 if (SResultTy
->isVoidTy()) {
390 // leave as empty array
391 } else if (auto SResultStructTy
= dyn_cast
<StructType
>(SResultTy
)) {
392 SuspendResultTys
= SResultStructTy
->elements();
394 // forms an ArrayRef using SResultTy, be careful
395 SuspendResultTys
= SResultTy
;
397 if (SuspendResultTys
.size() != ResumeTys
.size()) {
400 Prototype
->getFunctionType()->dump();
402 report_fatal_error("wrong number of results from coro.suspend.retcon");
404 for (size_t I
= 0, E
= ResumeTys
.size(); I
!= E
; ++I
) {
405 if (SuspendResultTys
[I
] != ResumeTys
[I
]) {
408 Prototype
->getFunctionType()->dump();
410 report_fatal_error("result from coro.suspend.retcon does not "
411 "match corresponding prototype function param");
419 llvm_unreachable("coro.begin is not dependent on a coro.id call");
422 // The coro.free intrinsic is always lowered to the result of coro.begin.
423 for (CoroFrameInst
*CF
: CoroFrames
) {
424 CF
->replaceAllUsesWith(CoroBegin
);
425 CF
->eraseFromParent();
428 // Move final suspend to be the last element in the CoroSuspends vector.
429 if (ABI
== coro::ABI::Switch
&&
430 SwitchLowering
.HasFinalSuspend
&&
431 FinalSuspendIndex
!= CoroSuspends
.size() - 1)
432 std::swap(CoroSuspends
[FinalSuspendIndex
], CoroSuspends
.back());
434 // Remove orphaned coro.saves.
435 for (CoroSaveInst
*CoroSave
: UnusedCoroSaves
)
436 CoroSave
->eraseFromParent();
439 static void propagateCallAttrsFromCallee(CallInst
*Call
, Function
*Callee
) {
440 Call
->setCallingConv(Callee
->getCallingConv());
444 static void addCallToCallGraph(CallGraph
*CG
, CallInst
*Call
, Function
*Callee
){
446 (*CG
)[Call
->getFunction()]->addCalledFunction(Call
, (*CG
)[Callee
]);
449 Value
*coro::Shape::emitAlloc(IRBuilder
<> &Builder
, Value
*Size
,
450 CallGraph
*CG
) const {
452 case coro::ABI::Switch
:
453 llvm_unreachable("can't allocate memory in coro switch-lowering");
455 case coro::ABI::Retcon
:
456 case coro::ABI::RetconOnce
: {
457 auto Alloc
= RetconLowering
.Alloc
;
458 Size
= Builder
.CreateIntCast(Size
,
459 Alloc
->getFunctionType()->getParamType(0),
460 /*is signed*/ false);
461 auto *Call
= Builder
.CreateCall(Alloc
, Size
);
462 propagateCallAttrsFromCallee(Call
, Alloc
);
463 addCallToCallGraph(CG
, Call
, Alloc
);
466 case coro::ABI::Async
:
467 llvm_unreachable("can't allocate memory in coro async-lowering");
469 llvm_unreachable("Unknown coro::ABI enum");
472 void coro::Shape::emitDealloc(IRBuilder
<> &Builder
, Value
*Ptr
,
473 CallGraph
*CG
) const {
475 case coro::ABI::Switch
:
476 llvm_unreachable("can't allocate memory in coro switch-lowering");
478 case coro::ABI::Retcon
:
479 case coro::ABI::RetconOnce
: {
480 auto Dealloc
= RetconLowering
.Dealloc
;
481 Ptr
= Builder
.CreateBitCast(Ptr
,
482 Dealloc
->getFunctionType()->getParamType(0));
483 auto *Call
= Builder
.CreateCall(Dealloc
, Ptr
);
484 propagateCallAttrsFromCallee(Call
, Dealloc
);
485 addCallToCallGraph(CG
, Call
, Dealloc
);
488 case coro::ABI::Async
:
489 llvm_unreachable("can't allocate memory in coro async-lowering");
491 llvm_unreachable("Unknown coro::ABI enum");
494 [[noreturn
]] static void fail(const Instruction
*I
, const char *Reason
,
499 errs() << " Value: ";
500 V
->printAsOperand(llvm::errs());
504 report_fatal_error(Reason
);
507 /// Check that the given value is a well-formed prototype for the
508 /// llvm.coro.id.retcon.* intrinsics.
509 static void checkWFRetconPrototype(const AnyCoroIdRetconInst
*I
, Value
*V
) {
510 auto F
= dyn_cast
<Function
>(V
->stripPointerCasts());
512 fail(I
, "llvm.coro.id.retcon.* prototype not a Function", V
);
514 auto FT
= F
->getFunctionType();
516 if (isa
<CoroIdRetconInst
>(I
)) {
518 if (FT
->getReturnType()->isPointerTy()) {
520 } else if (auto SRetTy
= dyn_cast
<StructType
>(FT
->getReturnType())) {
521 ResultOkay
= (!SRetTy
->isOpaque() &&
522 SRetTy
->getNumElements() > 0 &&
523 SRetTy
->getElementType(0)->isPointerTy());
528 fail(I
, "llvm.coro.id.retcon prototype must return pointer as first "
531 if (FT
->getReturnType() !=
532 I
->getFunction()->getFunctionType()->getReturnType())
533 fail(I
, "llvm.coro.id.retcon prototype return type must be same as"
534 "current function return type", F
);
536 // No meaningful validation to do here for llvm.coro.id.unique.once.
539 if (FT
->getNumParams() == 0 || !FT
->getParamType(0)->isPointerTy())
540 fail(I
, "llvm.coro.id.retcon.* prototype must take pointer as "
541 "its first parameter", F
);
544 /// Check that the given value is a well-formed allocator.
545 static void checkWFAlloc(const Instruction
*I
, Value
*V
) {
546 auto F
= dyn_cast
<Function
>(V
->stripPointerCasts());
548 fail(I
, "llvm.coro.* allocator not a Function", V
);
550 auto FT
= F
->getFunctionType();
551 if (!FT
->getReturnType()->isPointerTy())
552 fail(I
, "llvm.coro.* allocator must return a pointer", F
);
554 if (FT
->getNumParams() != 1 ||
555 !FT
->getParamType(0)->isIntegerTy())
556 fail(I
, "llvm.coro.* allocator must take integer as only param", F
);
559 /// Check that the given value is a well-formed deallocator.
560 static void checkWFDealloc(const Instruction
*I
, Value
*V
) {
561 auto F
= dyn_cast
<Function
>(V
->stripPointerCasts());
563 fail(I
, "llvm.coro.* deallocator not a Function", V
);
565 auto FT
= F
->getFunctionType();
566 if (!FT
->getReturnType()->isVoidTy())
567 fail(I
, "llvm.coro.* deallocator must return void", F
);
569 if (FT
->getNumParams() != 1 ||
570 !FT
->getParamType(0)->isPointerTy())
571 fail(I
, "llvm.coro.* deallocator must take pointer as only param", F
);
574 static void checkConstantInt(const Instruction
*I
, Value
*V
,
575 const char *Reason
) {
576 if (!isa
<ConstantInt
>(V
)) {
581 void AnyCoroIdRetconInst::checkWellFormed() const {
582 checkConstantInt(this, getArgOperand(SizeArg
),
583 "size argument to coro.id.retcon.* must be constant");
584 checkConstantInt(this, getArgOperand(AlignArg
),
585 "alignment argument to coro.id.retcon.* must be constant");
586 checkWFRetconPrototype(this, getArgOperand(PrototypeArg
));
587 checkWFAlloc(this, getArgOperand(AllocArg
));
588 checkWFDealloc(this, getArgOperand(DeallocArg
));
591 static void checkAsyncFuncPointer(const Instruction
*I
, Value
*V
) {
592 auto *AsyncFuncPtrAddr
= dyn_cast
<GlobalVariable
>(V
->stripPointerCasts());
593 if (!AsyncFuncPtrAddr
)
594 fail(I
, "llvm.coro.id.async async function pointer not a global", V
);
597 void CoroIdAsyncInst::checkWellFormed() const {
598 checkConstantInt(this, getArgOperand(SizeArg
),
599 "size argument to coro.id.async must be constant");
600 checkConstantInt(this, getArgOperand(AlignArg
),
601 "alignment argument to coro.id.async must be constant");
602 checkConstantInt(this, getArgOperand(StorageArg
),
603 "storage argument offset to coro.id.async must be constant");
604 checkAsyncFuncPointer(this, getArgOperand(AsyncFuncPtrArg
));
607 static void checkAsyncContextProjectFunction(const Instruction
*I
,
609 auto *FunTy
= cast
<FunctionType
>(F
->getValueType());
610 if (!FunTy
->getReturnType()->isPointerTy())
612 "llvm.coro.suspend.async resume function projection function must "
615 if (FunTy
->getNumParams() != 1 || !FunTy
->getParamType(0)->isPointerTy())
617 "llvm.coro.suspend.async resume function projection function must "
618 "take one ptr type as parameter",
622 void CoroSuspendAsyncInst::checkWellFormed() const {
623 checkAsyncContextProjectFunction(this, getAsyncContextProjectionFunction());
626 void CoroAsyncEndInst::checkWellFormed() const {
627 auto *MustTailCallFunc
= getMustTailCallFunction();
628 if (!MustTailCallFunc
)
630 auto *FnTy
= MustTailCallFunc
->getFunctionType();
631 if (FnTy
->getNumParams() != (arg_size() - 3))
633 "llvm.coro.end.async must tail call function argument type must "
634 "match the tail arguments",