[llvm-exegesis] [NFC] Fixing typo.
[llvm-complete.git] / lib / Transforms / Coroutines / Coroutines.cpp
blob7bd87fd29a630c51a913a10c2eaef2db9892b92f
1 //===- Coroutines.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the common infrastructure for Coroutine Passes.
11 //===----------------------------------------------------------------------===//
13 #include "llvm/Transforms/Coroutines.h"
14 #include "llvm-c/Transforms/Coroutines.h"
15 #include "CoroInstr.h"
16 #include "CoroInternal.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/Analysis/CallGraph.h"
20 #include "llvm/Analysis/CallGraphSCCPass.h"
21 #include "llvm/Transforms/Utils/Local.h"
22 #include "llvm/IR/Attributes.h"
23 #include "llvm/IR/CallSite.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/InstIterator.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/IntrinsicInst.h"
30 #include "llvm/IR/Intrinsics.h"
31 #include "llvm/IR/LegacyPassManager.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/Type.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Transforms/IPO.h"
37 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
38 #include <cassert>
39 #include <cstddef>
40 #include <utility>
42 using namespace llvm;
44 void llvm::initializeCoroutines(PassRegistry &Registry) {
45 initializeCoroEarlyPass(Registry);
46 initializeCoroSplitPass(Registry);
47 initializeCoroElidePass(Registry);
48 initializeCoroCleanupPass(Registry);
51 static void addCoroutineOpt0Passes(const PassManagerBuilder &Builder,
52 legacy::PassManagerBase &PM) {
53 PM.add(createCoroSplitPass());
54 PM.add(createCoroElidePass());
56 PM.add(createBarrierNoopPass());
57 PM.add(createCoroCleanupPass());
60 static void addCoroutineEarlyPasses(const PassManagerBuilder &Builder,
61 legacy::PassManagerBase &PM) {
62 PM.add(createCoroEarlyPass());
65 static void addCoroutineScalarOptimizerPasses(const PassManagerBuilder &Builder,
66 legacy::PassManagerBase &PM) {
67 PM.add(createCoroElidePass());
70 static void addCoroutineSCCPasses(const PassManagerBuilder &Builder,
71 legacy::PassManagerBase &PM) {
72 PM.add(createCoroSplitPass());
75 static void addCoroutineOptimizerLastPasses(const PassManagerBuilder &Builder,
76 legacy::PassManagerBase &PM) {
77 PM.add(createCoroCleanupPass());
80 void llvm::addCoroutinePassesToExtensionPoints(PassManagerBuilder &Builder) {
81 Builder.addExtension(PassManagerBuilder::EP_EarlyAsPossible,
82 addCoroutineEarlyPasses);
83 Builder.addExtension(PassManagerBuilder::EP_EnabledOnOptLevel0,
84 addCoroutineOpt0Passes);
85 Builder.addExtension(PassManagerBuilder::EP_CGSCCOptimizerLate,
86 addCoroutineSCCPasses);
87 Builder.addExtension(PassManagerBuilder::EP_ScalarOptimizerLate,
88 addCoroutineScalarOptimizerPasses);
89 Builder.addExtension(PassManagerBuilder::EP_OptimizerLast,
90 addCoroutineOptimizerLastPasses);
93 // Construct the lowerer base class and initialize its members.
94 coro::LowererBase::LowererBase(Module &M)
95 : TheModule(M), Context(M.getContext()),
96 Int8Ptr(Type::getInt8PtrTy(Context)),
97 ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
98 /*isVarArg=*/false)),
99 NullPtr(ConstantPointerNull::get(Int8Ptr)) {}
101 // Creates a sequence of instructions to obtain a resume function address using
102 // llvm.coro.subfn.addr. It generates the following sequence:
104 // call i8* @llvm.coro.subfn.addr(i8* %Arg, i8 %index)
105 // bitcast i8* %2 to void(i8*)*
107 Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index,
108 Instruction *InsertPt) {
109 auto *IndexVal = ConstantInt::get(Type::getInt8Ty(Context), Index);
110 auto *Fn = Intrinsic::getDeclaration(&TheModule, Intrinsic::coro_subfn_addr);
112 assert(Index >= CoroSubFnInst::IndexFirst &&
113 Index < CoroSubFnInst::IndexLast &&
114 "makeSubFnCall: Index value out of range");
115 auto *Call = CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt);
117 auto *Bitcast =
118 new BitCastInst(Call, ResumeFnType->getPointerTo(), "", InsertPt);
119 return Bitcast;
122 #ifndef NDEBUG
123 static bool isCoroutineIntrinsicName(StringRef Name) {
124 // NOTE: Must be sorted!
125 static const char *const CoroIntrinsics[] = {
126 "llvm.coro.alloc", "llvm.coro.begin", "llvm.coro.destroy",
127 "llvm.coro.done", "llvm.coro.end", "llvm.coro.frame",
128 "llvm.coro.free", "llvm.coro.id", "llvm.coro.noop",
129 "llvm.coro.param", "llvm.coro.promise", "llvm.coro.resume",
130 "llvm.coro.save", "llvm.coro.size", "llvm.coro.subfn.addr",
131 "llvm.coro.suspend",
133 return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics, Name) != -1;
135 #endif
137 // Verifies if a module has named values listed. Also, in debug mode verifies
138 // that names are intrinsic names.
139 bool coro::declaresIntrinsics(Module &M,
140 std::initializer_list<StringRef> List) {
141 for (StringRef Name : List) {
142 assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic");
143 if (M.getNamedValue(Name))
144 return true;
147 return false;
150 // Replace all coro.frees associated with the provided CoroId either with 'null'
151 // if Elide is true and with its frame parameter otherwise.
152 void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) {
153 SmallVector<CoroFreeInst *, 4> CoroFrees;
154 for (User *U : CoroId->users())
155 if (auto CF = dyn_cast<CoroFreeInst>(U))
156 CoroFrees.push_back(CF);
158 if (CoroFrees.empty())
159 return;
161 Value *Replacement =
162 Elide ? ConstantPointerNull::get(Type::getInt8PtrTy(CoroId->getContext()))
163 : CoroFrees.front()->getFrame();
165 for (CoroFreeInst *CF : CoroFrees) {
166 CF->replaceAllUsesWith(Replacement);
167 CF->eraseFromParent();
171 // FIXME: This code is stolen from CallGraph::addToCallGraph(Function *F), which
172 // happens to be private. It is better for this functionality exposed by the
173 // CallGraph.
174 static void buildCGN(CallGraph &CG, CallGraphNode *Node) {
175 Function *F = Node->getFunction();
177 // Look for calls by this function.
178 for (Instruction &I : instructions(F))
179 if (CallSite CS = CallSite(cast<Value>(&I))) {
180 const Function *Callee = CS.getCalledFunction();
181 if (!Callee || !Intrinsic::isLeaf(Callee->getIntrinsicID()))
182 // Indirect calls of intrinsics are not allowed so no need to check.
183 // We can be more precise here by using TargetArg returned by
184 // Intrinsic::isLeaf.
185 Node->addCalledFunction(CS, CG.getCallsExternalNode());
186 else if (!Callee->isIntrinsic())
187 Node->addCalledFunction(CS, CG.getOrInsertFunction(Callee));
191 // Rebuild CGN after we extracted parts of the code from ParentFunc into
192 // NewFuncs. Builds CGNs for the NewFuncs and adds them to the current SCC.
193 void coro::updateCallGraph(Function &ParentFunc, ArrayRef<Function *> NewFuncs,
194 CallGraph &CG, CallGraphSCC &SCC) {
195 // Rebuild CGN from scratch for the ParentFunc
196 auto *ParentNode = CG[&ParentFunc];
197 ParentNode->removeAllCalledFunctions();
198 buildCGN(CG, ParentNode);
200 SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end());
202 for (Function *F : NewFuncs) {
203 CallGraphNode *Callee = CG.getOrInsertFunction(F);
204 Nodes.push_back(Callee);
205 buildCGN(CG, Callee);
208 SCC.initialize(Nodes);
211 static void clear(coro::Shape &Shape) {
212 Shape.CoroBegin = nullptr;
213 Shape.CoroEnds.clear();
214 Shape.CoroSizes.clear();
215 Shape.CoroSuspends.clear();
217 Shape.FrameTy = nullptr;
218 Shape.FramePtr = nullptr;
219 Shape.AllocaSpillBlock = nullptr;
220 Shape.ResumeSwitch = nullptr;
221 Shape.PromiseAlloca = nullptr;
222 Shape.HasFinalSuspend = false;
225 static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin,
226 CoroSuspendInst *SuspendInst) {
227 Module *M = SuspendInst->getModule();
228 auto *Fn = Intrinsic::getDeclaration(M, Intrinsic::coro_save);
229 auto *SaveInst =
230 cast<CoroSaveInst>(CallInst::Create(Fn, CoroBegin, "", SuspendInst));
231 assert(!SuspendInst->getCoroSave());
232 SuspendInst->setArgOperand(0, SaveInst);
233 return SaveInst;
236 // Collect "interesting" coroutine intrinsics.
237 void coro::Shape::buildFrom(Function &F) {
238 size_t FinalSuspendIndex = 0;
239 clear(*this);
240 SmallVector<CoroFrameInst *, 8> CoroFrames;
241 SmallVector<CoroSaveInst *, 2> UnusedCoroSaves;
243 for (Instruction &I : instructions(F)) {
244 if (auto II = dyn_cast<IntrinsicInst>(&I)) {
245 switch (II->getIntrinsicID()) {
246 default:
247 continue;
248 case Intrinsic::coro_size:
249 CoroSizes.push_back(cast<CoroSizeInst>(II));
250 break;
251 case Intrinsic::coro_frame:
252 CoroFrames.push_back(cast<CoroFrameInst>(II));
253 break;
254 case Intrinsic::coro_save:
255 // After optimizations, coro_suspends using this coro_save might have
256 // been removed, remember orphaned coro_saves to remove them later.
257 if (II->use_empty())
258 UnusedCoroSaves.push_back(cast<CoroSaveInst>(II));
259 break;
260 case Intrinsic::coro_suspend:
261 CoroSuspends.push_back(cast<CoroSuspendInst>(II));
262 if (CoroSuspends.back()->isFinal()) {
263 if (HasFinalSuspend)
264 report_fatal_error(
265 "Only one suspend point can be marked as final");
266 HasFinalSuspend = true;
267 FinalSuspendIndex = CoroSuspends.size() - 1;
269 break;
270 case Intrinsic::coro_begin: {
271 auto CB = cast<CoroBeginInst>(II);
272 if (CB->getId()->getInfo().isPreSplit()) {
273 if (CoroBegin)
274 report_fatal_error(
275 "coroutine should have exactly one defining @llvm.coro.begin");
276 CB->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
277 CB->addAttribute(AttributeList::ReturnIndex, Attribute::NoAlias);
278 CB->removeAttribute(AttributeList::FunctionIndex,
279 Attribute::NoDuplicate);
280 CoroBegin = CB;
282 break;
284 case Intrinsic::coro_end:
285 CoroEnds.push_back(cast<CoroEndInst>(II));
286 if (CoroEnds.back()->isFallthrough()) {
287 // Make sure that the fallthrough coro.end is the first element in the
288 // CoroEnds vector.
289 if (CoroEnds.size() > 1) {
290 if (CoroEnds.front()->isFallthrough())
291 report_fatal_error(
292 "Only one coro.end can be marked as fallthrough");
293 std::swap(CoroEnds.front(), CoroEnds.back());
296 break;
301 // If for some reason, we were not able to find coro.begin, bailout.
302 if (!CoroBegin) {
303 // Replace coro.frame which are supposed to be lowered to the result of
304 // coro.begin with undef.
305 auto *Undef = UndefValue::get(Type::getInt8PtrTy(F.getContext()));
306 for (CoroFrameInst *CF : CoroFrames) {
307 CF->replaceAllUsesWith(Undef);
308 CF->eraseFromParent();
311 // Replace all coro.suspend with undef and remove related coro.saves if
312 // present.
313 for (CoroSuspendInst *CS : CoroSuspends) {
314 CS->replaceAllUsesWith(UndefValue::get(CS->getType()));
315 CS->eraseFromParent();
316 if (auto *CoroSave = CS->getCoroSave())
317 CoroSave->eraseFromParent();
320 // Replace all coro.ends with unreachable instruction.
321 for (CoroEndInst *CE : CoroEnds)
322 changeToUnreachable(CE, /*UseLLVMTrap=*/false);
324 return;
327 // The coro.free intrinsic is always lowered to the result of coro.begin.
328 for (CoroFrameInst *CF : CoroFrames) {
329 CF->replaceAllUsesWith(CoroBegin);
330 CF->eraseFromParent();
333 // Canonicalize coro.suspend by inserting a coro.save if needed.
334 for (CoroSuspendInst *CS : CoroSuspends)
335 if (!CS->getCoroSave())
336 createCoroSave(CoroBegin, CS);
338 // Move final suspend to be the last element in the CoroSuspends vector.
339 if (HasFinalSuspend &&
340 FinalSuspendIndex != CoroSuspends.size() - 1)
341 std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back());
343 // Remove orphaned coro.saves.
344 for (CoroSaveInst *CoroSave : UnusedCoroSaves)
345 CoroSave->eraseFromParent();
348 void LLVMAddCoroEarlyPass(LLVMPassManagerRef PM) {
349 unwrap(PM)->add(createCoroEarlyPass());
352 void LLVMAddCoroSplitPass(LLVMPassManagerRef PM) {
353 unwrap(PM)->add(createCoroSplitPass());
356 void LLVMAddCoroElidePass(LLVMPassManagerRef PM) {
357 unwrap(PM)->add(createCoroElidePass());
360 void LLVMAddCoroCleanupPass(LLVMPassManagerRef PM) {
361 unwrap(PM)->add(createCoroCleanupPass());