[InstCombine] Signed saturation tests. NFC
[llvm-complete.git] / lib / Transforms / Coroutines / CoroEarly.cpp
blob55993d33ee4e06ed498abbdb432c92896622a4e5
1 //===- CoroEarly.cpp - Coroutine Early Function Pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This pass lowers coroutine intrinsics that hide the details of the exact
9 // calling convention for coroutine resume and destroy functions and details of
10 // the structure of the coroutine frame.
11 //===----------------------------------------------------------------------===//
13 #include "CoroInternal.h"
14 #include "llvm/IR/CallSite.h"
15 #include "llvm/IR/IRBuilder.h"
16 #include "llvm/IR/InstIterator.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/Pass.h"
20 using namespace llvm;
22 #define DEBUG_TYPE "coro-early"
24 namespace {
25 // Created on demand if CoroEarly pass has work to do.
26 class Lowerer : public coro::LowererBase {
27 IRBuilder<> Builder;
28 PointerType *const AnyResumeFnPtrTy;
29 Constant *NoopCoro = nullptr;
31 void lowerResumeOrDestroy(CallSite CS, CoroSubFnInst::ResumeKind);
32 void lowerCoroPromise(CoroPromiseInst *Intrin);
33 void lowerCoroDone(IntrinsicInst *II);
34 void lowerCoroNoop(IntrinsicInst *II);
36 public:
37 Lowerer(Module &M)
38 : LowererBase(M), Builder(Context),
39 AnyResumeFnPtrTy(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
40 /*isVarArg=*/false)
41 ->getPointerTo()) {}
42 bool lowerEarlyIntrinsics(Function &F);
46 // Replace a direct call to coro.resume or coro.destroy with an indirect call to
47 // an address returned by coro.subfn.addr intrinsic. This is done so that
48 // CGPassManager recognizes devirtualization when CoroElide pass replaces a call
49 // to coro.subfn.addr with an appropriate function address.
50 void Lowerer::lowerResumeOrDestroy(CallSite CS,
51 CoroSubFnInst::ResumeKind Index) {
52 Value *ResumeAddr =
53 makeSubFnCall(CS.getArgOperand(0), Index, CS.getInstruction());
54 CS.setCalledFunction(ResumeAddr);
55 CS.setCallingConv(CallingConv::Fast);
58 // Coroutine promise field is always at the fixed offset from the beginning of
59 // the coroutine frame. i8* coro.promise(i8*, i1 from) intrinsic adds an offset
60 // to a passed pointer to move from coroutine frame to coroutine promise and
61 // vice versa. Since we don't know exactly which coroutine frame it is, we build
62 // a coroutine frame mock up starting with two function pointers, followed by a
63 // properly aligned coroutine promise field.
64 // TODO: Handle the case when coroutine promise alloca has align override.
65 void Lowerer::lowerCoroPromise(CoroPromiseInst *Intrin) {
66 Value *Operand = Intrin->getArgOperand(0);
67 unsigned Alignement = Intrin->getAlignment();
68 Type *Int8Ty = Builder.getInt8Ty();
70 auto *SampleStruct =
71 StructType::get(Context, {AnyResumeFnPtrTy, AnyResumeFnPtrTy, Int8Ty});
72 const DataLayout &DL = TheModule.getDataLayout();
73 int64_t Offset = alignTo(
74 DL.getStructLayout(SampleStruct)->getElementOffset(2), Alignement);
75 if (Intrin->isFromPromise())
76 Offset = -Offset;
78 Builder.SetInsertPoint(Intrin);
79 Value *Replacement =
80 Builder.CreateConstInBoundsGEP1_32(Int8Ty, Operand, Offset);
82 Intrin->replaceAllUsesWith(Replacement);
83 Intrin->eraseFromParent();
86 // When a coroutine reaches final suspend point, it zeros out ResumeFnAddr in
87 // the coroutine frame (it is UB to resume from a final suspend point).
88 // The llvm.coro.done intrinsic is used to check whether a coroutine is
89 // suspended at the final suspend point or not.
90 void Lowerer::lowerCoroDone(IntrinsicInst *II) {
91 Value *Operand = II->getArgOperand(0);
93 // ResumeFnAddr is the first pointer sized element of the coroutine frame.
94 static_assert(coro::Shape::SwitchFieldIndex::Resume == 0,
95 "resume function not at offset zero");
96 auto *FrameTy = Int8Ptr;
97 PointerType *FramePtrTy = FrameTy->getPointerTo();
99 Builder.SetInsertPoint(II);
100 auto *BCI = Builder.CreateBitCast(Operand, FramePtrTy);
101 auto *Load = Builder.CreateLoad(BCI);
102 auto *Cond = Builder.CreateICmpEQ(Load, NullPtr);
104 II->replaceAllUsesWith(Cond);
105 II->eraseFromParent();
108 void Lowerer::lowerCoroNoop(IntrinsicInst *II) {
109 if (!NoopCoro) {
110 LLVMContext &C = Builder.getContext();
111 Module &M = *II->getModule();
113 // Create a noop.frame struct type.
114 StructType *FrameTy = StructType::create(C, "NoopCoro.Frame");
115 auto *FramePtrTy = FrameTy->getPointerTo();
116 auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy,
117 /*isVarArg=*/false);
118 auto *FnPtrTy = FnTy->getPointerTo();
119 FrameTy->setBody({FnPtrTy, FnPtrTy});
121 // Create a Noop function that does nothing.
122 Function *NoopFn =
123 Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
124 "NoopCoro.ResumeDestroy", &M);
125 NoopFn->setCallingConv(CallingConv::Fast);
126 auto *Entry = BasicBlock::Create(C, "entry", NoopFn);
127 ReturnInst::Create(C, Entry);
129 // Create a constant struct for the frame.
130 Constant* Values[] = {NoopFn, NoopFn};
131 Constant* NoopCoroConst = ConstantStruct::get(FrameTy, Values);
132 NoopCoro = new GlobalVariable(M, NoopCoroConst->getType(), /*isConstant=*/true,
133 GlobalVariable::PrivateLinkage, NoopCoroConst,
134 "NoopCoro.Frame.Const");
137 Builder.SetInsertPoint(II);
138 auto *NoopCoroVoidPtr = Builder.CreateBitCast(NoopCoro, Int8Ptr);
139 II->replaceAllUsesWith(NoopCoroVoidPtr);
140 II->eraseFromParent();
143 // Prior to CoroSplit, calls to coro.begin needs to be marked as NoDuplicate,
144 // as CoroSplit assumes there is exactly one coro.begin. After CoroSplit,
145 // NoDuplicate attribute will be removed from coro.begin otherwise, it will
146 // interfere with inlining.
147 static void setCannotDuplicate(CoroIdInst *CoroId) {
148 for (User *U : CoroId->users())
149 if (auto *CB = dyn_cast<CoroBeginInst>(U))
150 CB->setCannotDuplicate();
153 bool Lowerer::lowerEarlyIntrinsics(Function &F) {
154 bool Changed = false;
155 CoroIdInst *CoroId = nullptr;
156 SmallVector<CoroFreeInst *, 4> CoroFrees;
157 for (auto IB = inst_begin(F), IE = inst_end(F); IB != IE;) {
158 Instruction &I = *IB++;
159 if (auto CS = CallSite(&I)) {
160 switch (CS.getIntrinsicID()) {
161 default:
162 continue;
163 case Intrinsic::coro_free:
164 CoroFrees.push_back(cast<CoroFreeInst>(&I));
165 break;
166 case Intrinsic::coro_suspend:
167 // Make sure that final suspend point is not duplicated as CoroSplit
168 // pass expects that there is at most one final suspend point.
169 if (cast<CoroSuspendInst>(&I)->isFinal())
170 CS.setCannotDuplicate();
171 break;
172 case Intrinsic::coro_end:
173 // Make sure that fallthrough coro.end is not duplicated as CoroSplit
174 // pass expects that there is at most one fallthrough coro.end.
175 if (cast<CoroEndInst>(&I)->isFallthrough())
176 CS.setCannotDuplicate();
177 break;
178 case Intrinsic::coro_noop:
179 lowerCoroNoop(cast<IntrinsicInst>(&I));
180 break;
181 case Intrinsic::coro_id:
182 // Mark a function that comes out of the frontend that has a coro.id
183 // with a coroutine attribute.
184 if (auto *CII = cast<CoroIdInst>(&I)) {
185 if (CII->getInfo().isPreSplit()) {
186 F.addFnAttr(CORO_PRESPLIT_ATTR, UNPREPARED_FOR_SPLIT);
187 setCannotDuplicate(CII);
188 CII->setCoroutineSelf();
189 CoroId = cast<CoroIdInst>(&I);
192 break;
193 case Intrinsic::coro_id_retcon:
194 case Intrinsic::coro_id_retcon_once:
195 F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT);
196 break;
197 case Intrinsic::coro_resume:
198 lowerResumeOrDestroy(CS, CoroSubFnInst::ResumeIndex);
199 break;
200 case Intrinsic::coro_destroy:
201 lowerResumeOrDestroy(CS, CoroSubFnInst::DestroyIndex);
202 break;
203 case Intrinsic::coro_promise:
204 lowerCoroPromise(cast<CoroPromiseInst>(&I));
205 break;
206 case Intrinsic::coro_done:
207 lowerCoroDone(cast<IntrinsicInst>(&I));
208 break;
210 Changed = true;
213 // Make sure that all CoroFree reference the coro.id intrinsic.
214 // Token type is not exposed through coroutine C/C++ builtins to plain C, so
215 // we allow specifying none and fixing it up here.
216 if (CoroId)
217 for (CoroFreeInst *CF : CoroFrees)
218 CF->setArgOperand(0, CoroId);
219 return Changed;
222 //===----------------------------------------------------------------------===//
223 // Top Level Driver
224 //===----------------------------------------------------------------------===//
226 namespace {
228 struct CoroEarly : public FunctionPass {
229 static char ID; // Pass identification, replacement for typeid.
230 CoroEarly() : FunctionPass(ID) {
231 initializeCoroEarlyPass(*PassRegistry::getPassRegistry());
234 std::unique_ptr<Lowerer> L;
236 // This pass has work to do only if we find intrinsics we are going to lower
237 // in the module.
238 bool doInitialization(Module &M) override {
239 if (coro::declaresIntrinsics(M, {"llvm.coro.id",
240 "llvm.coro.id.retcon",
241 "llvm.coro.id.retcon.once",
242 "llvm.coro.destroy",
243 "llvm.coro.done",
244 "llvm.coro.end",
245 "llvm.coro.noop",
246 "llvm.coro.free",
247 "llvm.coro.promise",
248 "llvm.coro.resume",
249 "llvm.coro.suspend"}))
250 L = std::make_unique<Lowerer>(M);
251 return false;
254 bool runOnFunction(Function &F) override {
255 if (!L)
256 return false;
258 return L->lowerEarlyIntrinsics(F);
261 void getAnalysisUsage(AnalysisUsage &AU) const override {
262 AU.setPreservesCFG();
264 StringRef getPassName() const override {
265 return "Lower early coroutine intrinsics";
270 char CoroEarly::ID = 0;
271 INITIALIZE_PASS(CoroEarly, "coro-early", "Lower early coroutine intrinsics",
272 false, false)
274 Pass *llvm::createCoroEarlyPass() { return new CoroEarly(); }