1 //===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
9 // This pass builds the coroutine frame and outlines resume and destroy parts
10 // of the coroutine into separate functions.
12 // We present a coroutine to an LLVM as an ordinary function with suspension
13 // points marked up with intrinsics. We let the optimizer party on the coroutine
14 // as a single function for as long as possible. Shortly before the coroutine is
15 // eligible to be inlined into its callers, we split up the coroutine into parts
16 // corresponding to an initial, resume and destroy invocations of the coroutine,
17 // add them to the current SCC and restart the IPO pipeline to optimize the
18 // coroutine subfunctions we extracted before proceeding to the caller of the
20 //===----------------------------------------------------------------------===//
22 #include "CoroInstr.h"
23 #include "CoroInternal.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/ADT/Twine.h"
29 #include "llvm/Analysis/CallGraph.h"
30 #include "llvm/Analysis/CallGraphSCCPass.h"
31 #include "llvm/Transforms/Utils/Local.h"
32 #include "llvm/IR/Argument.h"
33 #include "llvm/IR/Attributes.h"
34 #include "llvm/IR/BasicBlock.h"
35 #include "llvm/IR/CFG.h"
36 #include "llvm/IR/CallSite.h"
37 #include "llvm/IR/CallingConv.h"
38 #include "llvm/IR/Constants.h"
39 #include "llvm/IR/DataLayout.h"
40 #include "llvm/IR/DerivedTypes.h"
41 #include "llvm/IR/Function.h"
42 #include "llvm/IR/GlobalValue.h"
43 #include "llvm/IR/GlobalVariable.h"
44 #include "llvm/IR/IRBuilder.h"
45 #include "llvm/IR/InstIterator.h"
46 #include "llvm/IR/InstrTypes.h"
47 #include "llvm/IR/Instruction.h"
48 #include "llvm/IR/Instructions.h"
49 #include "llvm/IR/IntrinsicInst.h"
50 #include "llvm/IR/LLVMContext.h"
51 #include "llvm/IR/LegacyPassManager.h"
52 #include "llvm/IR/Module.h"
53 #include "llvm/IR/Type.h"
54 #include "llvm/IR/Value.h"
55 #include "llvm/IR/Verifier.h"
56 #include "llvm/Pass.h"
57 #include "llvm/Support/Casting.h"
58 #include "llvm/Support/Debug.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/Transforms/Scalar.h"
61 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
62 #include "llvm/Transforms/Utils/Cloning.h"
63 #include "llvm/Transforms/Utils/ValueMapper.h"
67 #include <initializer_list>
72 #define DEBUG_TYPE "coro-split"
74 // Create an entry block for a resume function with a switch that will jump to
76 static BasicBlock
*createResumeEntryBlock(Function
&F
, coro::Shape
&Shape
) {
77 LLVMContext
&C
= F
.getContext();
80 // %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0,
82 // % index = load i32, i32* %index.addr
83 // switch i32 %index, label %unreachable [
84 // i32 0, label %resume.0
85 // i32 1, label %resume.1
89 auto *NewEntry
= BasicBlock::Create(C
, "resume.entry", &F
);
90 auto *UnreachBB
= BasicBlock::Create(C
, "unreachable", &F
);
92 IRBuilder
<> Builder(NewEntry
);
93 auto *FramePtr
= Shape
.FramePtr
;
94 auto *FrameTy
= Shape
.FrameTy
;
95 auto *GepIndex
= Builder
.CreateConstInBoundsGEP2_32(
96 FrameTy
, FramePtr
, 0, coro::Shape::IndexField
, "index.addr");
97 auto *Index
= Builder
.CreateLoad(GepIndex
, "index");
99 Builder
.CreateSwitch(Index
, UnreachBB
, Shape
.CoroSuspends
.size());
100 Shape
.ResumeSwitch
= Switch
;
102 size_t SuspendIndex
= 0;
103 for (CoroSuspendInst
*S
: Shape
.CoroSuspends
) {
104 ConstantInt
*IndexVal
= Shape
.getIndex(SuspendIndex
);
106 // Replace CoroSave with a store to Index:
107 // %index.addr = getelementptr %f.frame... (index field number)
108 // store i32 0, i32* %index.addr1
109 auto *Save
= S
->getCoroSave();
110 Builder
.SetInsertPoint(Save
);
112 // Final suspend point is represented by storing zero in ResumeFnAddr.
113 auto *GepIndex
= Builder
.CreateConstInBoundsGEP2_32(FrameTy
, FramePtr
, 0,
115 auto *NullPtr
= ConstantPointerNull::get(cast
<PointerType
>(
116 cast
<PointerType
>(GepIndex
->getType())->getElementType()));
117 Builder
.CreateStore(NullPtr
, GepIndex
);
119 auto *GepIndex
= Builder
.CreateConstInBoundsGEP2_32(
120 FrameTy
, FramePtr
, 0, coro::Shape::IndexField
, "index.addr");
121 Builder
.CreateStore(IndexVal
, GepIndex
);
123 Save
->replaceAllUsesWith(ConstantTokenNone::get(C
));
124 Save
->eraseFromParent();
126 // Split block before and after coro.suspend and add a jump from an entry
131 // %0 = call i8 @llvm.coro.suspend(token none, i1 false)
132 // switch i8 %0, label %suspend[i8 0, label %resume
133 // i8 1, label %cleanup]
138 // br label %resume.0.landing
140 // resume.0: ; <--- jump from the switch in the resume.entry
141 // %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
142 // br label %resume.0.landing
145 // %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
146 // switch i8 % 1, label %suspend [i8 0, label %resume
147 // i8 1, label %cleanup]
149 auto *SuspendBB
= S
->getParent();
151 SuspendBB
->splitBasicBlock(S
, "resume." + Twine(SuspendIndex
));
152 auto *LandingBB
= ResumeBB
->splitBasicBlock(
153 S
->getNextNode(), ResumeBB
->getName() + Twine(".landing"));
154 Switch
->addCase(IndexVal
, ResumeBB
);
156 cast
<BranchInst
>(SuspendBB
->getTerminator())->setSuccessor(0, LandingBB
);
157 auto *PN
= PHINode::Create(Builder
.getInt8Ty(), 2, "", &LandingBB
->front());
158 S
->replaceAllUsesWith(PN
);
159 PN
->addIncoming(Builder
.getInt8(-1), SuspendBB
);
160 PN
->addIncoming(S
, ResumeBB
);
165 Builder
.SetInsertPoint(UnreachBB
);
166 Builder
.CreateUnreachable();
171 // In Resumers, we replace fallthrough coro.end with ret void and delete the
172 // rest of the block.
173 static void replaceFallthroughCoroEnd(IntrinsicInst
*End
,
174 ValueToValueMapTy
&VMap
) {
175 auto *NewE
= cast
<IntrinsicInst
>(VMap
[End
]);
176 ReturnInst::Create(NewE
->getContext(), nullptr, NewE
);
178 // Remove the rest of the block, by splitting it into an unreachable block.
179 auto *BB
= NewE
->getParent();
180 BB
->splitBasicBlock(NewE
);
181 BB
->getTerminator()->eraseFromParent();
184 // In Resumers, we replace unwind coro.end with True to force the immediate
186 static void replaceUnwindCoroEnds(coro::Shape
&Shape
, ValueToValueMapTy
&VMap
) {
187 if (Shape
.CoroEnds
.empty())
190 LLVMContext
&Context
= Shape
.CoroEnds
.front()->getContext();
191 auto *True
= ConstantInt::getTrue(Context
);
192 for (CoroEndInst
*CE
: Shape
.CoroEnds
) {
196 auto *NewCE
= cast
<IntrinsicInst
>(VMap
[CE
]);
198 // If coro.end has an associated bundle, add cleanupret instruction.
199 if (auto Bundle
= NewCE
->getOperandBundle(LLVMContext::OB_funclet
)) {
200 Value
*FromPad
= Bundle
->Inputs
[0];
201 auto *CleanupRet
= CleanupReturnInst::Create(FromPad
, nullptr, NewCE
);
202 NewCE
->getParent()->splitBasicBlock(NewCE
);
203 CleanupRet
->getParent()->getTerminator()->eraseFromParent();
206 NewCE
->replaceAllUsesWith(True
);
207 NewCE
->eraseFromParent();
211 // Rewrite final suspend point handling. We do not use suspend index to
212 // represent the final suspend point. Instead we zero-out ResumeFnAddr in the
213 // coroutine frame, since it is undefined behavior to resume a coroutine
214 // suspended at the final suspend point. Thus, in the resume function, we can
215 // simply remove the last case (when coro::Shape is built, the final suspend
216 // point (if present) is always the last element of CoroSuspends array).
217 // In the destroy function, we add a code sequence to check if ResumeFnAddress
218 // is Null, and if so, jump to the appropriate label to handle cleanup from the
219 // final suspend point.
220 static void handleFinalSuspend(IRBuilder
<> &Builder
, Value
*FramePtr
,
221 coro::Shape
&Shape
, SwitchInst
*Switch
,
223 assert(Shape
.HasFinalSuspend
);
224 auto FinalCaseIt
= std::prev(Switch
->case_end());
225 BasicBlock
*ResumeBB
= FinalCaseIt
->getCaseSuccessor();
226 Switch
->removeCase(FinalCaseIt
);
228 BasicBlock
*OldSwitchBB
= Switch
->getParent();
229 auto *NewSwitchBB
= OldSwitchBB
->splitBasicBlock(Switch
, "Switch");
230 Builder
.SetInsertPoint(OldSwitchBB
->getTerminator());
231 auto *GepIndex
= Builder
.CreateConstInBoundsGEP2_32(Shape
.FrameTy
, FramePtr
,
232 0, 0, "ResumeFn.addr");
233 auto *Load
= Builder
.CreateLoad(GepIndex
);
235 ConstantPointerNull::get(cast
<PointerType
>(Load
->getType()));
236 auto *Cond
= Builder
.CreateICmpEQ(Load
, NullPtr
);
237 Builder
.CreateCondBr(Cond
, ResumeBB
, NewSwitchBB
);
238 OldSwitchBB
->getTerminator()->eraseFromParent();
242 // Create a resume clone by cloning the body of the original function, setting
243 // new entry block and replacing coro.suspend an appropriate value to force
244 // resume or cleanup pass for every suspend point.
245 static Function
*createClone(Function
&F
, Twine Suffix
, coro::Shape
&Shape
,
246 BasicBlock
*ResumeEntry
, int8_t FnIndex
) {
247 Module
*M
= F
.getParent();
248 auto *FrameTy
= Shape
.FrameTy
;
249 auto *FnPtrTy
= cast
<PointerType
>(FrameTy
->getElementType(0));
250 auto *FnTy
= cast
<FunctionType
>(FnPtrTy
->getElementType());
253 Function::Create(FnTy
, GlobalValue::LinkageTypes::ExternalLinkage
,
254 F
.getName() + Suffix
, M
);
255 NewF
->addParamAttr(0, Attribute::NonNull
);
256 NewF
->addParamAttr(0, Attribute::NoAlias
);
258 ValueToValueMapTy VMap
;
259 // Replace all args with undefs. The buildCoroutineFrame algorithm already
260 // rewritten access to the args that occurs after suspend points with loads
261 // and stores to/from the coroutine frame.
262 for (Argument
&A
: F
.args())
263 VMap
[&A
] = UndefValue::get(A
.getType());
265 SmallVector
<ReturnInst
*, 4> Returns
;
267 CloneFunctionInto(NewF
, &F
, VMap
, /*ModuleLevelChanges=*/true, Returns
);
268 NewF
->setLinkage(GlobalValue::LinkageTypes::InternalLinkage
);
270 // Remove old returns.
271 for (ReturnInst
*Return
: Returns
)
272 changeToUnreachable(Return
, /*UseLLVMTrap=*/false);
274 // Remove old return attributes.
275 NewF
->removeAttributes(
276 AttributeList::ReturnIndex
,
277 AttributeFuncs::typeIncompatible(NewF
->getReturnType()));
279 // Make AllocaSpillBlock the new entry block.
280 auto *SwitchBB
= cast
<BasicBlock
>(VMap
[ResumeEntry
]);
281 auto *Entry
= cast
<BasicBlock
>(VMap
[Shape
.AllocaSpillBlock
]);
282 Entry
->moveBefore(&NewF
->getEntryBlock());
283 Entry
->getTerminator()->eraseFromParent();
284 BranchInst::Create(SwitchBB
, Entry
);
285 Entry
->setName("entry" + Suffix
);
287 // Clear all predecessors of the new entry block.
288 auto *Switch
= cast
<SwitchInst
>(VMap
[Shape
.ResumeSwitch
]);
289 Entry
->replaceAllUsesWith(Switch
->getDefaultDest());
291 IRBuilder
<> Builder(&NewF
->getEntryBlock().front());
293 // Remap frame pointer.
294 Argument
*NewFramePtr
= &*NewF
->arg_begin();
295 Value
*OldFramePtr
= cast
<Value
>(VMap
[Shape
.FramePtr
]);
296 NewFramePtr
->takeName(OldFramePtr
);
297 OldFramePtr
->replaceAllUsesWith(NewFramePtr
);
299 // Remap vFrame pointer.
300 auto *NewVFrame
= Builder
.CreateBitCast(
301 NewFramePtr
, Type::getInt8PtrTy(Builder
.getContext()), "vFrame");
302 Value
*OldVFrame
= cast
<Value
>(VMap
[Shape
.CoroBegin
]);
303 OldVFrame
->replaceAllUsesWith(NewVFrame
);
305 // Rewrite final suspend handling as it is not done via switch (allows to
306 // remove final case from the switch, since it is undefined behavior to resume
307 // the coroutine suspended at the final suspend point.
308 if (Shape
.HasFinalSuspend
) {
309 auto *Switch
= cast
<SwitchInst
>(VMap
[Shape
.ResumeSwitch
]);
310 bool IsDestroy
= FnIndex
!= 0;
311 handleFinalSuspend(Builder
, NewFramePtr
, Shape
, Switch
, IsDestroy
);
314 // Replace coro suspend with the appropriate resume index.
315 // Replacing coro.suspend with (0) will result in control flow proceeding to
316 // a resume label associated with a suspend point, replacing it with (1) will
317 // result in control flow proceeding to a cleanup label associated with this
319 auto *NewValue
= Builder
.getInt8(FnIndex
? 1 : 0);
320 for (CoroSuspendInst
*CS
: Shape
.CoroSuspends
) {
321 auto *MappedCS
= cast
<CoroSuspendInst
>(VMap
[CS
]);
322 MappedCS
->replaceAllUsesWith(NewValue
);
323 MappedCS
->eraseFromParent();
326 // Remove coro.end intrinsics.
327 replaceFallthroughCoroEnd(Shape
.CoroEnds
.front(), VMap
);
328 replaceUnwindCoroEnds(Shape
, VMap
);
329 // Eliminate coro.free from the clones, replacing it with 'null' in cleanup,
330 // to suppress deallocation code.
331 coro::replaceCoroFree(cast
<CoroIdInst
>(VMap
[Shape
.CoroBegin
->getId()]),
332 /*Elide=*/FnIndex
== 2);
334 NewF
->setCallingConv(CallingConv::Fast
);
339 static void removeCoroEnds(coro::Shape
&Shape
) {
340 if (Shape
.CoroEnds
.empty())
343 LLVMContext
&Context
= Shape
.CoroEnds
.front()->getContext();
344 auto *False
= ConstantInt::getFalse(Context
);
346 for (CoroEndInst
*CE
: Shape
.CoroEnds
) {
347 CE
->replaceAllUsesWith(False
);
348 CE
->eraseFromParent();
352 static void replaceFrameSize(coro::Shape
&Shape
) {
353 if (Shape
.CoroSizes
.empty())
356 // In the same function all coro.sizes should have the same result type.
357 auto *SizeIntrin
= Shape
.CoroSizes
.back();
358 Module
*M
= SizeIntrin
->getModule();
359 const DataLayout
&DL
= M
->getDataLayout();
360 auto Size
= DL
.getTypeAllocSize(Shape
.FrameTy
);
361 auto *SizeConstant
= ConstantInt::get(SizeIntrin
->getType(), Size
);
363 for (CoroSizeInst
*CS
: Shape
.CoroSizes
) {
364 CS
->replaceAllUsesWith(SizeConstant
);
365 CS
->eraseFromParent();
369 // Create a global constant array containing pointers to functions provided and
370 // set Info parameter of CoroBegin to point at this constant. Example:
372 // @f.resumers = internal constant [2 x void(%f.frame*)*]
373 // [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy]
374 // define void @f() {
376 // call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
377 // i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*))
379 // Assumes that all the functions have the same signature.
380 static void setCoroInfo(Function
&F
, CoroBeginInst
*CoroBegin
,
381 std::initializer_list
<Function
*> Fns
) {
382 SmallVector
<Constant
*, 4> Args(Fns
.begin(), Fns
.end());
383 assert(!Args
.empty());
384 Function
*Part
= *Fns
.begin();
385 Module
*M
= Part
->getParent();
386 auto *ArrTy
= ArrayType::get(Part
->getType(), Args
.size());
388 auto *ConstVal
= ConstantArray::get(ArrTy
, Args
);
389 auto *GV
= new GlobalVariable(*M
, ConstVal
->getType(), /*isConstant=*/true,
390 GlobalVariable::PrivateLinkage
, ConstVal
,
391 F
.getName() + Twine(".resumers"));
393 // Update coro.begin instruction to refer to this constant.
394 LLVMContext
&C
= F
.getContext();
395 auto *BC
= ConstantExpr::getPointerCast(GV
, Type::getInt8PtrTy(C
));
396 CoroBegin
->getId()->setInfo(BC
);
399 // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
400 static void updateCoroFrame(coro::Shape
&Shape
, Function
*ResumeFn
,
401 Function
*DestroyFn
, Function
*CleanupFn
) {
402 IRBuilder
<> Builder(Shape
.FramePtr
->getNextNode());
403 auto *ResumeAddr
= Builder
.CreateConstInBoundsGEP2_32(
404 Shape
.FrameTy
, Shape
.FramePtr
, 0, coro::Shape::ResumeField
,
406 Builder
.CreateStore(ResumeFn
, ResumeAddr
);
408 Value
*DestroyOrCleanupFn
= DestroyFn
;
410 CoroIdInst
*CoroId
= Shape
.CoroBegin
->getId();
411 if (CoroAllocInst
*CA
= CoroId
->getCoroAlloc()) {
412 // If there is a CoroAlloc and it returns false (meaning we elide the
413 // allocation, use CleanupFn instead of DestroyFn).
414 DestroyOrCleanupFn
= Builder
.CreateSelect(CA
, DestroyFn
, CleanupFn
);
417 auto *DestroyAddr
= Builder
.CreateConstInBoundsGEP2_32(
418 Shape
.FrameTy
, Shape
.FramePtr
, 0, coro::Shape::DestroyField
,
420 Builder
.CreateStore(DestroyOrCleanupFn
, DestroyAddr
);
423 static void postSplitCleanup(Function
&F
) {
424 removeUnreachableBlocks(F
);
425 legacy::FunctionPassManager
FPM(F
.getParent());
427 FPM
.add(createVerifierPass());
428 FPM
.add(createSCCPPass());
429 FPM
.add(createCFGSimplificationPass());
430 FPM
.add(createEarlyCSEPass());
431 FPM
.add(createCFGSimplificationPass());
433 FPM
.doInitialization();
435 FPM
.doFinalization();
438 // Assuming we arrived at the block NewBlock from Prev instruction, store
439 // PHI's incoming values in the ResolvedValues map.
441 scanPHIsAndUpdateValueMap(Instruction
*Prev
, BasicBlock
*NewBlock
,
442 DenseMap
<Value
*, Value
*> &ResolvedValues
) {
443 auto *PrevBB
= Prev
->getParent();
444 for (PHINode
&PN
: NewBlock
->phis()) {
445 auto V
= PN
.getIncomingValueForBlock(PrevBB
);
446 // See if we already resolved it.
447 auto VI
= ResolvedValues
.find(V
);
448 if (VI
!= ResolvedValues
.end())
450 // Remember the value.
451 ResolvedValues
[&PN
] = V
;
455 // Replace a sequence of branches leading to a ret, with a clone of a ret
456 // instruction. Suspend instruction represented by a switch, track the PHI
457 // values and select the correct case successor when possible.
458 static bool simplifyTerminatorLeadingToRet(Instruction
*InitialInst
) {
459 DenseMap
<Value
*, Value
*> ResolvedValues
;
461 Instruction
*I
= InitialInst
;
462 while (I
->isTerminator()) {
463 if (isa
<ReturnInst
>(I
)) {
464 if (I
!= InitialInst
)
465 ReplaceInstWithInst(InitialInst
, I
->clone());
468 if (auto *BR
= dyn_cast
<BranchInst
>(I
)) {
469 if (BR
->isUnconditional()) {
470 BasicBlock
*BB
= BR
->getSuccessor(0);
471 scanPHIsAndUpdateValueMap(I
, BB
, ResolvedValues
);
472 I
= BB
->getFirstNonPHIOrDbgOrLifetime();
475 } else if (auto *SI
= dyn_cast
<SwitchInst
>(I
)) {
476 Value
*V
= SI
->getCondition();
477 auto it
= ResolvedValues
.find(V
);
478 if (it
!= ResolvedValues
.end())
480 if (ConstantInt
*Cond
= dyn_cast
<ConstantInt
>(V
)) {
481 BasicBlock
*BB
= SI
->findCaseValue(Cond
)->getCaseSuccessor();
482 scanPHIsAndUpdateValueMap(I
, BB
, ResolvedValues
);
483 I
= BB
->getFirstNonPHIOrDbgOrLifetime();
492 // Add musttail to any resume instructions that is immediately followed by a
493 // suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call
494 // for symmetrical coroutine control transfer (C++ Coroutines TS extension).
495 // This transformation is done only in the resume part of the coroutine that has
496 // identical signature and calling convention as the coro.resume call.
497 static void addMustTailToCoroResumes(Function
&F
) {
498 bool changed
= false;
500 // Collect potential resume instructions.
501 SmallVector
<CallInst
*, 4> Resumes
;
502 for (auto &I
: instructions(F
))
503 if (auto *Call
= dyn_cast
<CallInst
>(&I
))
504 if (auto *CalledValue
= Call
->getCalledValue())
505 // CoroEarly pass replaced coro resumes with indirect calls to an
506 // address return by CoroSubFnInst intrinsic. See if it is one of those.
507 if (isa
<CoroSubFnInst
>(CalledValue
->stripPointerCasts()))
508 Resumes
.push_back(Call
);
510 // Set musttail on those that are followed by a ret instruction.
511 for (CallInst
*Call
: Resumes
)
512 if (simplifyTerminatorLeadingToRet(Call
->getNextNode())) {
513 Call
->setTailCallKind(CallInst::TCK_MustTail
);
518 removeUnreachableBlocks(F
);
521 // Coroutine has no suspend points. Remove heap allocation for the coroutine
522 // frame if possible.
523 static void handleNoSuspendCoroutine(CoroBeginInst
*CoroBegin
, Type
*FrameTy
) {
524 auto *CoroId
= CoroBegin
->getId();
525 auto *AllocInst
= CoroId
->getCoroAlloc();
526 coro::replaceCoroFree(CoroId
, /*Elide=*/AllocInst
!= nullptr);
528 IRBuilder
<> Builder(AllocInst
);
529 // FIXME: Need to handle overaligned members.
530 auto *Frame
= Builder
.CreateAlloca(FrameTy
);
531 auto *VFrame
= Builder
.CreateBitCast(Frame
, Builder
.getInt8PtrTy());
532 AllocInst
->replaceAllUsesWith(Builder
.getFalse());
533 AllocInst
->eraseFromParent();
534 CoroBegin
->replaceAllUsesWith(VFrame
);
536 CoroBegin
->replaceAllUsesWith(CoroBegin
->getMem());
538 CoroBegin
->eraseFromParent();
541 // look for a very simple pattern
544 // resume or destroy call
547 // If there are other calls between coro.save and coro.suspend, they can
548 // potentially resume or destroy the coroutine, so it is unsafe to eliminate a
550 static bool simplifySuspendPoint(CoroSuspendInst
*Suspend
,
551 CoroBeginInst
*CoroBegin
) {
552 auto *Save
= Suspend
->getCoroSave();
553 auto *BB
= Suspend
->getParent();
554 if (BB
!= Save
->getParent())
557 CallSite SingleCallSite
;
559 // Check that we have only one CallSite.
560 for (Instruction
*I
= Save
->getNextNode(); I
!= Suspend
;
561 I
= I
->getNextNode()) {
562 if (isa
<CoroFrameInst
>(I
))
564 if (isa
<CoroSubFnInst
>(I
))
566 if (CallSite CS
= CallSite(I
)) {
573 auto *CallInstr
= SingleCallSite
.getInstruction();
577 auto *Callee
= SingleCallSite
.getCalledValue()->stripPointerCasts();
579 // See if the callsite is for resumption or destruction of the coroutine.
580 auto *SubFn
= dyn_cast
<CoroSubFnInst
>(Callee
);
584 // Does not refer to the current coroutine, we cannot do anything with it.
585 if (SubFn
->getFrame() != CoroBegin
)
588 // Replace llvm.coro.suspend with the value that results in resumption over
589 // the resume or cleanup path.
590 Suspend
->replaceAllUsesWith(SubFn
->getRawIndex());
591 Suspend
->eraseFromParent();
592 Save
->eraseFromParent();
594 // No longer need a call to coro.resume or coro.destroy.
595 CallInstr
->eraseFromParent();
597 if (SubFn
->user_empty())
598 SubFn
->eraseFromParent();
603 // Remove suspend points that are simplified.
604 static void simplifySuspendPoints(coro::Shape
&Shape
) {
605 auto &S
= Shape
.CoroSuspends
;
606 size_t I
= 0, N
= S
.size();
610 if (simplifySuspendPoint(S
[I
], Shape
.CoroBegin
)) {
613 std::swap(S
[I
], S
[N
]);
622 static SmallPtrSet
<BasicBlock
*, 4> getCoroBeginPredBlocks(CoroBeginInst
*CB
) {
623 // Collect all blocks that we need to look for instructions to relocate.
624 SmallPtrSet
<BasicBlock
*, 4> RelocBlocks
;
625 SmallVector
<BasicBlock
*, 4> Work
;
626 Work
.push_back(CB
->getParent());
629 BasicBlock
*Current
= Work
.pop_back_val();
630 for (BasicBlock
*BB
: predecessors(Current
))
631 if (RelocBlocks
.count(BB
) == 0) {
632 RelocBlocks
.insert(BB
);
635 } while (!Work
.empty());
639 static SmallPtrSet
<Instruction
*, 8>
640 getNotRelocatableInstructions(CoroBeginInst
*CoroBegin
,
641 SmallPtrSetImpl
<BasicBlock
*> &RelocBlocks
) {
642 SmallPtrSet
<Instruction
*, 8> DoNotRelocate
;
643 // Collect all instructions that we should not relocate
644 SmallVector
<Instruction
*, 8> Work
;
646 // Start with CoroBegin and terminators of all preceding blocks.
647 Work
.push_back(CoroBegin
);
648 BasicBlock
*CoroBeginBB
= CoroBegin
->getParent();
649 for (BasicBlock
*BB
: RelocBlocks
)
650 if (BB
!= CoroBeginBB
)
651 Work
.push_back(BB
->getTerminator());
653 // For every instruction in the Work list, place its operands in DoNotRelocate
656 Instruction
*Current
= Work
.pop_back_val();
657 LLVM_DEBUG(dbgs() << "CoroSplit: Will not relocate: " << *Current
<< "\n");
658 DoNotRelocate
.insert(Current
);
659 for (Value
*U
: Current
->operands()) {
660 auto *I
= dyn_cast
<Instruction
>(U
);
664 if (auto *A
= dyn_cast
<AllocaInst
>(I
)) {
665 // Stores to alloca instructions that occur before the coroutine frame
666 // is allocated should not be moved; the stored values may be used by
667 // the coroutine frame allocator. The operands to those stores must also
669 for (const auto &User
: A
->users())
670 if (auto *SI
= dyn_cast
<llvm::StoreInst
>(User
))
671 if (RelocBlocks
.count(SI
->getParent()) != 0 &&
672 DoNotRelocate
.count(SI
) == 0) {
674 DoNotRelocate
.insert(SI
);
679 if (DoNotRelocate
.count(I
) == 0) {
681 DoNotRelocate
.insert(I
);
684 } while (!Work
.empty());
685 return DoNotRelocate
;
688 static void relocateInstructionBefore(CoroBeginInst
*CoroBegin
, Function
&F
) {
689 // Analyze which non-alloca instructions are needed for allocation and
690 // relocate the rest to after coro.begin. We need to do it, since some of the
691 // targets of those instructions may be placed into coroutine frame memory
692 // for which becomes available after coro.begin intrinsic.
694 auto BlockSet
= getCoroBeginPredBlocks(CoroBegin
);
695 auto DoNotRelocateSet
= getNotRelocatableInstructions(CoroBegin
, BlockSet
);
697 Instruction
*InsertPt
= CoroBegin
->getNextNode();
698 BasicBlock
&BB
= F
.getEntryBlock(); // TODO: Look at other blocks as well.
699 for (auto B
= BB
.begin(), E
= BB
.end(); B
!= E
;) {
700 Instruction
&I
= *B
++;
701 if (isa
<AllocaInst
>(&I
))
705 if (DoNotRelocateSet
.count(&I
))
707 I
.moveBefore(InsertPt
);
711 static void splitCoroutine(Function
&F
, CallGraph
&CG
, CallGraphSCC
&SCC
) {
712 coro::Shape
Shape(F
);
713 if (!Shape
.CoroBegin
)
716 simplifySuspendPoints(Shape
);
717 relocateInstructionBefore(Shape
.CoroBegin
, F
);
718 buildCoroutineFrame(F
, Shape
);
719 replaceFrameSize(Shape
);
721 // If there are no suspend points, no split required, just remove
722 // the allocation and deallocation blocks, they are not needed.
723 if (Shape
.CoroSuspends
.empty()) {
724 handleNoSuspendCoroutine(Shape
.CoroBegin
, Shape
.FrameTy
);
725 removeCoroEnds(Shape
);
727 coro::updateCallGraph(F
, {}, CG
, SCC
);
731 auto *ResumeEntry
= createResumeEntryBlock(F
, Shape
);
732 auto ResumeClone
= createClone(F
, ".resume", Shape
, ResumeEntry
, 0);
733 auto DestroyClone
= createClone(F
, ".destroy", Shape
, ResumeEntry
, 1);
734 auto CleanupClone
= createClone(F
, ".cleanup", Shape
, ResumeEntry
, 2);
736 // We no longer need coro.end in F.
737 removeCoroEnds(Shape
);
740 postSplitCleanup(*ResumeClone
);
741 postSplitCleanup(*DestroyClone
);
742 postSplitCleanup(*CleanupClone
);
744 addMustTailToCoroResumes(*ResumeClone
);
746 // Store addresses resume/destroy/cleanup functions in the coroutine frame.
747 updateCoroFrame(Shape
, ResumeClone
, DestroyClone
, CleanupClone
);
749 // Create a constant array referring to resume/destroy/clone functions pointed
750 // by the last argument of @llvm.coro.info, so that CoroElide pass can
751 // determined correct function to call.
752 setCoroInfo(F
, Shape
.CoroBegin
, {ResumeClone
, DestroyClone
, CleanupClone
});
754 // Update call graph and add the functions we created to the SCC.
755 coro::updateCallGraph(F
, {ResumeClone
, DestroyClone
, CleanupClone
}, CG
, SCC
);
758 // When we see the coroutine the first time, we insert an indirect call to a
759 // devirt trigger function and mark the coroutine that it is now ready for
761 static void prepareForSplit(Function
&F
, CallGraph
&CG
) {
762 Module
&M
= *F
.getParent();
764 Function
*DevirtFn
= M
.getFunction(CORO_DEVIRT_TRIGGER_FN
);
765 assert(DevirtFn
&& "coro.devirt.trigger function not found");
768 F
.addFnAttr(CORO_PRESPLIT_ATTR
, PREPARED_FOR_SPLIT
);
770 // Insert an indirect call sequence that will be devirtualized by CoroElide
772 // %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1)
773 // %1 = bitcast i8* %0 to void(i8*)*
774 // call void %1(i8* null)
775 coro::LowererBase
Lowerer(M
);
776 Instruction
*InsertPt
= F
.getEntryBlock().getTerminator();
777 auto *Null
= ConstantPointerNull::get(Type::getInt8PtrTy(F
.getContext()));
779 Lowerer
.makeSubFnCall(Null
, CoroSubFnInst::RestartTrigger
, InsertPt
);
780 auto *IndirectCall
= CallInst::Create(DevirtFnAddr
, Null
, "", InsertPt
);
782 // Update CG graph with an indirect call we just added.
783 CG
[&F
]->addCalledFunction(IndirectCall
, CG
.getCallsExternalNode());
786 // Make sure that there is a devirtualization trigger function that CoroSplit
787 // pass uses the force restart CGSCC pipeline. If devirt trigger function is not
788 // found, we will create one and add it to the current SCC.
789 static void createDevirtTriggerFunc(CallGraph
&CG
, CallGraphSCC
&SCC
) {
790 Module
&M
= CG
.getModule();
791 if (M
.getFunction(CORO_DEVIRT_TRIGGER_FN
))
794 LLVMContext
&C
= M
.getContext();
795 auto *FnTy
= FunctionType::get(Type::getVoidTy(C
), Type::getInt8PtrTy(C
),
796 /*IsVarArgs=*/false);
798 Function::Create(FnTy
, GlobalValue::LinkageTypes::PrivateLinkage
,
799 CORO_DEVIRT_TRIGGER_FN
, &M
);
800 DevirtFn
->addFnAttr(Attribute::AlwaysInline
);
801 auto *Entry
= BasicBlock::Create(C
, "entry", DevirtFn
);
802 ReturnInst::Create(C
, Entry
);
804 auto *Node
= CG
.getOrInsertFunction(DevirtFn
);
806 SmallVector
<CallGraphNode
*, 8> Nodes(SCC
.begin(), SCC
.end());
807 Nodes
.push_back(Node
);
808 SCC
.initialize(Nodes
);
811 //===----------------------------------------------------------------------===//
813 //===----------------------------------------------------------------------===//
817 struct CoroSplit
: public CallGraphSCCPass
{
818 static char ID
; // Pass identification, replacement for typeid
820 CoroSplit() : CallGraphSCCPass(ID
) {
821 initializeCoroSplitPass(*PassRegistry::getPassRegistry());
826 // A coroutine is identified by the presence of coro.begin intrinsic, if
827 // we don't have any, this pass has nothing to do.
828 bool doInitialization(CallGraph
&CG
) override
{
829 Run
= coro::declaresIntrinsics(CG
.getModule(), {"llvm.coro.begin"});
830 return CallGraphSCCPass::doInitialization(CG
);
833 bool runOnSCC(CallGraphSCC
&SCC
) override
{
837 // Find coroutines for processing.
838 SmallVector
<Function
*, 4> Coroutines
;
839 for (CallGraphNode
*CGN
: SCC
)
840 if (auto *F
= CGN
->getFunction())
841 if (F
->hasFnAttribute(CORO_PRESPLIT_ATTR
))
842 Coroutines
.push_back(F
);
844 if (Coroutines
.empty())
847 CallGraph
&CG
= getAnalysis
<CallGraphWrapperPass
>().getCallGraph();
848 createDevirtTriggerFunc(CG
, SCC
);
850 for (Function
*F
: Coroutines
) {
851 Attribute Attr
= F
->getFnAttribute(CORO_PRESPLIT_ATTR
);
852 StringRef Value
= Attr
.getValueAsString();
853 LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F
->getName()
854 << "' state: " << Value
<< "\n");
855 if (Value
== UNPREPARED_FOR_SPLIT
) {
856 prepareForSplit(*F
, CG
);
859 F
->removeFnAttr(CORO_PRESPLIT_ATTR
);
860 splitCoroutine(*F
, CG
, SCC
);
865 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
866 CallGraphSCCPass::getAnalysisUsage(AU
);
869 StringRef
getPassName() const override
{ return "Coroutine Splitting"; }
872 } // end anonymous namespace
874 char CoroSplit::ID
= 0;
877 CoroSplit
, "coro-split",
878 "Split coroutine into a set of functions driving its state machine", false,
881 Pass
*llvm::createCoroSplitPass() { return new CoroSplit(); }