[InstCombine] Signed saturation patterns
[llvm-core.git] / lib / Transforms / Coroutines / CoroSplit.cpp
blob04723cbde417b4b542b1f0b287b36bf0e43a523b
1 //===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This pass builds the coroutine frame and outlines resume and destroy parts
9 // of the coroutine into separate functions.
11 // We present a coroutine to an LLVM as an ordinary function with suspension
12 // points marked up with intrinsics. We let the optimizer party on the coroutine
13 // as a single function for as long as possible. Shortly before the coroutine is
14 // eligible to be inlined into its callers, we split up the coroutine into parts
15 // corresponding to an initial, resume and destroy invocations of the coroutine,
16 // add them to the current SCC and restart the IPO pipeline to optimize the
17 // coroutine subfunctions we extracted before proceeding to the caller of the
18 // coroutine.
19 //===----------------------------------------------------------------------===//
21 #include "CoroInstr.h"
22 #include "CoroInternal.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/SmallPtrSet.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/ADT/Twine.h"
28 #include "llvm/Analysis/CallGraph.h"
29 #include "llvm/Analysis/CallGraphSCCPass.h"
30 #include "llvm/Transforms/Utils/Local.h"
31 #include "llvm/IR/Argument.h"
32 #include "llvm/IR/Attributes.h"
33 #include "llvm/IR/BasicBlock.h"
34 #include "llvm/IR/CFG.h"
35 #include "llvm/IR/CallSite.h"
36 #include "llvm/IR/CallingConv.h"
37 #include "llvm/IR/Constants.h"
38 #include "llvm/IR/DataLayout.h"
39 #include "llvm/IR/DerivedTypes.h"
40 #include "llvm/IR/Function.h"
41 #include "llvm/IR/GlobalValue.h"
42 #include "llvm/IR/GlobalVariable.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/InstIterator.h"
45 #include "llvm/IR/InstrTypes.h"
46 #include "llvm/IR/Instruction.h"
47 #include "llvm/IR/Instructions.h"
48 #include "llvm/IR/IntrinsicInst.h"
49 #include "llvm/IR/LLVMContext.h"
50 #include "llvm/IR/LegacyPassManager.h"
51 #include "llvm/IR/Module.h"
52 #include "llvm/IR/Type.h"
53 #include "llvm/IR/Value.h"
54 #include "llvm/IR/Verifier.h"
55 #include "llvm/Pass.h"
56 #include "llvm/Support/Casting.h"
57 #include "llvm/Support/Debug.h"
58 #include "llvm/Support/PrettyStackTrace.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/Transforms/Scalar.h"
61 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
62 #include "llvm/Transforms/Utils/Cloning.h"
63 #include "llvm/Transforms/Utils/ValueMapper.h"
64 #include <cassert>
65 #include <cstddef>
66 #include <cstdint>
67 #include <initializer_list>
68 #include <iterator>
70 using namespace llvm;
72 #define DEBUG_TYPE "coro-split"
74 namespace {
76 /// A little helper class for building
77 class CoroCloner {
78 public:
79 enum class Kind {
80 /// The shared resume function for a switch lowering.
81 SwitchResume,
83 /// The shared unwind function for a switch lowering.
84 SwitchUnwind,
86 /// The shared cleanup function for a switch lowering.
87 SwitchCleanup,
89 /// An individual continuation function.
90 Continuation,
92 private:
93 Function &OrigF;
94 Function *NewF;
95 const Twine &Suffix;
96 coro::Shape &Shape;
97 Kind FKind;
98 ValueToValueMapTy VMap;
99 IRBuilder<> Builder;
100 Value *NewFramePtr = nullptr;
101 Value *SwiftErrorSlot = nullptr;
103 /// The active suspend instruction; meaningful only for continuation ABIs.
104 AnyCoroSuspendInst *ActiveSuspend = nullptr;
106 public:
107 /// Create a cloner for a switch lowering.
108 CoroCloner(Function &OrigF, const Twine &Suffix, coro::Shape &Shape,
109 Kind FKind)
110 : OrigF(OrigF), NewF(nullptr), Suffix(Suffix), Shape(Shape),
111 FKind(FKind), Builder(OrigF.getContext()) {
112 assert(Shape.ABI == coro::ABI::Switch);
115 /// Create a cloner for a continuation lowering.
116 CoroCloner(Function &OrigF, const Twine &Suffix, coro::Shape &Shape,
117 Function *NewF, AnyCoroSuspendInst *ActiveSuspend)
118 : OrigF(OrigF), NewF(NewF), Suffix(Suffix), Shape(Shape),
119 FKind(Kind::Continuation), Builder(OrigF.getContext()),
120 ActiveSuspend(ActiveSuspend) {
121 assert(Shape.ABI == coro::ABI::Retcon ||
122 Shape.ABI == coro::ABI::RetconOnce);
123 assert(NewF && "need existing function for continuation");
124 assert(ActiveSuspend && "need active suspend point for continuation");
127 Function *getFunction() const {
128 assert(NewF != nullptr && "declaration not yet set");
129 return NewF;
132 void create();
134 private:
135 bool isSwitchDestroyFunction() {
136 switch (FKind) {
137 case Kind::Continuation:
138 case Kind::SwitchResume:
139 return false;
140 case Kind::SwitchUnwind:
141 case Kind::SwitchCleanup:
142 return true;
144 llvm_unreachable("Unknown CoroCloner::Kind enum");
147 void createDeclaration();
148 void replaceEntryBlock();
149 Value *deriveNewFramePointer();
150 void replaceRetconSuspendUses();
151 void replaceCoroSuspends();
152 void replaceCoroEnds();
153 void replaceSwiftErrorOps();
154 void handleFinalSuspend();
155 void maybeFreeContinuationStorage();
158 } // end anonymous namespace
160 static void maybeFreeRetconStorage(IRBuilder<> &Builder, coro::Shape &Shape,
161 Value *FramePtr, CallGraph *CG) {
162 assert(Shape.ABI == coro::ABI::Retcon ||
163 Shape.ABI == coro::ABI::RetconOnce);
164 if (Shape.RetconLowering.IsFrameInlineInStorage)
165 return;
167 Shape.emitDealloc(Builder, FramePtr, CG);
170 /// Replace a non-unwind call to llvm.coro.end.
171 static void replaceFallthroughCoroEnd(CoroEndInst *End, coro::Shape &Shape,
172 Value *FramePtr, bool InResume,
173 CallGraph *CG) {
174 // Start inserting right before the coro.end.
175 IRBuilder<> Builder(End);
177 // Create the return instruction.
178 switch (Shape.ABI) {
179 // The cloned functions in switch-lowering always return void.
180 case coro::ABI::Switch:
181 // coro.end doesn't immediately end the coroutine in the main function
182 // in this lowering, because we need to deallocate the coroutine.
183 if (!InResume)
184 return;
185 Builder.CreateRetVoid();
186 break;
188 // In unique continuation lowering, the continuations always return void.
189 // But we may have implicitly allocated storage.
190 case coro::ABI::RetconOnce:
191 maybeFreeRetconStorage(Builder, Shape, FramePtr, CG);
192 Builder.CreateRetVoid();
193 break;
195 // In non-unique continuation lowering, we signal completion by returning
196 // a null continuation.
197 case coro::ABI::Retcon: {
198 maybeFreeRetconStorage(Builder, Shape, FramePtr, CG);
199 auto RetTy = Shape.getResumeFunctionType()->getReturnType();
200 auto RetStructTy = dyn_cast<StructType>(RetTy);
201 PointerType *ContinuationTy =
202 cast<PointerType>(RetStructTy ? RetStructTy->getElementType(0) : RetTy);
204 Value *ReturnValue = ConstantPointerNull::get(ContinuationTy);
205 if (RetStructTy) {
206 ReturnValue = Builder.CreateInsertValue(UndefValue::get(RetStructTy),
207 ReturnValue, 0);
209 Builder.CreateRet(ReturnValue);
210 break;
214 // Remove the rest of the block, by splitting it into an unreachable block.
215 auto *BB = End->getParent();
216 BB->splitBasicBlock(End);
217 BB->getTerminator()->eraseFromParent();
220 /// Replace an unwind call to llvm.coro.end.
221 static void replaceUnwindCoroEnd(CoroEndInst *End, coro::Shape &Shape,
222 Value *FramePtr, bool InResume, CallGraph *CG){
223 IRBuilder<> Builder(End);
225 switch (Shape.ABI) {
226 // In switch-lowering, this does nothing in the main function.
227 case coro::ABI::Switch:
228 if (!InResume)
229 return;
230 break;
232 // In continuation-lowering, this frees the continuation storage.
233 case coro::ABI::Retcon:
234 case coro::ABI::RetconOnce:
235 maybeFreeRetconStorage(Builder, Shape, FramePtr, CG);
236 break;
239 // If coro.end has an associated bundle, add cleanupret instruction.
240 if (auto Bundle = End->getOperandBundle(LLVMContext::OB_funclet)) {
241 auto *FromPad = cast<CleanupPadInst>(Bundle->Inputs[0]);
242 auto *CleanupRet = Builder.CreateCleanupRet(FromPad, nullptr);
243 End->getParent()->splitBasicBlock(End);
244 CleanupRet->getParent()->getTerminator()->eraseFromParent();
248 static void replaceCoroEnd(CoroEndInst *End, coro::Shape &Shape,
249 Value *FramePtr, bool InResume, CallGraph *CG) {
250 if (End->isUnwind())
251 replaceUnwindCoroEnd(End, Shape, FramePtr, InResume, CG);
252 else
253 replaceFallthroughCoroEnd(End, Shape, FramePtr, InResume, CG);
255 auto &Context = End->getContext();
256 End->replaceAllUsesWith(InResume ? ConstantInt::getTrue(Context)
257 : ConstantInt::getFalse(Context));
258 End->eraseFromParent();
261 // Create an entry block for a resume function with a switch that will jump to
262 // suspend points.
263 static void createResumeEntryBlock(Function &F, coro::Shape &Shape) {
264 assert(Shape.ABI == coro::ABI::Switch);
265 LLVMContext &C = F.getContext();
267 // resume.entry:
268 // %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0,
269 // i32 2
270 // % index = load i32, i32* %index.addr
271 // switch i32 %index, label %unreachable [
272 // i32 0, label %resume.0
273 // i32 1, label %resume.1
274 // ...
275 // ]
277 auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F);
278 auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F);
280 IRBuilder<> Builder(NewEntry);
281 auto *FramePtr = Shape.FramePtr;
282 auto *FrameTy = Shape.FrameTy;
283 auto *GepIndex = Builder.CreateStructGEP(
284 FrameTy, FramePtr, coro::Shape::SwitchFieldIndex::Index, "index.addr");
285 auto *Index = Builder.CreateLoad(Shape.getIndexType(), GepIndex, "index");
286 auto *Switch =
287 Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size());
288 Shape.SwitchLowering.ResumeSwitch = Switch;
290 size_t SuspendIndex = 0;
291 for (auto *AnyS : Shape.CoroSuspends) {
292 auto *S = cast<CoroSuspendInst>(AnyS);
293 ConstantInt *IndexVal = Shape.getIndex(SuspendIndex);
295 // Replace CoroSave with a store to Index:
296 // %index.addr = getelementptr %f.frame... (index field number)
297 // store i32 0, i32* %index.addr1
298 auto *Save = S->getCoroSave();
299 Builder.SetInsertPoint(Save);
300 if (S->isFinal()) {
301 // Final suspend point is represented by storing zero in ResumeFnAddr.
302 auto *GepIndex = Builder.CreateStructGEP(FrameTy, FramePtr,
303 coro::Shape::SwitchFieldIndex::Resume,
304 "ResumeFn.addr");
305 auto *NullPtr = ConstantPointerNull::get(cast<PointerType>(
306 cast<PointerType>(GepIndex->getType())->getElementType()));
307 Builder.CreateStore(NullPtr, GepIndex);
308 } else {
309 auto *GepIndex = Builder.CreateStructGEP(
310 FrameTy, FramePtr, coro::Shape::SwitchFieldIndex::Index, "index.addr");
311 Builder.CreateStore(IndexVal, GepIndex);
313 Save->replaceAllUsesWith(ConstantTokenNone::get(C));
314 Save->eraseFromParent();
316 // Split block before and after coro.suspend and add a jump from an entry
317 // switch:
319 // whateverBB:
320 // whatever
321 // %0 = call i8 @llvm.coro.suspend(token none, i1 false)
322 // switch i8 %0, label %suspend[i8 0, label %resume
323 // i8 1, label %cleanup]
324 // becomes:
326 // whateverBB:
327 // whatever
328 // br label %resume.0.landing
330 // resume.0: ; <--- jump from the switch in the resume.entry
331 // %0 = tail call i8 @llvm.coro.suspend(token none, i1 false)
332 // br label %resume.0.landing
334 // resume.0.landing:
335 // %1 = phi i8[-1, %whateverBB], [%0, %resume.0]
336 // switch i8 % 1, label %suspend [i8 0, label %resume
337 // i8 1, label %cleanup]
339 auto *SuspendBB = S->getParent();
340 auto *ResumeBB =
341 SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex));
342 auto *LandingBB = ResumeBB->splitBasicBlock(
343 S->getNextNode(), ResumeBB->getName() + Twine(".landing"));
344 Switch->addCase(IndexVal, ResumeBB);
346 cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB);
347 auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front());
348 S->replaceAllUsesWith(PN);
349 PN->addIncoming(Builder.getInt8(-1), SuspendBB);
350 PN->addIncoming(S, ResumeBB);
352 ++SuspendIndex;
355 Builder.SetInsertPoint(UnreachBB);
356 Builder.CreateUnreachable();
358 Shape.SwitchLowering.ResumeEntryBlock = NewEntry;
362 // Rewrite final suspend point handling. We do not use suspend index to
363 // represent the final suspend point. Instead we zero-out ResumeFnAddr in the
364 // coroutine frame, since it is undefined behavior to resume a coroutine
365 // suspended at the final suspend point. Thus, in the resume function, we can
366 // simply remove the last case (when coro::Shape is built, the final suspend
367 // point (if present) is always the last element of CoroSuspends array).
368 // In the destroy function, we add a code sequence to check if ResumeFnAddress
369 // is Null, and if so, jump to the appropriate label to handle cleanup from the
370 // final suspend point.
371 void CoroCloner::handleFinalSuspend() {
372 assert(Shape.ABI == coro::ABI::Switch &&
373 Shape.SwitchLowering.HasFinalSuspend);
374 auto *Switch = cast<SwitchInst>(VMap[Shape.SwitchLowering.ResumeSwitch]);
375 auto FinalCaseIt = std::prev(Switch->case_end());
376 BasicBlock *ResumeBB = FinalCaseIt->getCaseSuccessor();
377 Switch->removeCase(FinalCaseIt);
378 if (isSwitchDestroyFunction()) {
379 BasicBlock *OldSwitchBB = Switch->getParent();
380 auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch");
381 Builder.SetInsertPoint(OldSwitchBB->getTerminator());
382 auto *GepIndex = Builder.CreateStructGEP(Shape.FrameTy, NewFramePtr,
383 coro::Shape::SwitchFieldIndex::Resume,
384 "ResumeFn.addr");
385 auto *Load = Builder.CreateLoad(Shape.getSwitchResumePointerType(),
386 GepIndex);
387 auto *Cond = Builder.CreateIsNull(Load);
388 Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB);
389 OldSwitchBB->getTerminator()->eraseFromParent();
393 static Function *createCloneDeclaration(Function &OrigF, coro::Shape &Shape,
394 const Twine &Suffix,
395 Module::iterator InsertBefore) {
396 Module *M = OrigF.getParent();
397 auto *FnTy = Shape.getResumeFunctionType();
399 Function *NewF =
400 Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage,
401 OrigF.getName() + Suffix);
402 NewF->addParamAttr(0, Attribute::NonNull);
403 NewF->addParamAttr(0, Attribute::NoAlias);
405 M->getFunctionList().insert(InsertBefore, NewF);
407 return NewF;
410 /// Replace uses of the active llvm.coro.suspend.retcon call with the
411 /// arguments to the continuation function.
413 /// This assumes that the builder has a meaningful insertion point.
414 void CoroCloner::replaceRetconSuspendUses() {
415 assert(Shape.ABI == coro::ABI::Retcon ||
416 Shape.ABI == coro::ABI::RetconOnce);
418 auto NewS = VMap[ActiveSuspend];
419 if (NewS->use_empty()) return;
421 // Copy out all the continuation arguments after the buffer pointer into
422 // an easily-indexed data structure for convenience.
423 SmallVector<Value*, 8> Args;
424 for (auto I = std::next(NewF->arg_begin()), E = NewF->arg_end(); I != E; ++I)
425 Args.push_back(&*I);
427 // If the suspend returns a single scalar value, we can just do a simple
428 // replacement.
429 if (!isa<StructType>(NewS->getType())) {
430 assert(Args.size() == 1);
431 NewS->replaceAllUsesWith(Args.front());
432 return;
435 // Try to peephole extracts of an aggregate return.
436 for (auto UI = NewS->use_begin(), UE = NewS->use_end(); UI != UE; ) {
437 auto EVI = dyn_cast<ExtractValueInst>((UI++)->getUser());
438 if (!EVI || EVI->getNumIndices() != 1)
439 continue;
441 EVI->replaceAllUsesWith(Args[EVI->getIndices().front()]);
442 EVI->eraseFromParent();
445 // If we have no remaining uses, we're done.
446 if (NewS->use_empty()) return;
448 // Otherwise, we need to create an aggregate.
449 Value *Agg = UndefValue::get(NewS->getType());
450 for (size_t I = 0, E = Args.size(); I != E; ++I)
451 Agg = Builder.CreateInsertValue(Agg, Args[I], I);
453 NewS->replaceAllUsesWith(Agg);
456 void CoroCloner::replaceCoroSuspends() {
457 Value *SuspendResult;
459 switch (Shape.ABI) {
460 // In switch lowering, replace coro.suspend with the appropriate value
461 // for the type of function we're extracting.
462 // Replacing coro.suspend with (0) will result in control flow proceeding to
463 // a resume label associated with a suspend point, replacing it with (1) will
464 // result in control flow proceeding to a cleanup label associated with this
465 // suspend point.
466 case coro::ABI::Switch:
467 SuspendResult = Builder.getInt8(isSwitchDestroyFunction() ? 1 : 0);
468 break;
470 // In returned-continuation lowering, the arguments from earlier
471 // continuations are theoretically arbitrary, and they should have been
472 // spilled.
473 case coro::ABI::RetconOnce:
474 case coro::ABI::Retcon:
475 return;
478 for (AnyCoroSuspendInst *CS : Shape.CoroSuspends) {
479 // The active suspend was handled earlier.
480 if (CS == ActiveSuspend) continue;
482 auto *MappedCS = cast<AnyCoroSuspendInst>(VMap[CS]);
483 MappedCS->replaceAllUsesWith(SuspendResult);
484 MappedCS->eraseFromParent();
488 void CoroCloner::replaceCoroEnds() {
489 for (CoroEndInst *CE : Shape.CoroEnds) {
490 // We use a null call graph because there's no call graph node for
491 // the cloned function yet. We'll just be rebuilding that later.
492 auto NewCE = cast<CoroEndInst>(VMap[CE]);
493 replaceCoroEnd(NewCE, Shape, NewFramePtr, /*in resume*/ true, nullptr);
497 static void replaceSwiftErrorOps(Function &F, coro::Shape &Shape,
498 ValueToValueMapTy *VMap) {
499 Value *CachedSlot = nullptr;
500 auto getSwiftErrorSlot = [&](Type *ValueTy) -> Value * {
501 if (CachedSlot) {
502 assert(CachedSlot->getType()->getPointerElementType() == ValueTy &&
503 "multiple swifterror slots in function with different types");
504 return CachedSlot;
507 // Check if the function has a swifterror argument.
508 for (auto &Arg : F.args()) {
509 if (Arg.isSwiftError()) {
510 CachedSlot = &Arg;
511 assert(Arg.getType()->getPointerElementType() == ValueTy &&
512 "swifterror argument does not have expected type");
513 return &Arg;
517 // Create a swifterror alloca.
518 IRBuilder<> Builder(F.getEntryBlock().getFirstNonPHIOrDbg());
519 auto Alloca = Builder.CreateAlloca(ValueTy);
520 Alloca->setSwiftError(true);
522 CachedSlot = Alloca;
523 return Alloca;
526 for (CallInst *Op : Shape.SwiftErrorOps) {
527 auto MappedOp = VMap ? cast<CallInst>((*VMap)[Op]) : Op;
528 IRBuilder<> Builder(MappedOp);
530 // If there are no arguments, this is a 'get' operation.
531 Value *MappedResult;
532 if (Op->getNumArgOperands() == 0) {
533 auto ValueTy = Op->getType();
534 auto Slot = getSwiftErrorSlot(ValueTy);
535 MappedResult = Builder.CreateLoad(ValueTy, Slot);
536 } else {
537 assert(Op->getNumArgOperands() == 1);
538 auto Value = MappedOp->getArgOperand(0);
539 auto ValueTy = Value->getType();
540 auto Slot = getSwiftErrorSlot(ValueTy);
541 Builder.CreateStore(Value, Slot);
542 MappedResult = Slot;
545 MappedOp->replaceAllUsesWith(MappedResult);
546 MappedOp->eraseFromParent();
549 // If we're updating the original function, we've invalidated SwiftErrorOps.
550 if (VMap == nullptr) {
551 Shape.SwiftErrorOps.clear();
555 void CoroCloner::replaceSwiftErrorOps() {
556 ::replaceSwiftErrorOps(*NewF, Shape, &VMap);
559 void CoroCloner::replaceEntryBlock() {
560 // In the original function, the AllocaSpillBlock is a block immediately
561 // following the allocation of the frame object which defines GEPs for
562 // all the allocas that have been moved into the frame, and it ends by
563 // branching to the original beginning of the coroutine. Make this
564 // the entry block of the cloned function.
565 auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]);
566 Entry->setName("entry" + Suffix);
567 Entry->moveBefore(&NewF->getEntryBlock());
568 Entry->getTerminator()->eraseFromParent();
570 // Clear all predecessors of the new entry block. There should be
571 // exactly one predecessor, which we created when splitting out
572 // AllocaSpillBlock to begin with.
573 assert(Entry->hasOneUse());
574 auto BranchToEntry = cast<BranchInst>(Entry->user_back());
575 assert(BranchToEntry->isUnconditional());
576 Builder.SetInsertPoint(BranchToEntry);
577 Builder.CreateUnreachable();
578 BranchToEntry->eraseFromParent();
580 // TODO: move any allocas into Entry that weren't moved into the frame.
581 // (Currently we move all allocas into the frame.)
583 // Branch from the entry to the appropriate place.
584 Builder.SetInsertPoint(Entry);
585 switch (Shape.ABI) {
586 case coro::ABI::Switch: {
587 // In switch-lowering, we built a resume-entry block in the original
588 // function. Make the entry block branch to this.
589 auto *SwitchBB =
590 cast<BasicBlock>(VMap[Shape.SwitchLowering.ResumeEntryBlock]);
591 Builder.CreateBr(SwitchBB);
592 break;
595 case coro::ABI::Retcon:
596 case coro::ABI::RetconOnce: {
597 // In continuation ABIs, we want to branch to immediately after the
598 // active suspend point. Earlier phases will have put the suspend in its
599 // own basic block, so just thread our jump directly to its successor.
600 auto MappedCS = cast<CoroSuspendRetconInst>(VMap[ActiveSuspend]);
601 auto Branch = cast<BranchInst>(MappedCS->getNextNode());
602 assert(Branch->isUnconditional());
603 Builder.CreateBr(Branch->getSuccessor(0));
604 break;
609 /// Derive the value of the new frame pointer.
610 Value *CoroCloner::deriveNewFramePointer() {
611 // Builder should be inserting to the front of the new entry block.
613 switch (Shape.ABI) {
614 // In switch-lowering, the argument is the frame pointer.
615 case coro::ABI::Switch:
616 return &*NewF->arg_begin();
618 // In continuation-lowering, the argument is the opaque storage.
619 case coro::ABI::Retcon:
620 case coro::ABI::RetconOnce: {
621 Argument *NewStorage = &*NewF->arg_begin();
622 auto FramePtrTy = Shape.FrameTy->getPointerTo();
624 // If the storage is inline, just bitcast to the storage to the frame type.
625 if (Shape.RetconLowering.IsFrameInlineInStorage)
626 return Builder.CreateBitCast(NewStorage, FramePtrTy);
628 // Otherwise, load the real frame from the opaque storage.
629 auto FramePtrPtr =
630 Builder.CreateBitCast(NewStorage, FramePtrTy->getPointerTo());
631 return Builder.CreateLoad(FramePtrPtr);
634 llvm_unreachable("bad ABI");
637 /// Clone the body of the original function into a resume function of
638 /// some sort.
639 void CoroCloner::create() {
640 // Create the new function if we don't already have one.
641 if (!NewF) {
642 NewF = createCloneDeclaration(OrigF, Shape, Suffix,
643 OrigF.getParent()->end());
646 // Replace all args with undefs. The buildCoroutineFrame algorithm already
647 // rewritten access to the args that occurs after suspend points with loads
648 // and stores to/from the coroutine frame.
649 for (Argument &A : OrigF.args())
650 VMap[&A] = UndefValue::get(A.getType());
652 SmallVector<ReturnInst *, 4> Returns;
654 // Ignore attempts to change certain attributes of the function.
655 // TODO: maybe there should be a way to suppress this during cloning?
656 auto savedVisibility = NewF->getVisibility();
657 auto savedUnnamedAddr = NewF->getUnnamedAddr();
658 auto savedDLLStorageClass = NewF->getDLLStorageClass();
660 // NewF's linkage (which CloneFunctionInto does *not* change) might not
661 // be compatible with the visibility of OrigF (which it *does* change),
662 // so protect against that.
663 auto savedLinkage = NewF->getLinkage();
664 NewF->setLinkage(llvm::GlobalValue::ExternalLinkage);
666 CloneFunctionInto(NewF, &OrigF, VMap, /*ModuleLevelChanges=*/true, Returns);
668 NewF->setLinkage(savedLinkage);
669 NewF->setVisibility(savedVisibility);
670 NewF->setUnnamedAddr(savedUnnamedAddr);
671 NewF->setDLLStorageClass(savedDLLStorageClass);
673 auto &Context = NewF->getContext();
675 // Replace the attributes of the new function:
676 auto OrigAttrs = NewF->getAttributes();
677 auto NewAttrs = AttributeList();
679 switch (Shape.ABI) {
680 case coro::ABI::Switch:
681 // Bootstrap attributes by copying function attributes from the
682 // original function. This should include optimization settings and so on.
683 NewAttrs = NewAttrs.addAttributes(Context, AttributeList::FunctionIndex,
684 OrigAttrs.getFnAttributes());
685 break;
687 case coro::ABI::Retcon:
688 case coro::ABI::RetconOnce:
689 // If we have a continuation prototype, just use its attributes,
690 // full-stop.
691 NewAttrs = Shape.RetconLowering.ResumePrototype->getAttributes();
692 break;
695 // Make the frame parameter nonnull and noalias.
696 NewAttrs = NewAttrs.addParamAttribute(Context, 0, Attribute::NonNull);
697 NewAttrs = NewAttrs.addParamAttribute(Context, 0, Attribute::NoAlias);
699 switch (Shape.ABI) {
700 // In these ABIs, the cloned functions always return 'void', and the
701 // existing return sites are meaningless. Note that for unique
702 // continuations, this includes the returns associated with suspends;
703 // this is fine because we can't suspend twice.
704 case coro::ABI::Switch:
705 case coro::ABI::RetconOnce:
706 // Remove old returns.
707 for (ReturnInst *Return : Returns)
708 changeToUnreachable(Return, /*UseLLVMTrap=*/false);
709 break;
711 // With multi-suspend continuations, we'll already have eliminated the
712 // original returns and inserted returns before all the suspend points,
713 // so we want to leave any returns in place.
714 case coro::ABI::Retcon:
715 break;
718 NewF->setAttributes(NewAttrs);
719 NewF->setCallingConv(Shape.getResumeFunctionCC());
721 // Set up the new entry block.
722 replaceEntryBlock();
724 Builder.SetInsertPoint(&NewF->getEntryBlock().front());
725 NewFramePtr = deriveNewFramePointer();
727 // Remap frame pointer.
728 Value *OldFramePtr = VMap[Shape.FramePtr];
729 NewFramePtr->takeName(OldFramePtr);
730 OldFramePtr->replaceAllUsesWith(NewFramePtr);
732 // Remap vFrame pointer.
733 auto *NewVFrame = Builder.CreateBitCast(
734 NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame");
735 Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]);
736 OldVFrame->replaceAllUsesWith(NewVFrame);
738 switch (Shape.ABI) {
739 case coro::ABI::Switch:
740 // Rewrite final suspend handling as it is not done via switch (allows to
741 // remove final case from the switch, since it is undefined behavior to
742 // resume the coroutine suspended at the final suspend point.
743 if (Shape.SwitchLowering.HasFinalSuspend)
744 handleFinalSuspend();
745 break;
747 case coro::ABI::Retcon:
748 case coro::ABI::RetconOnce:
749 // Replace uses of the active suspend with the corresponding
750 // continuation-function arguments.
751 assert(ActiveSuspend != nullptr &&
752 "no active suspend when lowering a continuation-style coroutine");
753 replaceRetconSuspendUses();
754 break;
757 // Handle suspends.
758 replaceCoroSuspends();
760 // Handle swifterror.
761 replaceSwiftErrorOps();
763 // Remove coro.end intrinsics.
764 replaceCoroEnds();
766 // Eliminate coro.free from the clones, replacing it with 'null' in cleanup,
767 // to suppress deallocation code.
768 if (Shape.ABI == coro::ABI::Switch)
769 coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]),
770 /*Elide=*/ FKind == CoroCloner::Kind::SwitchCleanup);
773 // Create a resume clone by cloning the body of the original function, setting
774 // new entry block and replacing coro.suspend an appropriate value to force
775 // resume or cleanup pass for every suspend point.
776 static Function *createClone(Function &F, const Twine &Suffix,
777 coro::Shape &Shape, CoroCloner::Kind FKind) {
778 CoroCloner Cloner(F, Suffix, Shape, FKind);
779 Cloner.create();
780 return Cloner.getFunction();
783 /// Remove calls to llvm.coro.end in the original function.
784 static void removeCoroEnds(coro::Shape &Shape, CallGraph *CG) {
785 for (auto End : Shape.CoroEnds) {
786 replaceCoroEnd(End, Shape, Shape.FramePtr, /*in resume*/ false, CG);
790 static void replaceFrameSize(coro::Shape &Shape) {
791 if (Shape.CoroSizes.empty())
792 return;
794 // In the same function all coro.sizes should have the same result type.
795 auto *SizeIntrin = Shape.CoroSizes.back();
796 Module *M = SizeIntrin->getModule();
797 const DataLayout &DL = M->getDataLayout();
798 auto Size = DL.getTypeAllocSize(Shape.FrameTy);
799 auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size);
801 for (CoroSizeInst *CS : Shape.CoroSizes) {
802 CS->replaceAllUsesWith(SizeConstant);
803 CS->eraseFromParent();
807 // Create a global constant array containing pointers to functions provided and
808 // set Info parameter of CoroBegin to point at this constant. Example:
810 // @f.resumers = internal constant [2 x void(%f.frame*)*]
811 // [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy]
812 // define void @f() {
813 // ...
814 // call i8* @llvm.coro.begin(i8* null, i32 0, i8* null,
815 // i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*))
817 // Assumes that all the functions have the same signature.
818 static void setCoroInfo(Function &F, coro::Shape &Shape,
819 ArrayRef<Function *> Fns) {
820 // This only works under the switch-lowering ABI because coro elision
821 // only works on the switch-lowering ABI.
822 assert(Shape.ABI == coro::ABI::Switch);
824 SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end());
825 assert(!Args.empty());
826 Function *Part = *Fns.begin();
827 Module *M = Part->getParent();
828 auto *ArrTy = ArrayType::get(Part->getType(), Args.size());
830 auto *ConstVal = ConstantArray::get(ArrTy, Args);
831 auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true,
832 GlobalVariable::PrivateLinkage, ConstVal,
833 F.getName() + Twine(".resumers"));
835 // Update coro.begin instruction to refer to this constant.
836 LLVMContext &C = F.getContext();
837 auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C));
838 Shape.getSwitchCoroId()->setInfo(BC);
841 // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame.
842 static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn,
843 Function *DestroyFn, Function *CleanupFn) {
844 assert(Shape.ABI == coro::ABI::Switch);
846 IRBuilder<> Builder(Shape.FramePtr->getNextNode());
847 auto *ResumeAddr = Builder.CreateStructGEP(
848 Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Resume,
849 "resume.addr");
850 Builder.CreateStore(ResumeFn, ResumeAddr);
852 Value *DestroyOrCleanupFn = DestroyFn;
854 CoroIdInst *CoroId = Shape.getSwitchCoroId();
855 if (CoroAllocInst *CA = CoroId->getCoroAlloc()) {
856 // If there is a CoroAlloc and it returns false (meaning we elide the
857 // allocation, use CleanupFn instead of DestroyFn).
858 DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn);
861 auto *DestroyAddr = Builder.CreateStructGEP(
862 Shape.FrameTy, Shape.FramePtr, coro::Shape::SwitchFieldIndex::Destroy,
863 "destroy.addr");
864 Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr);
867 static void postSplitCleanup(Function &F) {
868 removeUnreachableBlocks(F);
870 // For now, we do a mandatory verification step because we don't
871 // entirely trust this pass. Note that we don't want to add a verifier
872 // pass to FPM below because it will also verify all the global data.
873 verifyFunction(F);
875 legacy::FunctionPassManager FPM(F.getParent());
877 FPM.add(createSCCPPass());
878 FPM.add(createCFGSimplificationPass());
879 FPM.add(createEarlyCSEPass());
880 FPM.add(createCFGSimplificationPass());
882 FPM.doInitialization();
883 FPM.run(F);
884 FPM.doFinalization();
887 // Assuming we arrived at the block NewBlock from Prev instruction, store
888 // PHI's incoming values in the ResolvedValues map.
889 static void
890 scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock,
891 DenseMap<Value *, Value *> &ResolvedValues) {
892 auto *PrevBB = Prev->getParent();
893 for (PHINode &PN : NewBlock->phis()) {
894 auto V = PN.getIncomingValueForBlock(PrevBB);
895 // See if we already resolved it.
896 auto VI = ResolvedValues.find(V);
897 if (VI != ResolvedValues.end())
898 V = VI->second;
899 // Remember the value.
900 ResolvedValues[&PN] = V;
904 // Replace a sequence of branches leading to a ret, with a clone of a ret
905 // instruction. Suspend instruction represented by a switch, track the PHI
906 // values and select the correct case successor when possible.
907 static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) {
908 DenseMap<Value *, Value *> ResolvedValues;
910 Instruction *I = InitialInst;
911 while (I->isTerminator()) {
912 if (isa<ReturnInst>(I)) {
913 if (I != InitialInst)
914 ReplaceInstWithInst(InitialInst, I->clone());
915 return true;
917 if (auto *BR = dyn_cast<BranchInst>(I)) {
918 if (BR->isUnconditional()) {
919 BasicBlock *BB = BR->getSuccessor(0);
920 scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
921 I = BB->getFirstNonPHIOrDbgOrLifetime();
922 continue;
924 } else if (auto *SI = dyn_cast<SwitchInst>(I)) {
925 Value *V = SI->getCondition();
926 auto it = ResolvedValues.find(V);
927 if (it != ResolvedValues.end())
928 V = it->second;
929 if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) {
930 BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor();
931 scanPHIsAndUpdateValueMap(I, BB, ResolvedValues);
932 I = BB->getFirstNonPHIOrDbgOrLifetime();
933 continue;
936 return false;
938 return false;
941 // Add musttail to any resume instructions that is immediately followed by a
942 // suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call
943 // for symmetrical coroutine control transfer (C++ Coroutines TS extension).
944 // This transformation is done only in the resume part of the coroutine that has
945 // identical signature and calling convention as the coro.resume call.
946 static void addMustTailToCoroResumes(Function &F) {
947 bool changed = false;
949 // Collect potential resume instructions.
950 SmallVector<CallInst *, 4> Resumes;
951 for (auto &I : instructions(F))
952 if (auto *Call = dyn_cast<CallInst>(&I))
953 if (auto *CalledValue = Call->getCalledValue())
954 // CoroEarly pass replaced coro resumes with indirect calls to an
955 // address return by CoroSubFnInst intrinsic. See if it is one of those.
956 if (isa<CoroSubFnInst>(CalledValue->stripPointerCasts()))
957 Resumes.push_back(Call);
959 // Set musttail on those that are followed by a ret instruction.
960 for (CallInst *Call : Resumes)
961 if (simplifyTerminatorLeadingToRet(Call->getNextNode())) {
962 Call->setTailCallKind(CallInst::TCK_MustTail);
963 changed = true;
966 if (changed)
967 removeUnreachableBlocks(F);
970 // Coroutine has no suspend points. Remove heap allocation for the coroutine
971 // frame if possible.
972 static void handleNoSuspendCoroutine(coro::Shape &Shape) {
973 auto *CoroBegin = Shape.CoroBegin;
974 auto *CoroId = CoroBegin->getId();
975 auto *AllocInst = CoroId->getCoroAlloc();
976 switch (Shape.ABI) {
977 case coro::ABI::Switch: {
978 auto SwitchId = cast<CoroIdInst>(CoroId);
979 coro::replaceCoroFree(SwitchId, /*Elide=*/AllocInst != nullptr);
980 if (AllocInst) {
981 IRBuilder<> Builder(AllocInst);
982 // FIXME: Need to handle overaligned members.
983 auto *Frame = Builder.CreateAlloca(Shape.FrameTy);
984 auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy());
985 AllocInst->replaceAllUsesWith(Builder.getFalse());
986 AllocInst->eraseFromParent();
987 CoroBegin->replaceAllUsesWith(VFrame);
988 } else {
989 CoroBegin->replaceAllUsesWith(CoroBegin->getMem());
991 break;
994 case coro::ABI::Retcon:
995 case coro::ABI::RetconOnce:
996 CoroBegin->replaceAllUsesWith(UndefValue::get(CoroBegin->getType()));
997 break;
1000 CoroBegin->eraseFromParent();
1003 // SimplifySuspendPoint needs to check that there is no calls between
1004 // coro_save and coro_suspend, since any of the calls may potentially resume
1005 // the coroutine and if that is the case we cannot eliminate the suspend point.
1006 static bool hasCallsInBlockBetween(Instruction *From, Instruction *To) {
1007 for (Instruction *I = From; I != To; I = I->getNextNode()) {
1008 // Assume that no intrinsic can resume the coroutine.
1009 if (isa<IntrinsicInst>(I))
1010 continue;
1012 if (CallSite(I))
1013 return true;
1015 return false;
1018 static bool hasCallsInBlocksBetween(BasicBlock *SaveBB, BasicBlock *ResDesBB) {
1019 SmallPtrSet<BasicBlock *, 8> Set;
1020 SmallVector<BasicBlock *, 8> Worklist;
1022 Set.insert(SaveBB);
1023 Worklist.push_back(ResDesBB);
1025 // Accumulate all blocks between SaveBB and ResDesBB. Because CoroSaveIntr
1026 // returns a token consumed by suspend instruction, all blocks in between
1027 // will have to eventually hit SaveBB when going backwards from ResDesBB.
1028 while (!Worklist.empty()) {
1029 auto *BB = Worklist.pop_back_val();
1030 Set.insert(BB);
1031 for (auto *Pred : predecessors(BB))
1032 if (Set.count(Pred) == 0)
1033 Worklist.push_back(Pred);
1036 // SaveBB and ResDesBB are checked separately in hasCallsBetween.
1037 Set.erase(SaveBB);
1038 Set.erase(ResDesBB);
1040 for (auto *BB : Set)
1041 if (hasCallsInBlockBetween(BB->getFirstNonPHI(), nullptr))
1042 return true;
1044 return false;
1047 static bool hasCallsBetween(Instruction *Save, Instruction *ResumeOrDestroy) {
1048 auto *SaveBB = Save->getParent();
1049 auto *ResumeOrDestroyBB = ResumeOrDestroy->getParent();
1051 if (SaveBB == ResumeOrDestroyBB)
1052 return hasCallsInBlockBetween(Save->getNextNode(), ResumeOrDestroy);
1054 // Any calls from Save to the end of the block?
1055 if (hasCallsInBlockBetween(Save->getNextNode(), nullptr))
1056 return true;
1058 // Any calls from begging of the block up to ResumeOrDestroy?
1059 if (hasCallsInBlockBetween(ResumeOrDestroyBB->getFirstNonPHI(),
1060 ResumeOrDestroy))
1061 return true;
1063 // Any calls in all of the blocks between SaveBB and ResumeOrDestroyBB?
1064 if (hasCallsInBlocksBetween(SaveBB, ResumeOrDestroyBB))
1065 return true;
1067 return false;
1070 // If a SuspendIntrin is preceded by Resume or Destroy, we can eliminate the
1071 // suspend point and replace it with nornal control flow.
1072 static bool simplifySuspendPoint(CoroSuspendInst *Suspend,
1073 CoroBeginInst *CoroBegin) {
1074 Instruction *Prev = Suspend->getPrevNode();
1075 if (!Prev) {
1076 auto *Pred = Suspend->getParent()->getSinglePredecessor();
1077 if (!Pred)
1078 return false;
1079 Prev = Pred->getTerminator();
1082 CallSite CS{Prev};
1083 if (!CS)
1084 return false;
1086 auto *CallInstr = CS.getInstruction();
1088 auto *Callee = CS.getCalledValue()->stripPointerCasts();
1090 // See if the callsite is for resumption or destruction of the coroutine.
1091 auto *SubFn = dyn_cast<CoroSubFnInst>(Callee);
1092 if (!SubFn)
1093 return false;
1095 // Does not refer to the current coroutine, we cannot do anything with it.
1096 if (SubFn->getFrame() != CoroBegin)
1097 return false;
1099 // See if the transformation is safe. Specifically, see if there are any
1100 // calls in between Save and CallInstr. They can potenitally resume the
1101 // coroutine rendering this optimization unsafe.
1102 auto *Save = Suspend->getCoroSave();
1103 if (hasCallsBetween(Save, CallInstr))
1104 return false;
1106 // Replace llvm.coro.suspend with the value that results in resumption over
1107 // the resume or cleanup path.
1108 Suspend->replaceAllUsesWith(SubFn->getRawIndex());
1109 Suspend->eraseFromParent();
1110 Save->eraseFromParent();
1112 // No longer need a call to coro.resume or coro.destroy.
1113 if (auto *Invoke = dyn_cast<InvokeInst>(CallInstr)) {
1114 BranchInst::Create(Invoke->getNormalDest(), Invoke);
1117 // Grab the CalledValue from CS before erasing the CallInstr.
1118 auto *CalledValue = CS.getCalledValue();
1119 CallInstr->eraseFromParent();
1121 // If no more users remove it. Usually it is a bitcast of SubFn.
1122 if (CalledValue != SubFn && CalledValue->user_empty())
1123 if (auto *I = dyn_cast<Instruction>(CalledValue))
1124 I->eraseFromParent();
1126 // Now we are good to remove SubFn.
1127 if (SubFn->user_empty())
1128 SubFn->eraseFromParent();
1130 return true;
1133 // Remove suspend points that are simplified.
1134 static void simplifySuspendPoints(coro::Shape &Shape) {
1135 // Currently, the only simplification we do is switch-lowering-specific.
1136 if (Shape.ABI != coro::ABI::Switch)
1137 return;
1139 auto &S = Shape.CoroSuspends;
1140 size_t I = 0, N = S.size();
1141 if (N == 0)
1142 return;
1143 while (true) {
1144 if (simplifySuspendPoint(cast<CoroSuspendInst>(S[I]), Shape.CoroBegin)) {
1145 if (--N == I)
1146 break;
1147 std::swap(S[I], S[N]);
1148 continue;
1150 if (++I == N)
1151 break;
1153 S.resize(N);
1156 static void splitSwitchCoroutine(Function &F, coro::Shape &Shape,
1157 SmallVectorImpl<Function *> &Clones) {
1158 assert(Shape.ABI == coro::ABI::Switch);
1160 createResumeEntryBlock(F, Shape);
1161 auto ResumeClone = createClone(F, ".resume", Shape,
1162 CoroCloner::Kind::SwitchResume);
1163 auto DestroyClone = createClone(F, ".destroy", Shape,
1164 CoroCloner::Kind::SwitchUnwind);
1165 auto CleanupClone = createClone(F, ".cleanup", Shape,
1166 CoroCloner::Kind::SwitchCleanup);
1168 postSplitCleanup(*ResumeClone);
1169 postSplitCleanup(*DestroyClone);
1170 postSplitCleanup(*CleanupClone);
1172 addMustTailToCoroResumes(*ResumeClone);
1174 // Store addresses resume/destroy/cleanup functions in the coroutine frame.
1175 updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone);
1177 assert(Clones.empty());
1178 Clones.push_back(ResumeClone);
1179 Clones.push_back(DestroyClone);
1180 Clones.push_back(CleanupClone);
1182 // Create a constant array referring to resume/destroy/clone functions pointed
1183 // by the last argument of @llvm.coro.info, so that CoroElide pass can
1184 // determined correct function to call.
1185 setCoroInfo(F, Shape, Clones);
1188 static void splitRetconCoroutine(Function &F, coro::Shape &Shape,
1189 SmallVectorImpl<Function *> &Clones) {
1190 assert(Shape.ABI == coro::ABI::Retcon ||
1191 Shape.ABI == coro::ABI::RetconOnce);
1192 assert(Clones.empty());
1194 // Reset various things that the optimizer might have decided it
1195 // "knows" about the coroutine function due to not seeing a return.
1196 F.removeFnAttr(Attribute::NoReturn);
1197 F.removeAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
1198 F.removeAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
1200 // Allocate the frame.
1201 auto *Id = cast<AnyCoroIdRetconInst>(Shape.CoroBegin->getId());
1202 Value *RawFramePtr;
1203 if (Shape.RetconLowering.IsFrameInlineInStorage) {
1204 RawFramePtr = Id->getStorage();
1205 } else {
1206 IRBuilder<> Builder(Id);
1208 // Determine the size of the frame.
1209 const DataLayout &DL = F.getParent()->getDataLayout();
1210 auto Size = DL.getTypeAllocSize(Shape.FrameTy);
1212 // Allocate. We don't need to update the call graph node because we're
1213 // going to recompute it from scratch after splitting.
1214 RawFramePtr = Shape.emitAlloc(Builder, Builder.getInt64(Size), nullptr);
1215 RawFramePtr =
1216 Builder.CreateBitCast(RawFramePtr, Shape.CoroBegin->getType());
1218 // Stash the allocated frame pointer in the continuation storage.
1219 auto Dest = Builder.CreateBitCast(Id->getStorage(),
1220 RawFramePtr->getType()->getPointerTo());
1221 Builder.CreateStore(RawFramePtr, Dest);
1224 // Map all uses of llvm.coro.begin to the allocated frame pointer.
1226 // Make sure we don't invalidate Shape.FramePtr.
1227 TrackingVH<Instruction> Handle(Shape.FramePtr);
1228 Shape.CoroBegin->replaceAllUsesWith(RawFramePtr);
1229 Shape.FramePtr = Handle.getValPtr();
1232 // Create a unique return block.
1233 BasicBlock *ReturnBB = nullptr;
1234 SmallVector<PHINode *, 4> ReturnPHIs;
1236 // Create all the functions in order after the main function.
1237 auto NextF = std::next(F.getIterator());
1239 // Create a continuation function for each of the suspend points.
1240 Clones.reserve(Shape.CoroSuspends.size());
1241 for (size_t i = 0, e = Shape.CoroSuspends.size(); i != e; ++i) {
1242 auto Suspend = cast<CoroSuspendRetconInst>(Shape.CoroSuspends[i]);
1244 // Create the clone declaration.
1245 auto Continuation =
1246 createCloneDeclaration(F, Shape, ".resume." + Twine(i), NextF);
1247 Clones.push_back(Continuation);
1249 // Insert a branch to the unified return block immediately before
1250 // the suspend point.
1251 auto SuspendBB = Suspend->getParent();
1252 auto NewSuspendBB = SuspendBB->splitBasicBlock(Suspend);
1253 auto Branch = cast<BranchInst>(SuspendBB->getTerminator());
1255 // Create the unified return block.
1256 if (!ReturnBB) {
1257 // Place it before the first suspend.
1258 ReturnBB = BasicBlock::Create(F.getContext(), "coro.return", &F,
1259 NewSuspendBB);
1260 Shape.RetconLowering.ReturnBlock = ReturnBB;
1262 IRBuilder<> Builder(ReturnBB);
1264 // Create PHIs for all the return values.
1265 assert(ReturnPHIs.empty());
1267 // First, the continuation.
1268 ReturnPHIs.push_back(Builder.CreatePHI(Continuation->getType(),
1269 Shape.CoroSuspends.size()));
1271 // Next, all the directly-yielded values.
1272 for (auto ResultTy : Shape.getRetconResultTypes())
1273 ReturnPHIs.push_back(Builder.CreatePHI(ResultTy,
1274 Shape.CoroSuspends.size()));
1276 // Build the return value.
1277 auto RetTy = F.getReturnType();
1279 // Cast the continuation value if necessary.
1280 // We can't rely on the types matching up because that type would
1281 // have to be infinite.
1282 auto CastedContinuationTy =
1283 (ReturnPHIs.size() == 1 ? RetTy : RetTy->getStructElementType(0));
1284 auto *CastedContinuation =
1285 Builder.CreateBitCast(ReturnPHIs[0], CastedContinuationTy);
1287 Value *RetV;
1288 if (ReturnPHIs.size() == 1) {
1289 RetV = CastedContinuation;
1290 } else {
1291 RetV = UndefValue::get(RetTy);
1292 RetV = Builder.CreateInsertValue(RetV, CastedContinuation, 0);
1293 for (size_t I = 1, E = ReturnPHIs.size(); I != E; ++I)
1294 RetV = Builder.CreateInsertValue(RetV, ReturnPHIs[I], I);
1297 Builder.CreateRet(RetV);
1300 // Branch to the return block.
1301 Branch->setSuccessor(0, ReturnBB);
1302 ReturnPHIs[0]->addIncoming(Continuation, SuspendBB);
1303 size_t NextPHIIndex = 1;
1304 for (auto &VUse : Suspend->value_operands())
1305 ReturnPHIs[NextPHIIndex++]->addIncoming(&*VUse, SuspendBB);
1306 assert(NextPHIIndex == ReturnPHIs.size());
1309 assert(Clones.size() == Shape.CoroSuspends.size());
1310 for (size_t i = 0, e = Shape.CoroSuspends.size(); i != e; ++i) {
1311 auto Suspend = Shape.CoroSuspends[i];
1312 auto Clone = Clones[i];
1314 CoroCloner(F, "resume." + Twine(i), Shape, Clone, Suspend).create();
1318 namespace {
1319 class PrettyStackTraceFunction : public PrettyStackTraceEntry {
1320 Function &F;
1321 public:
1322 PrettyStackTraceFunction(Function &F) : F(F) {}
1323 void print(raw_ostream &OS) const override {
1324 OS << "While splitting coroutine ";
1325 F.printAsOperand(OS, /*print type*/ false, F.getParent());
1326 OS << "\n";
1331 static void splitCoroutine(Function &F, coro::Shape &Shape,
1332 SmallVectorImpl<Function *> &Clones) {
1333 switch (Shape.ABI) {
1334 case coro::ABI::Switch:
1335 return splitSwitchCoroutine(F, Shape, Clones);
1336 case coro::ABI::Retcon:
1337 case coro::ABI::RetconOnce:
1338 return splitRetconCoroutine(F, Shape, Clones);
1340 llvm_unreachable("bad ABI kind");
1343 static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) {
1344 PrettyStackTraceFunction prettyStackTrace(F);
1346 // The suspend-crossing algorithm in buildCoroutineFrame get tripped
1347 // up by uses in unreachable blocks, so remove them as a first pass.
1348 removeUnreachableBlocks(F);
1350 coro::Shape Shape(F);
1351 if (!Shape.CoroBegin)
1352 return;
1354 simplifySuspendPoints(Shape);
1355 buildCoroutineFrame(F, Shape);
1356 replaceFrameSize(Shape);
1358 SmallVector<Function*, 4> Clones;
1360 // If there are no suspend points, no split required, just remove
1361 // the allocation and deallocation blocks, they are not needed.
1362 if (Shape.CoroSuspends.empty()) {
1363 handleNoSuspendCoroutine(Shape);
1364 } else {
1365 splitCoroutine(F, Shape, Clones);
1368 // Replace all the swifterror operations in the original function.
1369 // This invalidates SwiftErrorOps in the Shape.
1370 replaceSwiftErrorOps(F, Shape, nullptr);
1372 removeCoroEnds(Shape, &CG);
1373 postSplitCleanup(F);
1375 // Update call graph and add the functions we created to the SCC.
1376 coro::updateCallGraph(F, Clones, CG, SCC);
1379 // When we see the coroutine the first time, we insert an indirect call to a
1380 // devirt trigger function and mark the coroutine that it is now ready for
1381 // split.
1382 static void prepareForSplit(Function &F, CallGraph &CG) {
1383 Module &M = *F.getParent();
1384 LLVMContext &Context = F.getContext();
1385 #ifndef NDEBUG
1386 Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN);
1387 assert(DevirtFn && "coro.devirt.trigger function not found");
1388 #endif
1390 F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
1392 // Insert an indirect call sequence that will be devirtualized by CoroElide
1393 // pass:
1394 // %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1)
1395 // %1 = bitcast i8* %0 to void(i8*)*
1396 // call void %1(i8* null)
1397 coro::LowererBase Lowerer(M);
1398 Instruction *InsertPt = F.getEntryBlock().getTerminator();
1399 auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(Context));
1400 auto *DevirtFnAddr =
1401 Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt);
1402 FunctionType *FnTy = FunctionType::get(Type::getVoidTy(Context),
1403 {Type::getInt8PtrTy(Context)}, false);
1404 auto *IndirectCall = CallInst::Create(FnTy, DevirtFnAddr, Null, "", InsertPt);
1406 // Update CG graph with an indirect call we just added.
1407 CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode());
1410 // Make sure that there is a devirtualization trigger function that CoroSplit
1411 // pass uses the force restart CGSCC pipeline. If devirt trigger function is not
1412 // found, we will create one and add it to the current SCC.
1413 static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) {
1414 Module &M = CG.getModule();
1415 if (M.getFunction(CORO_DEVIRT_TRIGGER_FN))
1416 return;
1418 LLVMContext &C = M.getContext();
1419 auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C),
1420 /*isVarArg=*/false);
1421 Function *DevirtFn =
1422 Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
1423 CORO_DEVIRT_TRIGGER_FN, &M);
1424 DevirtFn->addFnAttr(Attribute::AlwaysInline);
1425 auto *Entry = BasicBlock::Create(C, "entry", DevirtFn);
1426 ReturnInst::Create(C, Entry);
1428 auto *Node = CG.getOrInsertFunction(DevirtFn);
1430 SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end());
1431 Nodes.push_back(Node);
1432 SCC.initialize(Nodes);
1435 /// Replace a call to llvm.coro.prepare.retcon.
1436 static void replacePrepare(CallInst *Prepare, CallGraph &CG) {
1437 auto CastFn = Prepare->getArgOperand(0); // as an i8*
1438 auto Fn = CastFn->stripPointerCasts(); // as its original type
1440 // Find call graph nodes for the preparation.
1441 CallGraphNode *PrepareUserNode = nullptr, *FnNode = nullptr;
1442 if (auto ConcreteFn = dyn_cast<Function>(Fn)) {
1443 PrepareUserNode = CG[Prepare->getFunction()];
1444 FnNode = CG[ConcreteFn];
1447 // Attempt to peephole this pattern:
1448 // %0 = bitcast [[TYPE]] @some_function to i8*
1449 // %1 = call @llvm.coro.prepare.retcon(i8* %0)
1450 // %2 = bitcast %1 to [[TYPE]]
1451 // ==>
1452 // %2 = @some_function
1453 for (auto UI = Prepare->use_begin(), UE = Prepare->use_end();
1454 UI != UE; ) {
1455 // Look for bitcasts back to the original function type.
1456 auto *Cast = dyn_cast<BitCastInst>((UI++)->getUser());
1457 if (!Cast || Cast->getType() != Fn->getType()) continue;
1459 // Check whether the replacement will introduce new direct calls.
1460 // If so, we'll need to update the call graph.
1461 if (PrepareUserNode) {
1462 for (auto &Use : Cast->uses()) {
1463 if (auto *CB = dyn_cast<CallBase>(Use.getUser())) {
1464 if (!CB->isCallee(&Use))
1465 continue;
1466 PrepareUserNode->removeCallEdgeFor(*CB);
1467 PrepareUserNode->addCalledFunction(CB, FnNode);
1472 // Replace and remove the cast.
1473 Cast->replaceAllUsesWith(Fn);
1474 Cast->eraseFromParent();
1477 // Replace any remaining uses with the function as an i8*.
1478 // This can never directly be a callee, so we don't need to update CG.
1479 Prepare->replaceAllUsesWith(CastFn);
1480 Prepare->eraseFromParent();
1482 // Kill dead bitcasts.
1483 while (auto *Cast = dyn_cast<BitCastInst>(CastFn)) {
1484 if (!Cast->use_empty()) break;
1485 CastFn = Cast->getOperand(0);
1486 Cast->eraseFromParent();
1490 /// Remove calls to llvm.coro.prepare.retcon, a barrier meant to prevent
1491 /// IPO from operating on calls to a retcon coroutine before it's been
1492 /// split. This is only safe to do after we've split all retcon
1493 /// coroutines in the module. We can do that this in this pass because
1494 /// this pass does promise to split all retcon coroutines (as opposed to
1495 /// switch coroutines, which are lowered in multiple stages).
1496 static bool replaceAllPrepares(Function *PrepareFn, CallGraph &CG) {
1497 bool Changed = false;
1498 for (auto PI = PrepareFn->use_begin(), PE = PrepareFn->use_end();
1499 PI != PE; ) {
1500 // Intrinsics can only be used in calls.
1501 auto *Prepare = cast<CallInst>((PI++)->getUser());
1502 replacePrepare(Prepare, CG);
1503 Changed = true;
1506 return Changed;
1509 //===----------------------------------------------------------------------===//
1510 // Top Level Driver
1511 //===----------------------------------------------------------------------===//
1513 namespace {
1515 struct CoroSplit : public CallGraphSCCPass {
1516 static char ID; // Pass identification, replacement for typeid
1518 CoroSplit() : CallGraphSCCPass(ID) {
1519 initializeCoroSplitPass(*PassRegistry::getPassRegistry());
1522 bool Run = false;
1524 // A coroutine is identified by the presence of coro.begin intrinsic, if
1525 // we don't have any, this pass has nothing to do.
1526 bool doInitialization(CallGraph &CG) override {
1527 Run = coro::declaresIntrinsics(CG.getModule(),
1528 {"llvm.coro.begin",
1529 "llvm.coro.prepare.retcon"});
1530 return CallGraphSCCPass::doInitialization(CG);
1533 bool runOnSCC(CallGraphSCC &SCC) override {
1534 if (!Run)
1535 return false;
1537 // Check for uses of llvm.coro.prepare.retcon.
1538 auto PrepareFn =
1539 SCC.getCallGraph().getModule().getFunction("llvm.coro.prepare.retcon");
1540 if (PrepareFn && PrepareFn->use_empty())
1541 PrepareFn = nullptr;
1543 // Find coroutines for processing.
1544 SmallVector<Function *, 4> Coroutines;
1545 for (CallGraphNode *CGN : SCC)
1546 if (auto *F = CGN->getFunction())
1547 if (F->hasFnAttribute(CORO_PRESPLIT_ATTR))
1548 Coroutines.push_back(F);
1550 if (Coroutines.empty() && !PrepareFn)
1551 return false;
1553 CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
1555 if (Coroutines.empty())
1556 return replaceAllPrepares(PrepareFn, CG);
1558 createDevirtTriggerFunc(CG, SCC);
1560 // Split all the coroutines.
1561 for (Function *F : Coroutines) {
1562 Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR);
1563 StringRef Value = Attr.getValueAsString();
1564 LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName()
1565 << "' state: " << Value << "\n");
1566 if (Value == UNPREPARED_FOR_SPLIT) {
1567 prepareForSplit(*F, CG);
1568 continue;
1570 F->removeFnAttr(CORO_PRESPLIT_ATTR);
1571 splitCoroutine(*F, CG, SCC);
1574 if (PrepareFn)
1575 replaceAllPrepares(PrepareFn, CG);
1577 return true;
1580 void getAnalysisUsage(AnalysisUsage &AU) const override {
1581 CallGraphSCCPass::getAnalysisUsage(AU);
1584 StringRef getPassName() const override { return "Coroutine Splitting"; }
1587 } // end anonymous namespace
1589 char CoroSplit::ID = 0;
1591 INITIALIZE_PASS_BEGIN(
1592 CoroSplit, "coro-split",
1593 "Split coroutine into a set of functions driving its state machine", false,
1594 false)
1595 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
1596 INITIALIZE_PASS_END(
1597 CoroSplit, "coro-split",
1598 "Split coroutine into a set of functions driving its state machine", false,
1599 false)
1601 Pass *llvm::createCoroSplitPass() { return new CoroSplit(); }