[ARM] Rejig MVE load store tests. NFC
[llvm-core.git] / lib / Transforms / Coroutines / CoroEarly.cpp
blob692697d6f32e814bd88e97baa61d0d0a72071052
1 //===- CoroEarly.cpp - Coroutine Early Function Pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This pass lowers coroutine intrinsics that hide the details of the exact
9 // calling convention for coroutine resume and destroy functions and details of
10 // the structure of the coroutine frame.
11 //===----------------------------------------------------------------------===//
13 #include "CoroInternal.h"
14 #include "llvm/IR/CallSite.h"
15 #include "llvm/IR/IRBuilder.h"
16 #include "llvm/IR/InstIterator.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/Pass.h"
20 using namespace llvm;
22 #define DEBUG_TYPE "coro-early"
24 namespace {
25 // Created on demand if CoroEarly pass has work to do.
26 class Lowerer : public coro::LowererBase {
27 IRBuilder<> Builder;
28 PointerType *const AnyResumeFnPtrTy;
29 Constant *NoopCoro = nullptr;
31 void lowerResumeOrDestroy(CallSite CS, CoroSubFnInst::ResumeKind);
32 void lowerCoroPromise(CoroPromiseInst *Intrin);
33 void lowerCoroDone(IntrinsicInst *II);
34 void lowerCoroNoop(IntrinsicInst *II);
36 public:
37 Lowerer(Module &M)
38 : LowererBase(M), Builder(Context),
39 AnyResumeFnPtrTy(FunctionType::get(Type::getVoidTy(Context), Int8Ptr,
40 /*isVarArg=*/false)
41 ->getPointerTo()) {}
42 bool lowerEarlyIntrinsics(Function &F);
46 // Replace a direct call to coro.resume or coro.destroy with an indirect call to
47 // an address returned by coro.subfn.addr intrinsic. This is done so that
48 // CGPassManager recognizes devirtualization when CoroElide pass replaces a call
49 // to coro.subfn.addr with an appropriate function address.
50 void Lowerer::lowerResumeOrDestroy(CallSite CS,
51 CoroSubFnInst::ResumeKind Index) {
52 Value *ResumeAddr =
53 makeSubFnCall(CS.getArgOperand(0), Index, CS.getInstruction());
54 CS.setCalledFunction(ResumeAddr);
55 CS.setCallingConv(CallingConv::Fast);
58 // Coroutine promise field is always at the fixed offset from the beginning of
59 // the coroutine frame. i8* coro.promise(i8*, i1 from) intrinsic adds an offset
60 // to a passed pointer to move from coroutine frame to coroutine promise and
61 // vice versa. Since we don't know exactly which coroutine frame it is, we build
62 // a coroutine frame mock up starting with two function pointers, followed by a
63 // properly aligned coroutine promise field.
64 // TODO: Handle the case when coroutine promise alloca has align override.
65 void Lowerer::lowerCoroPromise(CoroPromiseInst *Intrin) {
66 Value *Operand = Intrin->getArgOperand(0);
67 unsigned Alignement = Intrin->getAlignment();
68 Type *Int8Ty = Builder.getInt8Ty();
70 auto *SampleStruct =
71 StructType::get(Context, {AnyResumeFnPtrTy, AnyResumeFnPtrTy, Int8Ty});
72 const DataLayout &DL = TheModule.getDataLayout();
73 int64_t Offset = alignTo(
74 DL.getStructLayout(SampleStruct)->getElementOffset(2), Alignement);
75 if (Intrin->isFromPromise())
76 Offset = -Offset;
78 Builder.SetInsertPoint(Intrin);
79 Value *Replacement =
80 Builder.CreateConstInBoundsGEP1_32(Int8Ty, Operand, Offset);
82 Intrin->replaceAllUsesWith(Replacement);
83 Intrin->eraseFromParent();
86 // When a coroutine reaches final suspend point, it zeros out ResumeFnAddr in
87 // the coroutine frame (it is UB to resume from a final suspend point).
88 // The llvm.coro.done intrinsic is used to check whether a coroutine is
89 // suspended at the final suspend point or not.
90 void Lowerer::lowerCoroDone(IntrinsicInst *II) {
91 Value *Operand = II->getArgOperand(0);
93 // ResumeFnAddr is the first pointer sized element of the coroutine frame.
94 auto *FrameTy = Int8Ptr;
95 PointerType *FramePtrTy = FrameTy->getPointerTo();
97 Builder.SetInsertPoint(II);
98 auto *BCI = Builder.CreateBitCast(Operand, FramePtrTy);
99 auto *Gep = Builder.CreateConstInBoundsGEP1_32(FrameTy, BCI, 0);
100 auto *Load = Builder.CreateLoad(FrameTy, Gep);
101 auto *Cond = Builder.CreateICmpEQ(Load, NullPtr);
103 II->replaceAllUsesWith(Cond);
104 II->eraseFromParent();
107 void Lowerer::lowerCoroNoop(IntrinsicInst *II) {
108 if (!NoopCoro) {
109 LLVMContext &C = Builder.getContext();
110 Module &M = *II->getModule();
112 // Create a noop.frame struct type.
113 StructType *FrameTy = StructType::create(C, "NoopCoro.Frame");
114 auto *FramePtrTy = FrameTy->getPointerTo();
115 auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy,
116 /*isVarArg=*/false);
117 auto *FnPtrTy = FnTy->getPointerTo();
118 FrameTy->setBody({FnPtrTy, FnPtrTy});
120 // Create a Noop function that does nothing.
121 Function *NoopFn =
122 Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage,
123 "NoopCoro.ResumeDestroy", &M);
124 NoopFn->setCallingConv(CallingConv::Fast);
125 auto *Entry = BasicBlock::Create(C, "entry", NoopFn);
126 ReturnInst::Create(C, Entry);
128 // Create a constant struct for the frame.
129 Constant* Values[] = {NoopFn, NoopFn};
130 Constant* NoopCoroConst = ConstantStruct::get(FrameTy, Values);
131 NoopCoro = new GlobalVariable(M, NoopCoroConst->getType(), /*isConstant=*/true,
132 GlobalVariable::PrivateLinkage, NoopCoroConst,
133 "NoopCoro.Frame.Const");
136 Builder.SetInsertPoint(II);
137 auto *NoopCoroVoidPtr = Builder.CreateBitCast(NoopCoro, Int8Ptr);
138 II->replaceAllUsesWith(NoopCoroVoidPtr);
139 II->eraseFromParent();
142 // Prior to CoroSplit, calls to coro.begin needs to be marked as NoDuplicate,
143 // as CoroSplit assumes there is exactly one coro.begin. After CoroSplit,
144 // NoDuplicate attribute will be removed from coro.begin otherwise, it will
145 // interfere with inlining.
146 static void setCannotDuplicate(CoroIdInst *CoroId) {
147 for (User *U : CoroId->users())
148 if (auto *CB = dyn_cast<CoroBeginInst>(U))
149 CB->setCannotDuplicate();
152 bool Lowerer::lowerEarlyIntrinsics(Function &F) {
153 bool Changed = false;
154 CoroIdInst *CoroId = nullptr;
155 SmallVector<CoroFreeInst *, 4> CoroFrees;
156 for (auto IB = inst_begin(F), IE = inst_end(F); IB != IE;) {
157 Instruction &I = *IB++;
158 if (auto CS = CallSite(&I)) {
159 switch (CS.getIntrinsicID()) {
160 default:
161 continue;
162 case Intrinsic::coro_free:
163 CoroFrees.push_back(cast<CoroFreeInst>(&I));
164 break;
165 case Intrinsic::coro_suspend:
166 // Make sure that final suspend point is not duplicated as CoroSplit
167 // pass expects that there is at most one final suspend point.
168 if (cast<CoroSuspendInst>(&I)->isFinal())
169 CS.setCannotDuplicate();
170 break;
171 case Intrinsic::coro_end:
172 // Make sure that fallthrough coro.end is not duplicated as CoroSplit
173 // pass expects that there is at most one fallthrough coro.end.
174 if (cast<CoroEndInst>(&I)->isFallthrough())
175 CS.setCannotDuplicate();
176 break;
177 case Intrinsic::coro_noop:
178 lowerCoroNoop(cast<IntrinsicInst>(&I));
179 break;
180 case Intrinsic::coro_id:
181 // Mark a function that comes out of the frontend that has a coro.id
182 // with a coroutine attribute.
183 if (auto *CII = cast<CoroIdInst>(&I)) {
184 if (CII->getInfo().isPreSplit()) {
185 F.addFnAttr(CORO_PRESPLIT_ATTR, UNPREPARED_FOR_SPLIT);
186 setCannotDuplicate(CII);
187 CII->setCoroutineSelf();
188 CoroId = cast<CoroIdInst>(&I);
191 break;
192 case Intrinsic::coro_resume:
193 lowerResumeOrDestroy(CS, CoroSubFnInst::ResumeIndex);
194 break;
195 case Intrinsic::coro_destroy:
196 lowerResumeOrDestroy(CS, CoroSubFnInst::DestroyIndex);
197 break;
198 case Intrinsic::coro_promise:
199 lowerCoroPromise(cast<CoroPromiseInst>(&I));
200 break;
201 case Intrinsic::coro_done:
202 lowerCoroDone(cast<IntrinsicInst>(&I));
203 break;
205 Changed = true;
208 // Make sure that all CoroFree reference the coro.id intrinsic.
209 // Token type is not exposed through coroutine C/C++ builtins to plain C, so
210 // we allow specifying none and fixing it up here.
211 if (CoroId)
212 for (CoroFreeInst *CF : CoroFrees)
213 CF->setArgOperand(0, CoroId);
214 return Changed;
217 //===----------------------------------------------------------------------===//
218 // Top Level Driver
219 //===----------------------------------------------------------------------===//
221 namespace {
223 struct CoroEarly : public FunctionPass {
224 static char ID; // Pass identification, replacement for typeid.
225 CoroEarly() : FunctionPass(ID) {
226 initializeCoroEarlyPass(*PassRegistry::getPassRegistry());
229 std::unique_ptr<Lowerer> L;
231 // This pass has work to do only if we find intrinsics we are going to lower
232 // in the module.
233 bool doInitialization(Module &M) override {
234 if (coro::declaresIntrinsics(
235 M, {"llvm.coro.id", "llvm.coro.destroy", "llvm.coro.done",
236 "llvm.coro.end", "llvm.coro.noop", "llvm.coro.free",
237 "llvm.coro.promise", "llvm.coro.resume", "llvm.coro.suspend"}))
238 L = llvm::make_unique<Lowerer>(M);
239 return false;
242 bool runOnFunction(Function &F) override {
243 if (!L)
244 return false;
246 return L->lowerEarlyIntrinsics(F);
249 void getAnalysisUsage(AnalysisUsage &AU) const override {
250 AU.setPreservesCFG();
252 StringRef getPassName() const override {
253 return "Lower early coroutine intrinsics";
258 char CoroEarly::ID = 0;
259 INITIALIZE_PASS(CoroEarly, "coro-early", "Lower early coroutine intrinsics",
260 false, false)
262 Pass *llvm::createCoroEarlyPass() { return new CoroEarly(); }