1 //===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
8 // This pass builds the coroutine frame and outlines resume and destroy parts
9 // of the coroutine into separate functions.
11 // We present a coroutine to an LLVM as an ordinary function with suspension
12 // points marked up with intrinsics. We let the optimizer party on the coroutine
13 // as a single function for as long as possible. Shortly before the coroutine is
14 // eligible to be inlined into its callers, we split up the coroutine into parts
15 // corresponding to an initial, resume and destroy invocations of the coroutine,
16 // add them to the current SCC and restart the IPO pipeline to optimize the
17 // coroutine subfunctions we extracted before proceeding to the caller of the
19 //===----------------------------------------------------------------------===//
21 #include "CoroInstr.h"
22 #include "CoroInternal.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/SmallPtrSet.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/ADT/Twine.h"
28 #include "llvm/Analysis/CallGraph.h"
29 #include "llvm/Analysis/CallGraphSCCPass.h"
30 #include "llvm/Transforms/Utils/Local.h"
31 #include "llvm/IR/Argument.h"
32 #include "llvm/IR/Attributes.h"
33 #include "llvm/IR/BasicBlock.h"
34 #include "llvm/IR/CFG.h"
35 #include "llvm/IR/CallSite.h"
36 #include "llvm/IR/CallingConv.h"
37 #include "llvm/IR/Constants.h"
38 #include "llvm/IR/DataLayout.h"
39 #include "llvm/IR/DerivedTypes.h"
40 #include "llvm/IR/Function.h"
41 #include "llvm/IR/GlobalValue.h"
42 #include "llvm/IR/GlobalVariable.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/InstIterator.h"
45 #include "llvm/IR/InstrTypes.h"
46 #include "llvm/IR/Instruction.h"
47 #include "llvm/IR/Instructions.h"
48 #include "llvm/IR/IntrinsicInst.h"
49 #include "llvm/IR/LLVMContext.h"
50 #include "llvm/IR/LegacyPassManager.h"
51 #include "llvm/IR/Module.h"
52 #include "llvm/IR/Type.h"
53 #include "llvm/IR/Value.h"
54 #include "llvm/IR/Verifier.h"
55 #include "llvm/Pass.h"
56 #include "llvm/Support/Casting.h"
57 #include "llvm/Support/Debug.h"
58 #include "llvm/Support/raw_ostream.h"
59 #include "llvm/Transforms/Scalar.h"
60 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
61 #include "llvm/Transforms/Utils/Cloning.h"
62 #include "llvm/Transforms/Utils/ValueMapper.h"
66 #include <initializer_list>
71 #define DEBUG_TYPE "coro-split"
73 // Create an entry block for a resume function with a switch that will jump to
75 static BasicBlock
*createResumeEntryBlock(Function
&F
, coro::Shape
&Shape
) {
76 LLVMContext
&C
= F
.getContext();
79 // %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0,
81 // % index = load i32, i32* %index.addr
82 // switch i32 %index, label %unreachable [
83 // i32 0, label %resume.0
84 // i32 1, label %resume.1
88 auto *NewEntry
= BasicBlock::Create(C
, "resume.entry", &F
);
89 auto *UnreachBB
= BasicBlock::Create(C
, "unreachable", &F
);
91 IRBuilder
<> Builder(NewEntry
);
92 auto *FramePtr
= Shape
.FramePtr
;
93 auto *FrameTy
= Shape
.FrameTy
;
94 auto *GepIndex
= Builder
.CreateConstInBoundsGEP2_32(
95 FrameTy
, FramePtr
, 0, coro::Shape::IndexField
, "index.addr");
96 auto *Index
= Builder
.CreateLoad(Shape
.getIndexType(), GepIndex
, "index");
98 Builder
.CreateSwitch(Index
, UnreachBB
, Shape
.CoroSuspends
.size());
99 Shape
.ResumeSwitch
= Switch
;
101 size_t SuspendIndex
= 0;
102 for (CoroSuspendInst
*S
: Shape
.CoroSuspends
) {
103 ConstantInt
*IndexVal
= Shape
.getIndex(SuspendIndex
);
105 // Replace CoroSave with a store to Index:
106 // %index.addr = getelementptr %f.frame... (index field number)
107 // store i32 0, i32* %index.addr1
108 auto *Save
= S
->getCoroSave();
109 Builder
.SetInsertPoint(Save
);
111 // Final suspend point is represented by storing zero in ResumeFnAddr.
112 auto *GepIndex
= Builder
.CreateConstInBoundsGEP2_32(FrameTy
, FramePtr
, 0,
114 auto *NullPtr
= ConstantPointerNull::get(cast
<PointerType
>(
115 cast
<PointerType
>(GepIndex
->getType())->getElementType()));
116 Builder
.CreateStore(NullPtr
, GepIndex
);
118 auto *GepIndex
= Builder
.CreateConstInBoundsGEP2_32(
119 FrameTy
, FramePtr
, 0, coro::Shape::IndexField
, "index.addr");
120 Builder
.CreateStore(IndexVal
, GepIndex
);
122 Save
->replaceAllUsesWith(ConstantTokenNone::get(C
));
123 Save
->eraseFromParent();
125 // Split block before and after coro.suspend and add a jump from an entry
130 // %0 = call i8 @llvm.coro.suspend(token none, i1 false)
131 // switch i8 %0, label %suspend[i8 0, label %resume
132 // i8 1, label %cleanup]
137 // br label %resume.0.landing
139 // resume.0: ; <--- jump from the switch in the resume.entry
140 // %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
141 // br label %resume.0.landing
144 // %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
145 // switch i8 % 1, label %suspend [i8 0, label %resume
146 // i8 1, label %cleanup]
148 auto *SuspendBB
= S
->getParent();
150 SuspendBB
->splitBasicBlock(S
, "resume." + Twine(SuspendIndex
));
151 auto *LandingBB
= ResumeBB
->splitBasicBlock(
152 S
->getNextNode(), ResumeBB
->getName() + Twine(".landing"));
153 Switch
->addCase(IndexVal
, ResumeBB
);
155 cast
<BranchInst
>(SuspendBB
->getTerminator())->setSuccessor(0, LandingBB
);
156 auto *PN
= PHINode::Create(Builder
.getInt8Ty(), 2, "", &LandingBB
->front());
157 S
->replaceAllUsesWith(PN
);
158 PN
->addIncoming(Builder
.getInt8(-1), SuspendBB
);
159 PN
->addIncoming(S
, ResumeBB
);
164 Builder
.SetInsertPoint(UnreachBB
);
165 Builder
.CreateUnreachable();
170 // In Resumers, we replace fallthrough coro.end with ret void and delete the
171 // rest of the block.
172 static void replaceFallthroughCoroEnd(IntrinsicInst
*End
,
173 ValueToValueMapTy
&VMap
) {
174 auto *NewE
= cast
<IntrinsicInst
>(VMap
[End
]);
175 ReturnInst::Create(NewE
->getContext(), nullptr, NewE
);
177 // Remove the rest of the block, by splitting it into an unreachable block.
178 auto *BB
= NewE
->getParent();
179 BB
->splitBasicBlock(NewE
);
180 BB
->getTerminator()->eraseFromParent();
183 // In Resumers, we replace unwind coro.end with True to force the immediate
185 static void replaceUnwindCoroEnds(coro::Shape
&Shape
, ValueToValueMapTy
&VMap
) {
186 if (Shape
.CoroEnds
.empty())
189 LLVMContext
&Context
= Shape
.CoroEnds
.front()->getContext();
190 auto *True
= ConstantInt::getTrue(Context
);
191 for (CoroEndInst
*CE
: Shape
.CoroEnds
) {
195 auto *NewCE
= cast
<IntrinsicInst
>(VMap
[CE
]);
197 // If coro.end has an associated bundle, add cleanupret instruction.
198 if (auto Bundle
= NewCE
->getOperandBundle(LLVMContext::OB_funclet
)) {
199 Value
*FromPad
= Bundle
->Inputs
[0];
200 auto *CleanupRet
= CleanupReturnInst::Create(FromPad
, nullptr, NewCE
);
201 NewCE
->getParent()->splitBasicBlock(NewCE
);
202 CleanupRet
->getParent()->getTerminator()->eraseFromParent();
205 NewCE
->replaceAllUsesWith(True
);
206 NewCE
->eraseFromParent();
210 // Rewrite final suspend point handling. We do not use suspend index to
211 // represent the final suspend point. Instead we zero-out ResumeFnAddr in the
212 // coroutine frame, since it is undefined behavior to resume a coroutine
213 // suspended at the final suspend point. Thus, in the resume function, we can
214 // simply remove the last case (when coro::Shape is built, the final suspend
215 // point (if present) is always the last element of CoroSuspends array).
216 // In the destroy function, we add a code sequence to check if ResumeFnAddress
217 // is Null, and if so, jump to the appropriate label to handle cleanup from the
218 // final suspend point.
219 static void handleFinalSuspend(IRBuilder
<> &Builder
, Value
*FramePtr
,
220 coro::Shape
&Shape
, SwitchInst
*Switch
,
222 assert(Shape
.HasFinalSuspend
);
223 auto FinalCaseIt
= std::prev(Switch
->case_end());
224 BasicBlock
*ResumeBB
= FinalCaseIt
->getCaseSuccessor();
225 Switch
->removeCase(FinalCaseIt
);
227 BasicBlock
*OldSwitchBB
= Switch
->getParent();
228 auto *NewSwitchBB
= OldSwitchBB
->splitBasicBlock(Switch
, "Switch");
229 Builder
.SetInsertPoint(OldSwitchBB
->getTerminator());
230 auto *GepIndex
= Builder
.CreateConstInBoundsGEP2_32(Shape
.FrameTy
, FramePtr
,
231 0, 0, "ResumeFn.addr");
232 auto *Load
= Builder
.CreateLoad(
233 Shape
.FrameTy
->getElementType(coro::Shape::ResumeField
), GepIndex
);
235 ConstantPointerNull::get(cast
<PointerType
>(Load
->getType()));
236 auto *Cond
= Builder
.CreateICmpEQ(Load
, NullPtr
);
237 Builder
.CreateCondBr(Cond
, ResumeBB
, NewSwitchBB
);
238 OldSwitchBB
->getTerminator()->eraseFromParent();
242 // Create a resume clone by cloning the body of the original function, setting
243 // new entry block and replacing coro.suspend an appropriate value to force
244 // resume or cleanup pass for every suspend point.
245 static Function
*createClone(Function
&F
, Twine Suffix
, coro::Shape
&Shape
,
246 BasicBlock
*ResumeEntry
, int8_t FnIndex
) {
247 Module
*M
= F
.getParent();
248 auto *FrameTy
= Shape
.FrameTy
;
249 auto *FnPtrTy
= cast
<PointerType
>(FrameTy
->getElementType(0));
250 auto *FnTy
= cast
<FunctionType
>(FnPtrTy
->getElementType());
253 Function::Create(FnTy
, GlobalValue::LinkageTypes::ExternalLinkage
,
254 F
.getName() + Suffix
, M
);
255 NewF
->addParamAttr(0, Attribute::NonNull
);
256 NewF
->addParamAttr(0, Attribute::NoAlias
);
258 ValueToValueMapTy VMap
;
259 // Replace all args with undefs. The buildCoroutineFrame algorithm already
260 // rewritten access to the args that occurs after suspend points with loads
261 // and stores to/from the coroutine frame.
262 for (Argument
&A
: F
.args())
263 VMap
[&A
] = UndefValue::get(A
.getType());
265 SmallVector
<ReturnInst
*, 4> Returns
;
267 CloneFunctionInto(NewF
, &F
, VMap
, /*ModuleLevelChanges=*/true, Returns
);
268 NewF
->setLinkage(GlobalValue::LinkageTypes::InternalLinkage
);
270 // Remove old returns.
271 for (ReturnInst
*Return
: Returns
)
272 changeToUnreachable(Return
, /*UseLLVMTrap=*/false);
274 // Remove old return attributes.
275 NewF
->removeAttributes(
276 AttributeList::ReturnIndex
,
277 AttributeFuncs::typeIncompatible(NewF
->getReturnType()));
279 // Make AllocaSpillBlock the new entry block.
280 auto *SwitchBB
= cast
<BasicBlock
>(VMap
[ResumeEntry
]);
281 auto *Entry
= cast
<BasicBlock
>(VMap
[Shape
.AllocaSpillBlock
]);
282 Entry
->moveBefore(&NewF
->getEntryBlock());
283 Entry
->getTerminator()->eraseFromParent();
284 BranchInst::Create(SwitchBB
, Entry
);
285 Entry
->setName("entry" + Suffix
);
287 // Clear all predecessors of the new entry block.
288 auto *Switch
= cast
<SwitchInst
>(VMap
[Shape
.ResumeSwitch
]);
289 Entry
->replaceAllUsesWith(Switch
->getDefaultDest());
291 IRBuilder
<> Builder(&NewF
->getEntryBlock().front());
293 // Remap frame pointer.
294 Argument
*NewFramePtr
= &*NewF
->arg_begin();
295 Value
*OldFramePtr
= cast
<Value
>(VMap
[Shape
.FramePtr
]);
296 NewFramePtr
->takeName(OldFramePtr
);
297 OldFramePtr
->replaceAllUsesWith(NewFramePtr
);
299 // Remap vFrame pointer.
300 auto *NewVFrame
= Builder
.CreateBitCast(
301 NewFramePtr
, Type::getInt8PtrTy(Builder
.getContext()), "vFrame");
302 Value
*OldVFrame
= cast
<Value
>(VMap
[Shape
.CoroBegin
]);
303 OldVFrame
->replaceAllUsesWith(NewVFrame
);
305 // Rewrite final suspend handling as it is not done via switch (allows to
306 // remove final case from the switch, since it is undefined behavior to resume
307 // the coroutine suspended at the final suspend point.
308 if (Shape
.HasFinalSuspend
) {
309 auto *Switch
= cast
<SwitchInst
>(VMap
[Shape
.ResumeSwitch
]);
310 bool IsDestroy
= FnIndex
!= 0;
311 handleFinalSuspend(Builder
, NewFramePtr
, Shape
, Switch
, IsDestroy
);
314 // Replace coro suspend with the appropriate resume index.
315 // Replacing coro.suspend with (0) will result in control flow proceeding to
316 // a resume label associated with a suspend point, replacing it with (1) will
317 // result in control flow proceeding to a cleanup label associated with this
319 auto *NewValue
= Builder
.getInt8(FnIndex
? 1 : 0);
320 for (CoroSuspendInst
*CS
: Shape
.CoroSuspends
) {
321 auto *MappedCS
= cast
<CoroSuspendInst
>(VMap
[CS
]);
322 MappedCS
->replaceAllUsesWith(NewValue
);
323 MappedCS
->eraseFromParent();
326 // Remove coro.end intrinsics.
327 replaceFallthroughCoroEnd(Shape
.CoroEnds
.front(), VMap
);
328 replaceUnwindCoroEnds(Shape
, VMap
);
329 // Eliminate coro.free from the clones, replacing it with 'null' in cleanup,
330 // to suppress deallocation code.
331 coro::replaceCoroFree(cast
<CoroIdInst
>(VMap
[Shape
.CoroBegin
->getId()]),
332 /*Elide=*/FnIndex
== 2);
334 NewF
->setCallingConv(CallingConv::Fast
);
339 static void removeCoroEnds(coro::Shape
&Shape
) {
340 if (Shape
.CoroEnds
.empty())
343 LLVMContext
&Context
= Shape
.CoroEnds
.front()->getContext();
344 auto *False
= ConstantInt::getFalse(Context
);
346 for (CoroEndInst
*CE
: Shape
.CoroEnds
) {
347 CE
->replaceAllUsesWith(False
);
348 CE
->eraseFromParent();
352 static void replaceFrameSize(coro::Shape
&Shape
) {
353 if (Shape
.CoroSizes
.empty())
356 // In the same function all coro.sizes should have the same result type.
357 auto *SizeIntrin
= Shape
.CoroSizes
.back();
358 Module
*M
= SizeIntrin
->getModule();
359 const DataLayout
&DL
= M
->getDataLayout();
360 auto Size
= DL
.getTypeAllocSize(Shape
.FrameTy
);
361 auto *SizeConstant
= ConstantInt::get(SizeIntrin
->getType(), Size
);
363 for (CoroSizeInst
*CS
: Shape
.CoroSizes
) {
364 CS
->replaceAllUsesWith(SizeConstant
);
365 CS
->eraseFromParent();
369 // Create a global constant array containing pointers to functions provided and
370 // set Info parameter of CoroBegin to point at this constant. Example:
372 // @f.resumers = internal constant [2 x void(%f.frame*)*]
373 // [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy]
374 // define void @f() {
376 // call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
377 // i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*))
379 // Assumes that all the functions have the same signature.
380 static void setCoroInfo(Function
&F
, CoroBeginInst
*CoroBegin
,
381 std::initializer_list
<Function
*> Fns
) {
382 SmallVector
<Constant
*, 4> Args(Fns
.begin(), Fns
.end());
383 assert(!Args
.empty());
384 Function
*Part
= *Fns
.begin();
385 Module
*M
= Part
->getParent();
386 auto *ArrTy
= ArrayType::get(Part
->getType(), Args
.size());
388 auto *ConstVal
= ConstantArray::get(ArrTy
, Args
);
389 auto *GV
= new GlobalVariable(*M
, ConstVal
->getType(), /*isConstant=*/true,
390 GlobalVariable::PrivateLinkage
, ConstVal
,
391 F
.getName() + Twine(".resumers"));
393 // Update coro.begin instruction to refer to this constant.
394 LLVMContext
&C
= F
.getContext();
395 auto *BC
= ConstantExpr::getPointerCast(GV
, Type::getInt8PtrTy(C
));
396 CoroBegin
->getId()->setInfo(BC
);
399 // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
400 static void updateCoroFrame(coro::Shape
&Shape
, Function
*ResumeFn
,
401 Function
*DestroyFn
, Function
*CleanupFn
) {
402 IRBuilder
<> Builder(Shape
.FramePtr
->getNextNode());
403 auto *ResumeAddr
= Builder
.CreateConstInBoundsGEP2_32(
404 Shape
.FrameTy
, Shape
.FramePtr
, 0, coro::Shape::ResumeField
,
406 Builder
.CreateStore(ResumeFn
, ResumeAddr
);
408 Value
*DestroyOrCleanupFn
= DestroyFn
;
410 CoroIdInst
*CoroId
= Shape
.CoroBegin
->getId();
411 if (CoroAllocInst
*CA
= CoroId
->getCoroAlloc()) {
412 // If there is a CoroAlloc and it returns false (meaning we elide the
413 // allocation, use CleanupFn instead of DestroyFn).
414 DestroyOrCleanupFn
= Builder
.CreateSelect(CA
, DestroyFn
, CleanupFn
);
417 auto *DestroyAddr
= Builder
.CreateConstInBoundsGEP2_32(
418 Shape
.FrameTy
, Shape
.FramePtr
, 0, coro::Shape::DestroyField
,
420 Builder
.CreateStore(DestroyOrCleanupFn
, DestroyAddr
);
423 static void postSplitCleanup(Function
&F
) {
424 removeUnreachableBlocks(F
);
425 legacy::FunctionPassManager
FPM(F
.getParent());
427 FPM
.add(createVerifierPass());
428 FPM
.add(createSCCPPass());
429 FPM
.add(createCFGSimplificationPass());
430 FPM
.add(createEarlyCSEPass());
431 FPM
.add(createCFGSimplificationPass());
433 FPM
.doInitialization();
435 FPM
.doFinalization();
438 // Assuming we arrived at the block NewBlock from Prev instruction, store
439 // PHI's incoming values in the ResolvedValues map.
441 scanPHIsAndUpdateValueMap(Instruction
*Prev
, BasicBlock
*NewBlock
,
442 DenseMap
<Value
*, Value
*> &ResolvedValues
) {
443 auto *PrevBB
= Prev
->getParent();
444 for (PHINode
&PN
: NewBlock
->phis()) {
445 auto V
= PN
.getIncomingValueForBlock(PrevBB
);
446 // See if we already resolved it.
447 auto VI
= ResolvedValues
.find(V
);
448 if (VI
!= ResolvedValues
.end())
450 // Remember the value.
451 ResolvedValues
[&PN
] = V
;
455 // Replace a sequence of branches leading to a ret, with a clone of a ret
456 // instruction. Suspend instruction represented by a switch, track the PHI
457 // values and select the correct case successor when possible.
458 static bool simplifyTerminatorLeadingToRet(Instruction
*InitialInst
) {
459 DenseMap
<Value
*, Value
*> ResolvedValues
;
461 Instruction
*I
= InitialInst
;
462 while (I
->isTerminator()) {
463 if (isa
<ReturnInst
>(I
)) {
464 if (I
!= InitialInst
)
465 ReplaceInstWithInst(InitialInst
, I
->clone());
468 if (auto *BR
= dyn_cast
<BranchInst
>(I
)) {
469 if (BR
->isUnconditional()) {
470 BasicBlock
*BB
= BR
->getSuccessor(0);
471 scanPHIsAndUpdateValueMap(I
, BB
, ResolvedValues
);
472 I
= BB
->getFirstNonPHIOrDbgOrLifetime();
475 } else if (auto *SI
= dyn_cast
<SwitchInst
>(I
)) {
476 Value
*V
= SI
->getCondition();
477 auto it
= ResolvedValues
.find(V
);
478 if (it
!= ResolvedValues
.end())
480 if (ConstantInt
*Cond
= dyn_cast
<ConstantInt
>(V
)) {
481 BasicBlock
*BB
= SI
->findCaseValue(Cond
)->getCaseSuccessor();
482 scanPHIsAndUpdateValueMap(I
, BB
, ResolvedValues
);
483 I
= BB
->getFirstNonPHIOrDbgOrLifetime();
492 // Add musttail to any resume instructions that is immediately followed by a
493 // suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call
494 // for symmetrical coroutine control transfer (C++ Coroutines TS extension).
495 // This transformation is done only in the resume part of the coroutine that has
496 // identical signature and calling convention as the coro.resume call.
497 static void addMustTailToCoroResumes(Function
&F
) {
498 bool changed
= false;
500 // Collect potential resume instructions.
501 SmallVector
<CallInst
*, 4> Resumes
;
502 for (auto &I
: instructions(F
))
503 if (auto *Call
= dyn_cast
<CallInst
>(&I
))
504 if (auto *CalledValue
= Call
->getCalledValue())
505 // CoroEarly pass replaced coro resumes with indirect calls to an
506 // address return by CoroSubFnInst intrinsic. See if it is one of those.
507 if (isa
<CoroSubFnInst
>(CalledValue
->stripPointerCasts()))
508 Resumes
.push_back(Call
);
510 // Set musttail on those that are followed by a ret instruction.
511 for (CallInst
*Call
: Resumes
)
512 if (simplifyTerminatorLeadingToRet(Call
->getNextNode())) {
513 Call
->setTailCallKind(CallInst::TCK_MustTail
);
518 removeUnreachableBlocks(F
);
521 // Coroutine has no suspend points. Remove heap allocation for the coroutine
522 // frame if possible.
523 static void handleNoSuspendCoroutine(CoroBeginInst
*CoroBegin
, Type
*FrameTy
) {
524 auto *CoroId
= CoroBegin
->getId();
525 auto *AllocInst
= CoroId
->getCoroAlloc();
526 coro::replaceCoroFree(CoroId
, /*Elide=*/AllocInst
!= nullptr);
528 IRBuilder
<> Builder(AllocInst
);
529 // FIXME: Need to handle overaligned members.
530 auto *Frame
= Builder
.CreateAlloca(FrameTy
);
531 auto *VFrame
= Builder
.CreateBitCast(Frame
, Builder
.getInt8PtrTy());
532 AllocInst
->replaceAllUsesWith(Builder
.getFalse());
533 AllocInst
->eraseFromParent();
534 CoroBegin
->replaceAllUsesWith(VFrame
);
536 CoroBegin
->replaceAllUsesWith(CoroBegin
->getMem());
538 CoroBegin
->eraseFromParent();
541 // SimplifySuspendPoint needs to check that there is no calls between
542 // coro_save and coro_suspend, since any of the calls may potentially resume
543 // the coroutine and if that is the case we cannot eliminate the suspend point.
544 static bool hasCallsInBlockBetween(Instruction
*From
, Instruction
*To
) {
545 for (Instruction
*I
= From
; I
!= To
; I
= I
->getNextNode()) {
546 // Assume that no intrinsic can resume the coroutine.
547 if (isa
<IntrinsicInst
>(I
))
556 static bool hasCallsInBlocksBetween(BasicBlock
*SaveBB
, BasicBlock
*ResDesBB
) {
557 SmallPtrSet
<BasicBlock
*, 8> Set
;
558 SmallVector
<BasicBlock
*, 8> Worklist
;
561 Worklist
.push_back(ResDesBB
);
563 // Accumulate all blocks between SaveBB and ResDesBB. Because CoroSaveIntr
564 // returns a token consumed by suspend instruction, all blocks in between
565 // will have to eventually hit SaveBB when going backwards from ResDesBB.
566 while (!Worklist
.empty()) {
567 auto *BB
= Worklist
.pop_back_val();
569 for (auto *Pred
: predecessors(BB
))
570 if (Set
.count(Pred
) == 0)
571 Worklist
.push_back(Pred
);
574 // SaveBB and ResDesBB are checked separately in hasCallsBetween.
579 if (hasCallsInBlockBetween(BB
->getFirstNonPHI(), nullptr))
585 static bool hasCallsBetween(Instruction
*Save
, Instruction
*ResumeOrDestroy
) {
586 auto *SaveBB
= Save
->getParent();
587 auto *ResumeOrDestroyBB
= ResumeOrDestroy
->getParent();
589 if (SaveBB
== ResumeOrDestroyBB
)
590 return hasCallsInBlockBetween(Save
->getNextNode(), ResumeOrDestroy
);
592 // Any calls from Save to the end of the block?
593 if (hasCallsInBlockBetween(Save
->getNextNode(), nullptr))
596 // Any calls from begging of the block up to ResumeOrDestroy?
597 if (hasCallsInBlockBetween(ResumeOrDestroyBB
->getFirstNonPHI(),
601 // Any calls in all of the blocks between SaveBB and ResumeOrDestroyBB?
602 if (hasCallsInBlocksBetween(SaveBB
, ResumeOrDestroyBB
))
608 // If a SuspendIntrin is preceded by Resume or Destroy, we can eliminate the
609 // suspend point and replace it with nornal control flow.
610 static bool simplifySuspendPoint(CoroSuspendInst
*Suspend
,
611 CoroBeginInst
*CoroBegin
) {
612 Instruction
*Prev
= Suspend
->getPrevNode();
614 auto *Pred
= Suspend
->getParent()->getSinglePredecessor();
617 Prev
= Pred
->getTerminator();
624 auto *CallInstr
= CS
.getInstruction();
626 auto *Callee
= CS
.getCalledValue()->stripPointerCasts();
628 // See if the callsite is for resumption or destruction of the coroutine.
629 auto *SubFn
= dyn_cast
<CoroSubFnInst
>(Callee
);
633 // Does not refer to the current coroutine, we cannot do anything with it.
634 if (SubFn
->getFrame() != CoroBegin
)
637 // See if the transformation is safe. Specifically, see if there are any
638 // calls in between Save and CallInstr. They can potenitally resume the
639 // coroutine rendering this optimization unsafe.
640 auto *Save
= Suspend
->getCoroSave();
641 if (hasCallsBetween(Save
, CallInstr
))
644 // Replace llvm.coro.suspend with the value that results in resumption over
645 // the resume or cleanup path.
646 Suspend
->replaceAllUsesWith(SubFn
->getRawIndex());
647 Suspend
->eraseFromParent();
648 Save
->eraseFromParent();
650 // No longer need a call to coro.resume or coro.destroy.
651 if (auto *Invoke
= dyn_cast
<InvokeInst
>(CallInstr
)) {
652 BranchInst::Create(Invoke
->getNormalDest(), Invoke
);
655 // Grab the CalledValue from CS before erasing the CallInstr.
656 auto *CalledValue
= CS
.getCalledValue();
657 CallInstr
->eraseFromParent();
659 // If no more users remove it. Usually it is a bitcast of SubFn.
660 if (CalledValue
!= SubFn
&& CalledValue
->user_empty())
661 if (auto *I
= dyn_cast
<Instruction
>(CalledValue
))
662 I
->eraseFromParent();
664 // Now we are good to remove SubFn.
665 if (SubFn
->user_empty())
666 SubFn
->eraseFromParent();
671 // Remove suspend points that are simplified.
672 static void simplifySuspendPoints(coro::Shape
&Shape
) {
673 auto &S
= Shape
.CoroSuspends
;
674 size_t I
= 0, N
= S
.size();
678 if (simplifySuspendPoint(S
[I
], Shape
.CoroBegin
)) {
681 std::swap(S
[I
], S
[N
]);
690 static SmallPtrSet
<BasicBlock
*, 4> getCoroBeginPredBlocks(CoroBeginInst
*CB
) {
691 // Collect all blocks that we need to look for instructions to relocate.
692 SmallPtrSet
<BasicBlock
*, 4> RelocBlocks
;
693 SmallVector
<BasicBlock
*, 4> Work
;
694 Work
.push_back(CB
->getParent());
697 BasicBlock
*Current
= Work
.pop_back_val();
698 for (BasicBlock
*BB
: predecessors(Current
))
699 if (RelocBlocks
.count(BB
) == 0) {
700 RelocBlocks
.insert(BB
);
703 } while (!Work
.empty());
707 static SmallPtrSet
<Instruction
*, 8>
708 getNotRelocatableInstructions(CoroBeginInst
*CoroBegin
,
709 SmallPtrSetImpl
<BasicBlock
*> &RelocBlocks
) {
710 SmallPtrSet
<Instruction
*, 8> DoNotRelocate
;
711 // Collect all instructions that we should not relocate
712 SmallVector
<Instruction
*, 8> Work
;
714 // Start with CoroBegin and terminators of all preceding blocks.
715 Work
.push_back(CoroBegin
);
716 BasicBlock
*CoroBeginBB
= CoroBegin
->getParent();
717 for (BasicBlock
*BB
: RelocBlocks
)
718 if (BB
!= CoroBeginBB
)
719 Work
.push_back(BB
->getTerminator());
721 // For every instruction in the Work list, place its operands in DoNotRelocate
724 Instruction
*Current
= Work
.pop_back_val();
725 LLVM_DEBUG(dbgs() << "CoroSplit: Will not relocate: " << *Current
<< "\n");
726 DoNotRelocate
.insert(Current
);
727 for (Value
*U
: Current
->operands()) {
728 auto *I
= dyn_cast
<Instruction
>(U
);
732 if (auto *A
= dyn_cast
<AllocaInst
>(I
)) {
733 // Stores to alloca instructions that occur before the coroutine frame
734 // is allocated should not be moved; the stored values may be used by
735 // the coroutine frame allocator. The operands to those stores must also
737 for (const auto &User
: A
->users())
738 if (auto *SI
= dyn_cast
<llvm::StoreInst
>(User
))
739 if (RelocBlocks
.count(SI
->getParent()) != 0 &&
740 DoNotRelocate
.count(SI
) == 0) {
742 DoNotRelocate
.insert(SI
);
747 if (DoNotRelocate
.count(I
) == 0) {
749 DoNotRelocate
.insert(I
);
752 } while (!Work
.empty());
753 return DoNotRelocate
;
756 static void relocateInstructionBefore(CoroBeginInst
*CoroBegin
, Function
&F
) {
757 // Analyze which non-alloca instructions are needed for allocation and
758 // relocate the rest to after coro.begin. We need to do it, since some of the
759 // targets of those instructions may be placed into coroutine frame memory
760 // for which becomes available after coro.begin intrinsic.
762 auto BlockSet
= getCoroBeginPredBlocks(CoroBegin
);
763 auto DoNotRelocateSet
= getNotRelocatableInstructions(CoroBegin
, BlockSet
);
765 Instruction
*InsertPt
= CoroBegin
->getNextNode();
766 BasicBlock
&BB
= F
.getEntryBlock(); // TODO: Look at other blocks as well.
767 for (auto B
= BB
.begin(), E
= BB
.end(); B
!= E
;) {
768 Instruction
&I
= *B
++;
769 if (isa
<AllocaInst
>(&I
))
773 if (DoNotRelocateSet
.count(&I
))
775 I
.moveBefore(InsertPt
);
779 static void splitCoroutine(Function
&F
, CallGraph
&CG
, CallGraphSCC
&SCC
) {
780 EliminateUnreachableBlocks(F
);
782 coro::Shape
Shape(F
);
783 if (!Shape
.CoroBegin
)
786 simplifySuspendPoints(Shape
);
787 relocateInstructionBefore(Shape
.CoroBegin
, F
);
788 buildCoroutineFrame(F
, Shape
);
789 replaceFrameSize(Shape
);
791 // If there are no suspend points, no split required, just remove
792 // the allocation and deallocation blocks, they are not needed.
793 if (Shape
.CoroSuspends
.empty()) {
794 handleNoSuspendCoroutine(Shape
.CoroBegin
, Shape
.FrameTy
);
795 removeCoroEnds(Shape
);
797 coro::updateCallGraph(F
, {}, CG
, SCC
);
801 auto *ResumeEntry
= createResumeEntryBlock(F
, Shape
);
802 auto ResumeClone
= createClone(F
, ".resume", Shape
, ResumeEntry
, 0);
803 auto DestroyClone
= createClone(F
, ".destroy", Shape
, ResumeEntry
, 1);
804 auto CleanupClone
= createClone(F
, ".cleanup", Shape
, ResumeEntry
, 2);
806 // We no longer need coro.end in F.
807 removeCoroEnds(Shape
);
810 postSplitCleanup(*ResumeClone
);
811 postSplitCleanup(*DestroyClone
);
812 postSplitCleanup(*CleanupClone
);
814 addMustTailToCoroResumes(*ResumeClone
);
816 // Store addresses resume/destroy/cleanup functions in the coroutine frame.
817 updateCoroFrame(Shape
, ResumeClone
, DestroyClone
, CleanupClone
);
819 // Create a constant array referring to resume/destroy/clone functions pointed
820 // by the last argument of @llvm.coro.info, so that CoroElide pass can
821 // determined correct function to call.
822 setCoroInfo(F
, Shape
.CoroBegin
, {ResumeClone
, DestroyClone
, CleanupClone
});
824 // Update call graph and add the functions we created to the SCC.
825 coro::updateCallGraph(F
, {ResumeClone
, DestroyClone
, CleanupClone
}, CG
, SCC
);
828 // When we see the coroutine the first time, we insert an indirect call to a
829 // devirt trigger function and mark the coroutine that it is now ready for
831 static void prepareForSplit(Function
&F
, CallGraph
&CG
) {
832 Module
&M
= *F
.getParent();
833 LLVMContext
&Context
= F
.getContext();
835 Function
*DevirtFn
= M
.getFunction(CORO_DEVIRT_TRIGGER_FN
);
836 assert(DevirtFn
&& "coro.devirt.trigger function not found");
839 F
.addFnAttr(CORO_PRESPLIT_ATTR
, PREPARED_FOR_SPLIT
);
841 // Insert an indirect call sequence that will be devirtualized by CoroElide
843 // %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1)
844 // %1 = bitcast i8* %0 to void(i8*)*
845 // call void %1(i8* null)
846 coro::LowererBase
Lowerer(M
);
847 Instruction
*InsertPt
= F
.getEntryBlock().getTerminator();
848 auto *Null
= ConstantPointerNull::get(Type::getInt8PtrTy(Context
));
850 Lowerer
.makeSubFnCall(Null
, CoroSubFnInst::RestartTrigger
, InsertPt
);
851 FunctionType
*FnTy
= FunctionType::get(Type::getVoidTy(Context
),
852 {Type::getInt8PtrTy(Context
)}, false);
853 auto *IndirectCall
= CallInst::Create(FnTy
, DevirtFnAddr
, Null
, "", InsertPt
);
855 // Update CG graph with an indirect call we just added.
856 CG
[&F
]->addCalledFunction(IndirectCall
, CG
.getCallsExternalNode());
859 // Make sure that there is a devirtualization trigger function that CoroSplit
860 // pass uses the force restart CGSCC pipeline. If devirt trigger function is not
861 // found, we will create one and add it to the current SCC.
862 static void createDevirtTriggerFunc(CallGraph
&CG
, CallGraphSCC
&SCC
) {
863 Module
&M
= CG
.getModule();
864 if (M
.getFunction(CORO_DEVIRT_TRIGGER_FN
))
867 LLVMContext
&C
= M
.getContext();
868 auto *FnTy
= FunctionType::get(Type::getVoidTy(C
), Type::getInt8PtrTy(C
),
869 /*IsVarArgs=*/false);
871 Function::Create(FnTy
, GlobalValue::LinkageTypes::PrivateLinkage
,
872 CORO_DEVIRT_TRIGGER_FN
, &M
);
873 DevirtFn
->addFnAttr(Attribute::AlwaysInline
);
874 auto *Entry
= BasicBlock::Create(C
, "entry", DevirtFn
);
875 ReturnInst::Create(C
, Entry
);
877 auto *Node
= CG
.getOrInsertFunction(DevirtFn
);
879 SmallVector
<CallGraphNode
*, 8> Nodes(SCC
.begin(), SCC
.end());
880 Nodes
.push_back(Node
);
881 SCC
.initialize(Nodes
);
884 //===----------------------------------------------------------------------===//
886 //===----------------------------------------------------------------------===//
890 struct CoroSplit
: public CallGraphSCCPass
{
891 static char ID
; // Pass identification, replacement for typeid
893 CoroSplit() : CallGraphSCCPass(ID
) {
894 initializeCoroSplitPass(*PassRegistry::getPassRegistry());
899 // A coroutine is identified by the presence of coro.begin intrinsic, if
900 // we don't have any, this pass has nothing to do.
901 bool doInitialization(CallGraph
&CG
) override
{
902 Run
= coro::declaresIntrinsics(CG
.getModule(), {"llvm.coro.begin"});
903 return CallGraphSCCPass::doInitialization(CG
);
906 bool runOnSCC(CallGraphSCC
&SCC
) override
{
910 // Find coroutines for processing.
911 SmallVector
<Function
*, 4> Coroutines
;
912 for (CallGraphNode
*CGN
: SCC
)
913 if (auto *F
= CGN
->getFunction())
914 if (F
->hasFnAttribute(CORO_PRESPLIT_ATTR
))
915 Coroutines
.push_back(F
);
917 if (Coroutines
.empty())
920 CallGraph
&CG
= getAnalysis
<CallGraphWrapperPass
>().getCallGraph();
921 createDevirtTriggerFunc(CG
, SCC
);
923 for (Function
*F
: Coroutines
) {
924 Attribute Attr
= F
->getFnAttribute(CORO_PRESPLIT_ATTR
);
925 StringRef Value
= Attr
.getValueAsString();
926 LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F
->getName()
927 << "' state: " << Value
<< "\n");
928 if (Value
== UNPREPARED_FOR_SPLIT
) {
929 prepareForSplit(*F
, CG
);
932 F
->removeFnAttr(CORO_PRESPLIT_ATTR
);
933 splitCoroutine(*F
, CG
, SCC
);
938 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
939 CallGraphSCCPass::getAnalysisUsage(AU
);
942 StringRef
getPassName() const override
{ return "Coroutine Splitting"; }
945 } // end anonymous namespace
947 char CoroSplit::ID
= 0;
950 CoroSplit
, "coro-split",
951 "Split coroutine into a set of functions driving its state machine", false,
954 Pass
*llvm::createCoroSplitPass() { return new CoroSplit(); }