1 //===- CoroElide.cpp - Coroutine Frame Allocation Elision Pass ------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
8 // This pass replaces dynamic allocation of coroutine frame with alloca and
9 // replaces calls to llvm.coro.resume and llvm.coro.destroy with direct calls
10 // to coroutine sub-functions.
11 //===----------------------------------------------------------------------===//
13 #include "CoroInternal.h"
14 #include "llvm/Analysis/AliasAnalysis.h"
15 #include "llvm/Analysis/InstructionSimplify.h"
16 #include "llvm/IR/Dominators.h"
17 #include "llvm/IR/InstIterator.h"
18 #include "llvm/Pass.h"
19 #include "llvm/Support/ErrorHandling.h"
23 #define DEBUG_TYPE "coro-elide"
26 // Created on demand if CoroElide pass has work to do.
27 struct Lowerer
: coro::LowererBase
{
28 SmallVector
<CoroIdInst
*, 4> CoroIds
;
29 SmallVector
<CoroBeginInst
*, 1> CoroBegins
;
30 SmallVector
<CoroAllocInst
*, 1> CoroAllocs
;
31 SmallVector
<CoroSubFnInst
*, 4> ResumeAddr
;
32 SmallVector
<CoroSubFnInst
*, 4> DestroyAddr
;
33 SmallVector
<CoroFreeInst
*, 1> CoroFrees
;
35 Lowerer(Module
&M
) : LowererBase(M
) {}
37 void elideHeapAllocations(Function
*F
, Type
*FrameTy
, AAResults
&AA
);
38 bool shouldElide(Function
*F
, DominatorTree
&DT
) const;
39 bool processCoroId(CoroIdInst
*, AAResults
&AA
, DominatorTree
&DT
);
41 } // end anonymous namespace
43 // Go through the list of coro.subfn.addr intrinsics and replace them with the
45 static void replaceWithConstant(Constant
*Value
,
46 SmallVectorImpl
<CoroSubFnInst
*> &Users
) {
50 // See if we need to bitcast the constant to match the type of the intrinsic
51 // being replaced. Note: All coro.subfn.addr intrinsics return the same type,
52 // so we only need to examine the type of the first one in the list.
53 Type
*IntrTy
= Users
.front()->getType();
54 Type
*ValueTy
= Value
->getType();
55 if (ValueTy
!= IntrTy
) {
56 // May need to tweak the function type to match the type expected at the
58 assert(ValueTy
->isPointerTy() && IntrTy
->isPointerTy());
59 Value
= ConstantExpr::getBitCast(Value
, IntrTy
);
62 // Now the value type matches the type of the intrinsic. Replace them all!
63 for (CoroSubFnInst
*I
: Users
)
64 replaceAndRecursivelySimplify(I
, Value
);
67 // See if any operand of the call instruction references the coroutine frame.
68 static bool operandReferences(CallInst
*CI
, AllocaInst
*Frame
, AAResults
&AA
) {
69 for (Value
*Op
: CI
->operand_values())
70 if (AA
.alias(Op
, Frame
) != NoAlias
)
75 // Look for any tail calls referencing the coroutine frame and remove tail
76 // attribute from them, since now coroutine frame resides on the stack and tail
77 // call implies that the function does not references anything on the stack.
78 static void removeTailCallAttribute(AllocaInst
*Frame
, AAResults
&AA
) {
79 Function
&F
= *Frame
->getFunction();
80 for (Instruction
&I
: instructions(F
))
81 if (auto *Call
= dyn_cast
<CallInst
>(&I
))
82 if (Call
->isTailCall() && operandReferences(Call
, Frame
, AA
)) {
83 // FIXME: If we ever hit this check. Evaluate whether it is more
84 // appropriate to retain musttail and allow the code to compile.
85 if (Call
->isMustTailCall())
86 report_fatal_error("Call referring to the coroutine frame cannot be "
87 "marked as musttail");
88 Call
->setTailCall(false);
92 // Given a resume function @f.resume(%f.frame* %frame), returns %f.frame type.
93 static Type
*getFrameType(Function
*Resume
) {
94 auto *ArgType
= Resume
->arg_begin()->getType();
95 return cast
<PointerType
>(ArgType
)->getElementType();
98 // Finds first non alloca instruction in the entry block of a function.
99 static Instruction
*getFirstNonAllocaInTheEntryBlock(Function
*F
) {
100 for (Instruction
&I
: F
->getEntryBlock())
101 if (!isa
<AllocaInst
>(&I
))
103 llvm_unreachable("no terminator in the entry block");
106 // To elide heap allocations we need to suppress code blocks guarded by
107 // llvm.coro.alloc and llvm.coro.free instructions.
108 void Lowerer::elideHeapAllocations(Function
*F
, Type
*FrameTy
, AAResults
&AA
) {
109 LLVMContext
&C
= FrameTy
->getContext();
111 getFirstNonAllocaInTheEntryBlock(CoroIds
.front()->getFunction());
113 // Replacing llvm.coro.alloc with false will suppress dynamic
114 // allocation as it is expected for the frontend to generate the code that
117 // mem = coro.alloc(id) ? malloc(coro.size()) : 0;
118 // coro.begin(id, mem)
119 auto *False
= ConstantInt::getFalse(C
);
120 for (auto *CA
: CoroAllocs
) {
121 CA
->replaceAllUsesWith(False
);
122 CA
->eraseFromParent();
125 // FIXME: Design how to transmit alignment information for every alloca that
126 // is spilled into the coroutine frame and recreate the alignment information
127 // here. Possibly we will need to do a mini SROA here and break the coroutine
128 // frame into individual AllocaInst recreating the original alignment.
129 const DataLayout
&DL
= F
->getParent()->getDataLayout();
130 auto *Frame
= new AllocaInst(FrameTy
, DL
.getAllocaAddrSpace(), "", InsertPt
);
132 new BitCastInst(Frame
, Type::getInt8PtrTy(C
), "vFrame", InsertPt
);
134 for (auto *CB
: CoroBegins
) {
135 CB
->replaceAllUsesWith(FrameVoidPtr
);
136 CB
->eraseFromParent();
139 // Since now coroutine frame lives on the stack we need to make sure that
140 // any tail call referencing it, must be made non-tail call.
141 removeTailCallAttribute(Frame
, AA
);
144 bool Lowerer::shouldElide(Function
*F
, DominatorTree
&DT
) const {
145 // If no CoroAllocs, we cannot suppress allocation, so elision is not
147 if (CoroAllocs
.empty())
150 // Check that for every coro.begin there is a coro.destroy directly
151 // referencing the SSA value of that coro.begin along a non-exceptional path.
152 // If the value escaped, then coro.destroy would have been referencing a
153 // memory location storing that value and not the virtual register.
155 // First gather all of the non-exceptional terminators for the function.
156 SmallPtrSet
<Instruction
*, 8> Terminators
;
157 for (BasicBlock
&B
: *F
) {
158 auto *TI
= B
.getTerminator();
159 if (TI
->getNumSuccessors() == 0 && !TI
->isExceptionalTerminator() &&
160 !isa
<UnreachableInst
>(TI
))
161 Terminators
.insert(TI
);
164 // Filter out the coro.destroy that lie along exceptional paths.
165 SmallPtrSet
<CoroSubFnInst
*, 4> DAs
;
166 for (CoroSubFnInst
*DA
: DestroyAddr
) {
167 for (Instruction
*TI
: Terminators
) {
168 if (DT
.dominates(DA
, TI
)) {
175 // Find all the coro.begin referenced by coro.destroy along happy paths.
176 SmallPtrSet
<CoroBeginInst
*, 8> ReferencedCoroBegins
;
177 for (CoroSubFnInst
*DA
: DAs
) {
178 if (auto *CB
= dyn_cast
<CoroBeginInst
>(DA
->getFrame()))
179 ReferencedCoroBegins
.insert(CB
);
184 // If size of the set is the same as total number of coro.begin, that means we
185 // found a coro.free or coro.destroy referencing each coro.begin, so we can
186 // perform heap elision.
187 return ReferencedCoroBegins
.size() == CoroBegins
.size();
190 bool Lowerer::processCoroId(CoroIdInst
*CoroId
, AAResults
&AA
,
198 // Collect all coro.begin and coro.allocs associated with this coro.id.
199 for (User
*U
: CoroId
->users()) {
200 if (auto *CB
= dyn_cast
<CoroBeginInst
>(U
))
201 CoroBegins
.push_back(CB
);
202 else if (auto *CA
= dyn_cast
<CoroAllocInst
>(U
))
203 CoroAllocs
.push_back(CA
);
204 else if (auto *CF
= dyn_cast
<CoroFreeInst
>(U
))
205 CoroFrees
.push_back(CF
);
208 // Collect all coro.subfn.addrs associated with coro.begin.
209 // Note, we only devirtualize the calls if their coro.subfn.addr refers to
210 // coro.begin directly. If we run into cases where this check is too
211 // conservative, we can consider relaxing the check.
212 for (CoroBeginInst
*CB
: CoroBegins
) {
213 for (User
*U
: CB
->users())
214 if (auto *II
= dyn_cast
<CoroSubFnInst
>(U
))
215 switch (II
->getIndex()) {
216 case CoroSubFnInst::ResumeIndex
:
217 ResumeAddr
.push_back(II
);
219 case CoroSubFnInst::DestroyIndex
:
220 DestroyAddr
.push_back(II
);
223 llvm_unreachable("unexpected coro.subfn.addr constant");
227 // PostSplit coro.id refers to an array of subfunctions in its Info
229 ConstantArray
*Resumers
= CoroId
->getInfo().Resumers
;
230 assert(Resumers
&& "PostSplit coro.id Info argument must refer to an array"
231 "of coroutine subfunctions");
232 auto *ResumeAddrConstant
=
233 ConstantExpr::getExtractValue(Resumers
, CoroSubFnInst::ResumeIndex
);
235 replaceWithConstant(ResumeAddrConstant
, ResumeAddr
);
237 bool ShouldElide
= shouldElide(CoroId
->getFunction(), DT
);
239 auto *DestroyAddrConstant
= ConstantExpr::getExtractValue(
241 ShouldElide
? CoroSubFnInst::CleanupIndex
: CoroSubFnInst::DestroyIndex
);
243 replaceWithConstant(DestroyAddrConstant
, DestroyAddr
);
246 auto *FrameTy
= getFrameType(cast
<Function
>(ResumeAddrConstant
));
247 elideHeapAllocations(CoroId
->getFunction(), FrameTy
, AA
);
248 coro::replaceCoroFree(CoroId
, /*Elide=*/true);
254 // See if there are any coro.subfn.addr instructions referring to coro.devirt
255 // trigger, if so, replace them with a direct call to devirt trigger function.
256 static bool replaceDevirtTrigger(Function
&F
) {
257 SmallVector
<CoroSubFnInst
*, 1> DevirtAddr
;
258 for (auto &I
: instructions(F
))
259 if (auto *SubFn
= dyn_cast
<CoroSubFnInst
>(&I
))
260 if (SubFn
->getIndex() == CoroSubFnInst::RestartTrigger
)
261 DevirtAddr
.push_back(SubFn
);
263 if (DevirtAddr
.empty())
266 Module
&M
= *F
.getParent();
267 Function
*DevirtFn
= M
.getFunction(CORO_DEVIRT_TRIGGER_FN
);
268 assert(DevirtFn
&& "coro.devirt.fn not found");
269 replaceWithConstant(DevirtFn
, DevirtAddr
);
274 //===----------------------------------------------------------------------===//
276 //===----------------------------------------------------------------------===//
279 struct CoroElide
: FunctionPass
{
281 CoroElide() : FunctionPass(ID
) {
282 initializeCoroElidePass(*PassRegistry::getPassRegistry());
285 std::unique_ptr
<Lowerer
> L
;
287 bool doInitialization(Module
&M
) override
{
288 if (coro::declaresIntrinsics(M
, {"llvm.coro.id"}))
289 L
= std::make_unique
<Lowerer
>(M
);
293 bool runOnFunction(Function
&F
) override
{
297 bool Changed
= false;
299 if (F
.hasFnAttribute(CORO_PRESPLIT_ATTR
))
300 Changed
= replaceDevirtTrigger(F
);
304 // Collect all PostSplit coro.ids.
305 for (auto &I
: instructions(F
))
306 if (auto *CII
= dyn_cast
<CoroIdInst
>(&I
))
307 if (CII
->getInfo().isPostSplit())
308 // If it is the coroutine itself, don't touch it.
309 if (CII
->getCoroutine() != CII
->getFunction())
310 L
->CoroIds
.push_back(CII
);
312 // If we did not find any coro.id, there is nothing to do.
313 if (L
->CoroIds
.empty())
316 AAResults
&AA
= getAnalysis
<AAResultsWrapperPass
>().getAAResults();
317 DominatorTree
&DT
= getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
319 for (auto *CII
: L
->CoroIds
)
320 Changed
|= L
->processCoroId(CII
, AA
, DT
);
324 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
325 AU
.addRequired
<AAResultsWrapperPass
>();
326 AU
.addRequired
<DominatorTreeWrapperPass
>();
328 StringRef
getPassName() const override
{ return "Coroutine Elision"; }
332 char CoroElide::ID
= 0;
333 INITIALIZE_PASS_BEGIN(
334 CoroElide
, "coro-elide",
335 "Coroutine frame allocation elision and indirect calls replacement", false,
337 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass
)
339 CoroElide
, "coro-elide",
340 "Coroutine frame allocation elision and indirect calls replacement", false,
343 Pass
*llvm::createCoroElidePass() { return new CoroElide(); }