[AMDGPU] prevent shrinking udiv/urem if either operand is in (SignedMax,UnsignedMax...
[llvm-project.git] / llvm / lib / SandboxIR / Context.cpp
blobb86ed5864c1ac1de18964253d13ab91766f7b279
1 //===- Context.cpp - The Context class of Sandbox IR ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/SandboxIR/Context.h"
10 #include "llvm/SandboxIR/Function.h"
11 #include "llvm/SandboxIR/Instruction.h"
12 #include "llvm/SandboxIR/Module.h"
14 namespace llvm::sandboxir {
16 std::unique_ptr<Value> Context::detachLLVMValue(llvm::Value *V) {
17 std::unique_ptr<Value> Erased;
18 auto It = LLVMValueToValueMap.find(V);
19 if (It != LLVMValueToValueMap.end()) {
20 auto *Val = It->second.release();
21 Erased = std::unique_ptr<Value>(Val);
22 LLVMValueToValueMap.erase(It);
24 return Erased;
27 std::unique_ptr<Value> Context::detach(Value *V) {
28 assert(V->getSubclassID() != Value::ClassID::Constant &&
29 "Can't detach a constant!");
30 assert(V->getSubclassID() != Value::ClassID::User && "Can't detach a user!");
31 return detachLLVMValue(V->Val);
34 Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
35 assert(VPtr->getSubclassID() != Value::ClassID::User &&
36 "Can't register a user!");
38 Value *V = VPtr.get();
39 [[maybe_unused]] auto Pair =
40 LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
41 assert(Pair.second && "Already exists!");
43 // Track creation of instructions.
44 // Please note that we don't allow the creation of detached instructions,
45 // meaning that the instructions need to be inserted into a block upon
46 // creation. This is why the tracker class combines creation and insertion.
47 if (auto *I = dyn_cast<Instruction>(V)) {
48 getTracker().emplaceIfTracking<CreateAndInsertInst>(I);
49 runCreateInstrCallbacks(I);
52 return V;
55 Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
56 auto Pair = LLVMValueToValueMap.insert({LLVMV, nullptr});
57 auto It = Pair.first;
58 if (!Pair.second)
59 return It->second.get();
61 if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) {
62 switch (C->getValueID()) {
63 case llvm::Value::ConstantIntVal:
64 It->second = std::unique_ptr<ConstantInt>(
65 new ConstantInt(cast<llvm::ConstantInt>(C), *this));
66 return It->second.get();
67 case llvm::Value::ConstantFPVal:
68 It->second = std::unique_ptr<ConstantFP>(
69 new ConstantFP(cast<llvm::ConstantFP>(C), *this));
70 return It->second.get();
71 case llvm::Value::BlockAddressVal:
72 It->second = std::unique_ptr<BlockAddress>(
73 new BlockAddress(cast<llvm::BlockAddress>(C), *this));
74 return It->second.get();
75 case llvm::Value::ConstantTokenNoneVal:
76 It->second = std::unique_ptr<ConstantTokenNone>(
77 new ConstantTokenNone(cast<llvm::ConstantTokenNone>(C), *this));
78 return It->second.get();
79 case llvm::Value::ConstantAggregateZeroVal: {
80 auto *CAZ = cast<llvm::ConstantAggregateZero>(C);
81 It->second = std::unique_ptr<ConstantAggregateZero>(
82 new ConstantAggregateZero(CAZ, *this));
83 auto *Ret = It->second.get();
84 // Must create sandboxir for elements.
85 auto EC = CAZ->getElementCount();
86 if (EC.isFixed()) {
87 for (auto ElmIdx : seq<unsigned>(0, EC.getFixedValue()))
88 getOrCreateValueInternal(CAZ->getElementValue(ElmIdx), CAZ);
90 return Ret;
92 case llvm::Value::ConstantPointerNullVal:
93 It->second = std::unique_ptr<ConstantPointerNull>(
94 new ConstantPointerNull(cast<llvm::ConstantPointerNull>(C), *this));
95 return It->second.get();
96 case llvm::Value::PoisonValueVal:
97 It->second = std::unique_ptr<PoisonValue>(
98 new PoisonValue(cast<llvm::PoisonValue>(C), *this));
99 return It->second.get();
100 case llvm::Value::UndefValueVal:
101 It->second = std::unique_ptr<UndefValue>(
102 new UndefValue(cast<llvm::UndefValue>(C), *this));
103 return It->second.get();
104 case llvm::Value::DSOLocalEquivalentVal: {
105 auto *DSOLE = cast<llvm::DSOLocalEquivalent>(C);
106 It->second = std::unique_ptr<DSOLocalEquivalent>(
107 new DSOLocalEquivalent(DSOLE, *this));
108 auto *Ret = It->second.get();
109 getOrCreateValueInternal(DSOLE->getGlobalValue(), DSOLE);
110 return Ret;
112 case llvm::Value::ConstantArrayVal:
113 It->second = std::unique_ptr<ConstantArray>(
114 new ConstantArray(cast<llvm::ConstantArray>(C), *this));
115 break;
116 case llvm::Value::ConstantStructVal:
117 It->second = std::unique_ptr<ConstantStruct>(
118 new ConstantStruct(cast<llvm::ConstantStruct>(C), *this));
119 break;
120 case llvm::Value::ConstantVectorVal:
121 It->second = std::unique_ptr<ConstantVector>(
122 new ConstantVector(cast<llvm::ConstantVector>(C), *this));
123 break;
124 case llvm::Value::FunctionVal:
125 It->second = std::unique_ptr<Function>(
126 new Function(cast<llvm::Function>(C), *this));
127 break;
128 case llvm::Value::GlobalIFuncVal:
129 It->second = std::unique_ptr<GlobalIFunc>(
130 new GlobalIFunc(cast<llvm::GlobalIFunc>(C), *this));
131 break;
132 case llvm::Value::GlobalVariableVal:
133 It->second = std::unique_ptr<GlobalVariable>(
134 new GlobalVariable(cast<llvm::GlobalVariable>(C), *this));
135 break;
136 case llvm::Value::GlobalAliasVal:
137 It->second = std::unique_ptr<GlobalAlias>(
138 new GlobalAlias(cast<llvm::GlobalAlias>(C), *this));
139 break;
140 case llvm::Value::NoCFIValueVal:
141 It->second = std::unique_ptr<NoCFIValue>(
142 new NoCFIValue(cast<llvm::NoCFIValue>(C), *this));
143 break;
144 case llvm::Value::ConstantPtrAuthVal:
145 It->second = std::unique_ptr<ConstantPtrAuth>(
146 new ConstantPtrAuth(cast<llvm::ConstantPtrAuth>(C), *this));
147 break;
148 case llvm::Value::ConstantExprVal:
149 It->second = std::unique_ptr<ConstantExpr>(
150 new ConstantExpr(cast<llvm::ConstantExpr>(C), *this));
151 break;
152 default:
153 It->second = std::unique_ptr<Constant>(new Constant(C, *this));
154 break;
156 auto *NewC = It->second.get();
157 for (llvm::Value *COp : C->operands())
158 getOrCreateValueInternal(COp, C);
159 return NewC;
161 if (auto *Arg = dyn_cast<llvm::Argument>(LLVMV)) {
162 It->second = std::unique_ptr<Argument>(new Argument(Arg, *this));
163 return It->second.get();
165 if (auto *BB = dyn_cast<llvm::BasicBlock>(LLVMV)) {
166 assert(isa<llvm::BlockAddress>(U) &&
167 "This won't create a SBBB, don't call this function directly!");
168 if (auto *SBBB = getValue(BB))
169 return SBBB;
170 return nullptr;
172 assert(isa<llvm::Instruction>(LLVMV) && "Expected Instruction");
174 switch (cast<llvm::Instruction>(LLVMV)->getOpcode()) {
175 case llvm::Instruction::VAArg: {
176 auto *LLVMVAArg = cast<llvm::VAArgInst>(LLVMV);
177 It->second = std::unique_ptr<VAArgInst>(new VAArgInst(LLVMVAArg, *this));
178 return It->second.get();
180 case llvm::Instruction::Freeze: {
181 auto *LLVMFreeze = cast<llvm::FreezeInst>(LLVMV);
182 It->second = std::unique_ptr<FreezeInst>(new FreezeInst(LLVMFreeze, *this));
183 return It->second.get();
185 case llvm::Instruction::Fence: {
186 auto *LLVMFence = cast<llvm::FenceInst>(LLVMV);
187 It->second = std::unique_ptr<FenceInst>(new FenceInst(LLVMFence, *this));
188 return It->second.get();
190 case llvm::Instruction::Select: {
191 auto *LLVMSel = cast<llvm::SelectInst>(LLVMV);
192 It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
193 return It->second.get();
195 case llvm::Instruction::ExtractElement: {
196 auto *LLVMIns = cast<llvm::ExtractElementInst>(LLVMV);
197 It->second = std::unique_ptr<ExtractElementInst>(
198 new ExtractElementInst(LLVMIns, *this));
199 return It->second.get();
201 case llvm::Instruction::InsertElement: {
202 auto *LLVMIns = cast<llvm::InsertElementInst>(LLVMV);
203 It->second = std::unique_ptr<InsertElementInst>(
204 new InsertElementInst(LLVMIns, *this));
205 return It->second.get();
207 case llvm::Instruction::ShuffleVector: {
208 auto *LLVMIns = cast<llvm::ShuffleVectorInst>(LLVMV);
209 It->second = std::unique_ptr<ShuffleVectorInst>(
210 new ShuffleVectorInst(LLVMIns, *this));
211 return It->second.get();
213 case llvm::Instruction::ExtractValue: {
214 auto *LLVMIns = cast<llvm::ExtractValueInst>(LLVMV);
215 It->second =
216 std::unique_ptr<ExtractValueInst>(new ExtractValueInst(LLVMIns, *this));
217 return It->second.get();
219 case llvm::Instruction::InsertValue: {
220 auto *LLVMIns = cast<llvm::InsertValueInst>(LLVMV);
221 It->second =
222 std::unique_ptr<InsertValueInst>(new InsertValueInst(LLVMIns, *this));
223 return It->second.get();
225 case llvm::Instruction::Br: {
226 auto *LLVMBr = cast<llvm::BranchInst>(LLVMV);
227 It->second = std::unique_ptr<BranchInst>(new BranchInst(LLVMBr, *this));
228 return It->second.get();
230 case llvm::Instruction::Load: {
231 auto *LLVMLd = cast<llvm::LoadInst>(LLVMV);
232 It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
233 return It->second.get();
235 case llvm::Instruction::Store: {
236 auto *LLVMSt = cast<llvm::StoreInst>(LLVMV);
237 It->second = std::unique_ptr<StoreInst>(new StoreInst(LLVMSt, *this));
238 return It->second.get();
240 case llvm::Instruction::Ret: {
241 auto *LLVMRet = cast<llvm::ReturnInst>(LLVMV);
242 It->second = std::unique_ptr<ReturnInst>(new ReturnInst(LLVMRet, *this));
243 return It->second.get();
245 case llvm::Instruction::Call: {
246 auto *LLVMCall = cast<llvm::CallInst>(LLVMV);
247 It->second = std::unique_ptr<CallInst>(new CallInst(LLVMCall, *this));
248 return It->second.get();
250 case llvm::Instruction::Invoke: {
251 auto *LLVMInvoke = cast<llvm::InvokeInst>(LLVMV);
252 It->second = std::unique_ptr<InvokeInst>(new InvokeInst(LLVMInvoke, *this));
253 return It->second.get();
255 case llvm::Instruction::CallBr: {
256 auto *LLVMCallBr = cast<llvm::CallBrInst>(LLVMV);
257 It->second = std::unique_ptr<CallBrInst>(new CallBrInst(LLVMCallBr, *this));
258 return It->second.get();
260 case llvm::Instruction::LandingPad: {
261 auto *LLVMLPad = cast<llvm::LandingPadInst>(LLVMV);
262 It->second =
263 std::unique_ptr<LandingPadInst>(new LandingPadInst(LLVMLPad, *this));
264 return It->second.get();
266 case llvm::Instruction::CatchPad: {
267 auto *LLVMCPI = cast<llvm::CatchPadInst>(LLVMV);
268 It->second =
269 std::unique_ptr<CatchPadInst>(new CatchPadInst(LLVMCPI, *this));
270 return It->second.get();
272 case llvm::Instruction::CleanupPad: {
273 auto *LLVMCPI = cast<llvm::CleanupPadInst>(LLVMV);
274 It->second =
275 std::unique_ptr<CleanupPadInst>(new CleanupPadInst(LLVMCPI, *this));
276 return It->second.get();
278 case llvm::Instruction::CatchRet: {
279 auto *LLVMCRI = cast<llvm::CatchReturnInst>(LLVMV);
280 It->second =
281 std::unique_ptr<CatchReturnInst>(new CatchReturnInst(LLVMCRI, *this));
282 return It->second.get();
284 case llvm::Instruction::CleanupRet: {
285 auto *LLVMCRI = cast<llvm::CleanupReturnInst>(LLVMV);
286 It->second = std::unique_ptr<CleanupReturnInst>(
287 new CleanupReturnInst(LLVMCRI, *this));
288 return It->second.get();
290 case llvm::Instruction::GetElementPtr: {
291 auto *LLVMGEP = cast<llvm::GetElementPtrInst>(LLVMV);
292 It->second = std::unique_ptr<GetElementPtrInst>(
293 new GetElementPtrInst(LLVMGEP, *this));
294 return It->second.get();
296 case llvm::Instruction::CatchSwitch: {
297 auto *LLVMCatchSwitchInst = cast<llvm::CatchSwitchInst>(LLVMV);
298 It->second = std::unique_ptr<CatchSwitchInst>(
299 new CatchSwitchInst(LLVMCatchSwitchInst, *this));
300 return It->second.get();
302 case llvm::Instruction::Resume: {
303 auto *LLVMResumeInst = cast<llvm::ResumeInst>(LLVMV);
304 It->second =
305 std::unique_ptr<ResumeInst>(new ResumeInst(LLVMResumeInst, *this));
306 return It->second.get();
308 case llvm::Instruction::Switch: {
309 auto *LLVMSwitchInst = cast<llvm::SwitchInst>(LLVMV);
310 It->second =
311 std::unique_ptr<SwitchInst>(new SwitchInst(LLVMSwitchInst, *this));
312 return It->second.get();
314 case llvm::Instruction::FNeg: {
315 auto *LLVMUnaryOperator = cast<llvm::UnaryOperator>(LLVMV);
316 It->second = std::unique_ptr<UnaryOperator>(
317 new UnaryOperator(LLVMUnaryOperator, *this));
318 return It->second.get();
320 case llvm::Instruction::Add:
321 case llvm::Instruction::FAdd:
322 case llvm::Instruction::Sub:
323 case llvm::Instruction::FSub:
324 case llvm::Instruction::Mul:
325 case llvm::Instruction::FMul:
326 case llvm::Instruction::UDiv:
327 case llvm::Instruction::SDiv:
328 case llvm::Instruction::FDiv:
329 case llvm::Instruction::URem:
330 case llvm::Instruction::SRem:
331 case llvm::Instruction::FRem:
332 case llvm::Instruction::Shl:
333 case llvm::Instruction::LShr:
334 case llvm::Instruction::AShr:
335 case llvm::Instruction::And:
336 case llvm::Instruction::Or:
337 case llvm::Instruction::Xor: {
338 auto *LLVMBinaryOperator = cast<llvm::BinaryOperator>(LLVMV);
339 It->second = std::unique_ptr<BinaryOperator>(
340 new BinaryOperator(LLVMBinaryOperator, *this));
341 return It->second.get();
343 case llvm::Instruction::AtomicRMW: {
344 auto *LLVMAtomicRMW = cast<llvm::AtomicRMWInst>(LLVMV);
345 It->second =
346 std::unique_ptr<AtomicRMWInst>(new AtomicRMWInst(LLVMAtomicRMW, *this));
347 return It->second.get();
349 case llvm::Instruction::AtomicCmpXchg: {
350 auto *LLVMAtomicCmpXchg = cast<llvm::AtomicCmpXchgInst>(LLVMV);
351 It->second = std::unique_ptr<AtomicCmpXchgInst>(
352 new AtomicCmpXchgInst(LLVMAtomicCmpXchg, *this));
353 return It->second.get();
355 case llvm::Instruction::Alloca: {
356 auto *LLVMAlloca = cast<llvm::AllocaInst>(LLVMV);
357 It->second = std::unique_ptr<AllocaInst>(new AllocaInst(LLVMAlloca, *this));
358 return It->second.get();
360 case llvm::Instruction::ZExt:
361 case llvm::Instruction::SExt:
362 case llvm::Instruction::FPToUI:
363 case llvm::Instruction::FPToSI:
364 case llvm::Instruction::FPExt:
365 case llvm::Instruction::PtrToInt:
366 case llvm::Instruction::IntToPtr:
367 case llvm::Instruction::SIToFP:
368 case llvm::Instruction::UIToFP:
369 case llvm::Instruction::Trunc:
370 case llvm::Instruction::FPTrunc:
371 case llvm::Instruction::BitCast:
372 case llvm::Instruction::AddrSpaceCast: {
373 auto *LLVMCast = cast<llvm::CastInst>(LLVMV);
374 It->second = std::unique_ptr<CastInst>(new CastInst(LLVMCast, *this));
375 return It->second.get();
377 case llvm::Instruction::PHI: {
378 auto *LLVMPhi = cast<llvm::PHINode>(LLVMV);
379 It->second = std::unique_ptr<PHINode>(new PHINode(LLVMPhi, *this));
380 return It->second.get();
382 case llvm::Instruction::ICmp: {
383 auto *LLVMICmp = cast<llvm::ICmpInst>(LLVMV);
384 It->second = std::unique_ptr<ICmpInst>(new ICmpInst(LLVMICmp, *this));
385 return It->second.get();
387 case llvm::Instruction::FCmp: {
388 auto *LLVMFCmp = cast<llvm::FCmpInst>(LLVMV);
389 It->second = std::unique_ptr<FCmpInst>(new FCmpInst(LLVMFCmp, *this));
390 return It->second.get();
392 case llvm::Instruction::Unreachable: {
393 auto *LLVMUnreachable = cast<llvm::UnreachableInst>(LLVMV);
394 It->second = std::unique_ptr<UnreachableInst>(
395 new UnreachableInst(LLVMUnreachable, *this));
396 return It->second.get();
398 default:
399 break;
402 It->second = std::unique_ptr<OpaqueInst>(
403 new OpaqueInst(cast<llvm::Instruction>(LLVMV), *this));
404 return It->second.get();
407 Argument *Context::getOrCreateArgument(llvm::Argument *LLVMArg) {
408 auto Pair = LLVMValueToValueMap.insert({LLVMArg, nullptr});
409 auto It = Pair.first;
410 if (Pair.second) {
411 It->second = std::unique_ptr<Argument>(new Argument(LLVMArg, *this));
412 return cast<Argument>(It->second.get());
414 return cast<Argument>(It->second.get());
417 Constant *Context::getOrCreateConstant(llvm::Constant *LLVMC) {
418 return cast<Constant>(getOrCreateValueInternal(LLVMC, 0));
421 BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) {
422 assert(getValue(LLVMBB) == nullptr && "Already exists!");
423 auto NewBBPtr = std::unique_ptr<BasicBlock>(new BasicBlock(LLVMBB, *this));
424 auto *BB = cast<BasicBlock>(registerValue(std::move(NewBBPtr)));
425 // Create SandboxIR for BB's body.
426 BB->buildBasicBlockFromLLVMIR(LLVMBB);
427 return BB;
430 VAArgInst *Context::createVAArgInst(llvm::VAArgInst *SI) {
431 auto NewPtr = std::unique_ptr<VAArgInst>(new VAArgInst(SI, *this));
432 return cast<VAArgInst>(registerValue(std::move(NewPtr)));
435 FreezeInst *Context::createFreezeInst(llvm::FreezeInst *SI) {
436 auto NewPtr = std::unique_ptr<FreezeInst>(new FreezeInst(SI, *this));
437 return cast<FreezeInst>(registerValue(std::move(NewPtr)));
440 FenceInst *Context::createFenceInst(llvm::FenceInst *SI) {
441 auto NewPtr = std::unique_ptr<FenceInst>(new FenceInst(SI, *this));
442 return cast<FenceInst>(registerValue(std::move(NewPtr)));
445 SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
446 auto NewPtr = std::unique_ptr<SelectInst>(new SelectInst(SI, *this));
447 return cast<SelectInst>(registerValue(std::move(NewPtr)));
450 ExtractElementInst *
451 Context::createExtractElementInst(llvm::ExtractElementInst *EEI) {
452 auto NewPtr =
453 std::unique_ptr<ExtractElementInst>(new ExtractElementInst(EEI, *this));
454 return cast<ExtractElementInst>(registerValue(std::move(NewPtr)));
457 InsertElementInst *
458 Context::createInsertElementInst(llvm::InsertElementInst *IEI) {
459 auto NewPtr =
460 std::unique_ptr<InsertElementInst>(new InsertElementInst(IEI, *this));
461 return cast<InsertElementInst>(registerValue(std::move(NewPtr)));
464 ShuffleVectorInst *
465 Context::createShuffleVectorInst(llvm::ShuffleVectorInst *SVI) {
466 auto NewPtr =
467 std::unique_ptr<ShuffleVectorInst>(new ShuffleVectorInst(SVI, *this));
468 return cast<ShuffleVectorInst>(registerValue(std::move(NewPtr)));
471 ExtractValueInst *Context::createExtractValueInst(llvm::ExtractValueInst *EVI) {
472 auto NewPtr =
473 std::unique_ptr<ExtractValueInst>(new ExtractValueInst(EVI, *this));
474 return cast<ExtractValueInst>(registerValue(std::move(NewPtr)));
477 InsertValueInst *Context::createInsertValueInst(llvm::InsertValueInst *IVI) {
478 auto NewPtr =
479 std::unique_ptr<InsertValueInst>(new InsertValueInst(IVI, *this));
480 return cast<InsertValueInst>(registerValue(std::move(NewPtr)));
483 BranchInst *Context::createBranchInst(llvm::BranchInst *BI) {
484 auto NewPtr = std::unique_ptr<BranchInst>(new BranchInst(BI, *this));
485 return cast<BranchInst>(registerValue(std::move(NewPtr)));
488 LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
489 auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
490 return cast<LoadInst>(registerValue(std::move(NewPtr)));
493 StoreInst *Context::createStoreInst(llvm::StoreInst *SI) {
494 auto NewPtr = std::unique_ptr<StoreInst>(new StoreInst(SI, *this));
495 return cast<StoreInst>(registerValue(std::move(NewPtr)));
498 ReturnInst *Context::createReturnInst(llvm::ReturnInst *I) {
499 auto NewPtr = std::unique_ptr<ReturnInst>(new ReturnInst(I, *this));
500 return cast<ReturnInst>(registerValue(std::move(NewPtr)));
503 CallInst *Context::createCallInst(llvm::CallInst *I) {
504 auto NewPtr = std::unique_ptr<CallInst>(new CallInst(I, *this));
505 return cast<CallInst>(registerValue(std::move(NewPtr)));
508 InvokeInst *Context::createInvokeInst(llvm::InvokeInst *I) {
509 auto NewPtr = std::unique_ptr<InvokeInst>(new InvokeInst(I, *this));
510 return cast<InvokeInst>(registerValue(std::move(NewPtr)));
513 CallBrInst *Context::createCallBrInst(llvm::CallBrInst *I) {
514 auto NewPtr = std::unique_ptr<CallBrInst>(new CallBrInst(I, *this));
515 return cast<CallBrInst>(registerValue(std::move(NewPtr)));
518 UnreachableInst *Context::createUnreachableInst(llvm::UnreachableInst *UI) {
519 auto NewPtr =
520 std::unique_ptr<UnreachableInst>(new UnreachableInst(UI, *this));
521 return cast<UnreachableInst>(registerValue(std::move(NewPtr)));
523 LandingPadInst *Context::createLandingPadInst(llvm::LandingPadInst *I) {
524 auto NewPtr = std::unique_ptr<LandingPadInst>(new LandingPadInst(I, *this));
525 return cast<LandingPadInst>(registerValue(std::move(NewPtr)));
527 CatchPadInst *Context::createCatchPadInst(llvm::CatchPadInst *I) {
528 auto NewPtr = std::unique_ptr<CatchPadInst>(new CatchPadInst(I, *this));
529 return cast<CatchPadInst>(registerValue(std::move(NewPtr)));
531 CleanupPadInst *Context::createCleanupPadInst(llvm::CleanupPadInst *I) {
532 auto NewPtr = std::unique_ptr<CleanupPadInst>(new CleanupPadInst(I, *this));
533 return cast<CleanupPadInst>(registerValue(std::move(NewPtr)));
535 CatchReturnInst *Context::createCatchReturnInst(llvm::CatchReturnInst *I) {
536 auto NewPtr = std::unique_ptr<CatchReturnInst>(new CatchReturnInst(I, *this));
537 return cast<CatchReturnInst>(registerValue(std::move(NewPtr)));
539 CleanupReturnInst *
540 Context::createCleanupReturnInst(llvm::CleanupReturnInst *I) {
541 auto NewPtr =
542 std::unique_ptr<CleanupReturnInst>(new CleanupReturnInst(I, *this));
543 return cast<CleanupReturnInst>(registerValue(std::move(NewPtr)));
545 GetElementPtrInst *
546 Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
547 auto NewPtr =
548 std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
549 return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
551 CatchSwitchInst *Context::createCatchSwitchInst(llvm::CatchSwitchInst *I) {
552 auto NewPtr = std::unique_ptr<CatchSwitchInst>(new CatchSwitchInst(I, *this));
553 return cast<CatchSwitchInst>(registerValue(std::move(NewPtr)));
555 ResumeInst *Context::createResumeInst(llvm::ResumeInst *I) {
556 auto NewPtr = std::unique_ptr<ResumeInst>(new ResumeInst(I, *this));
557 return cast<ResumeInst>(registerValue(std::move(NewPtr)));
559 SwitchInst *Context::createSwitchInst(llvm::SwitchInst *I) {
560 auto NewPtr = std::unique_ptr<SwitchInst>(new SwitchInst(I, *this));
561 return cast<SwitchInst>(registerValue(std::move(NewPtr)));
563 UnaryOperator *Context::createUnaryOperator(llvm::UnaryOperator *I) {
564 auto NewPtr = std::unique_ptr<UnaryOperator>(new UnaryOperator(I, *this));
565 return cast<UnaryOperator>(registerValue(std::move(NewPtr)));
567 BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) {
568 auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this));
569 return cast<BinaryOperator>(registerValue(std::move(NewPtr)));
571 AtomicRMWInst *Context::createAtomicRMWInst(llvm::AtomicRMWInst *I) {
572 auto NewPtr = std::unique_ptr<AtomicRMWInst>(new AtomicRMWInst(I, *this));
573 return cast<AtomicRMWInst>(registerValue(std::move(NewPtr)));
575 AtomicCmpXchgInst *
576 Context::createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I) {
577 auto NewPtr =
578 std::unique_ptr<AtomicCmpXchgInst>(new AtomicCmpXchgInst(I, *this));
579 return cast<AtomicCmpXchgInst>(registerValue(std::move(NewPtr)));
581 AllocaInst *Context::createAllocaInst(llvm::AllocaInst *I) {
582 auto NewPtr = std::unique_ptr<AllocaInst>(new AllocaInst(I, *this));
583 return cast<AllocaInst>(registerValue(std::move(NewPtr)));
585 CastInst *Context::createCastInst(llvm::CastInst *I) {
586 auto NewPtr = std::unique_ptr<CastInst>(new CastInst(I, *this));
587 return cast<CastInst>(registerValue(std::move(NewPtr)));
589 PHINode *Context::createPHINode(llvm::PHINode *I) {
590 auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this));
591 return cast<PHINode>(registerValue(std::move(NewPtr)));
593 ICmpInst *Context::createICmpInst(llvm::ICmpInst *I) {
594 auto NewPtr = std::unique_ptr<ICmpInst>(new ICmpInst(I, *this));
595 return cast<ICmpInst>(registerValue(std::move(NewPtr)));
597 FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
598 auto NewPtr = std::unique_ptr<FCmpInst>(new FCmpInst(I, *this));
599 return cast<FCmpInst>(registerValue(std::move(NewPtr)));
601 Value *Context::getValue(llvm::Value *V) const {
602 auto It = LLVMValueToValueMap.find(V);
603 if (It != LLVMValueToValueMap.end())
604 return It->second.get();
605 return nullptr;
608 Context::Context(LLVMContext &LLVMCtx)
609 : LLVMCtx(LLVMCtx), IRTracker(*this),
610 LLVMIRBuilder(LLVMCtx, ConstantFolder()) {}
612 Context::~Context() {}
614 Module *Context::getModule(llvm::Module *LLVMM) const {
615 auto It = LLVMModuleToModuleMap.find(LLVMM);
616 if (It != LLVMModuleToModuleMap.end())
617 return It->second.get();
618 return nullptr;
621 Module *Context::getOrCreateModule(llvm::Module *LLVMM) {
622 auto Pair = LLVMModuleToModuleMap.insert({LLVMM, nullptr});
623 auto It = Pair.first;
624 if (!Pair.second)
625 return It->second.get();
626 It->second = std::unique_ptr<Module>(new Module(*LLVMM, *this));
627 return It->second.get();
630 Function *Context::createFunction(llvm::Function *F) {
631 assert(getValue(F) == nullptr && "Already exists!");
632 // Create the module if needed before we create the new sandboxir::Function.
633 // Note: this won't fully populate the module. The only globals that will be
634 // available will be the ones being used within the function.
635 getOrCreateModule(F->getParent());
637 auto NewFPtr = std::unique_ptr<Function>(new Function(F, *this));
638 auto *SBF = cast<Function>(registerValue(std::move(NewFPtr)));
639 // Create arguments.
640 for (auto &Arg : F->args())
641 getOrCreateArgument(&Arg);
642 // Create BBs.
643 for (auto &BB : *F)
644 createBasicBlock(&BB);
645 return SBF;
648 Module *Context::createModule(llvm::Module *LLVMM) {
649 auto *M = getOrCreateModule(LLVMM);
650 // Create the functions.
651 for (auto &LLVMF : *LLVMM)
652 createFunction(&LLVMF);
653 // Create globals.
654 for (auto &Global : LLVMM->globals())
655 getOrCreateValue(&Global);
656 // Create aliases.
657 for (auto &Alias : LLVMM->aliases())
658 getOrCreateValue(&Alias);
659 // Create ifuncs.
660 for (auto &IFunc : LLVMM->ifuncs())
661 getOrCreateValue(&IFunc);
663 return M;
666 void Context::runEraseInstrCallbacks(Instruction *I) {
667 for (const auto &CBEntry : EraseInstrCallbacks)
668 CBEntry.second(I);
671 void Context::runCreateInstrCallbacks(Instruction *I) {
672 for (auto &CBEntry : CreateInstrCallbacks)
673 CBEntry.second(I);
676 void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
677 for (auto &CBEntry : MoveInstrCallbacks)
678 CBEntry.second(I, WhereIt);
681 // An arbitrary limit, to check for accidental misuse. We expect a small number
682 // of callbacks to be registered at a time, but we can increase this number if
683 // we discover we needed more.
684 [[maybe_unused]] static constexpr int MaxRegisteredCallbacks = 16;
686 Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) {
687 assert(EraseInstrCallbacks.size() <= MaxRegisteredCallbacks &&
688 "EraseInstrCallbacks size limit exceeded");
689 CallbackID ID = NextCallbackID++;
690 EraseInstrCallbacks[ID] = CB;
691 return ID;
693 void Context::unregisterEraseInstrCallback(CallbackID ID) {
694 [[maybe_unused]] bool Erased = EraseInstrCallbacks.erase(ID);
695 assert(Erased &&
696 "Callback ID not found in EraseInstrCallbacks during deregistration");
699 Context::CallbackID
700 Context::registerCreateInstrCallback(CreateInstrCallback CB) {
701 assert(CreateInstrCallbacks.size() <= MaxRegisteredCallbacks &&
702 "CreateInstrCallbacks size limit exceeded");
703 CallbackID ID = NextCallbackID++;
704 CreateInstrCallbacks[ID] = CB;
705 return ID;
707 void Context::unregisterCreateInstrCallback(CallbackID ID) {
708 [[maybe_unused]] bool Erased = CreateInstrCallbacks.erase(ID);
709 assert(Erased &&
710 "Callback ID not found in CreateInstrCallbacks during deregistration");
713 Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) {
714 assert(MoveInstrCallbacks.size() <= MaxRegisteredCallbacks &&
715 "MoveInstrCallbacks size limit exceeded");
716 CallbackID ID = NextCallbackID++;
717 MoveInstrCallbacks[ID] = CB;
718 return ID;
720 void Context::unregisterMoveInstrCallback(CallbackID ID) {
721 [[maybe_unused]] bool Erased = MoveInstrCallbacks.erase(ID);
722 assert(Erased &&
723 "Callback ID not found in MoveInstrCallbacks during deregistration");
726 } // namespace llvm::sandboxir