[clang] Add test for CWG190 "Layout-compatible POD-struct types" (#121668)
[llvm-project.git] / llvm / lib / SandboxIR / Tracker.cpp
blob27ed37aa9bdd37ea14406637b7559c93895c04c5
1 //===- Tracker.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "llvm/SandboxIR/Tracker.h"
10 #include "llvm/ADT/STLExtras.h"
11 #include "llvm/IR/BasicBlock.h"
12 #include "llvm/IR/Instruction.h"
13 #include "llvm/IR/Module.h"
14 #include "llvm/IR/StructuralHash.h"
15 #include "llvm/SandboxIR/Instruction.h"
16 #include <sstream>
18 using namespace llvm::sandboxir;
20 #ifndef NDEBUG
22 std::string IRSnapshotChecker::dumpIR(const llvm::Function &F) const {
23 std::string Result;
24 raw_string_ostream SS(Result);
25 F.print(SS, /*AssemblyAnnotationWriter=*/nullptr);
26 return Result;
29 IRSnapshotChecker::ContextSnapshot IRSnapshotChecker::takeSnapshot() const {
30 ContextSnapshot Result;
31 for (const auto &Entry : Ctx.LLVMModuleToModuleMap)
32 for (const auto &F : *Entry.first) {
33 FunctionSnapshot Snapshot;
34 Snapshot.Hash = StructuralHash(F, /*DetailedHash=*/true);
35 Snapshot.TextualIR = dumpIR(F);
36 Result[&F] = Snapshot;
38 return Result;
41 bool IRSnapshotChecker::diff(const ContextSnapshot &Orig,
42 const ContextSnapshot &Curr) const {
43 bool DifferenceFound = false;
44 for (const auto &[F, OrigFS] : Orig) {
45 auto CurrFSIt = Curr.find(F);
46 if (CurrFSIt == Curr.end()) {
47 DifferenceFound = true;
48 dbgs() << "Function " << F->getName() << " not found in current IR.\n";
49 dbgs() << OrigFS.TextualIR << "\n";
50 continue;
52 const FunctionSnapshot &CurrFS = CurrFSIt->second;
53 if (OrigFS.Hash != CurrFS.Hash) {
54 DifferenceFound = true;
55 dbgs() << "Found IR difference in Function " << F->getName() << "\n";
56 dbgs() << "Original:\n" << OrigFS.TextualIR << "\n";
57 dbgs() << "Current:\n" << CurrFS.TextualIR << "\n";
60 // Check that Curr doesn't contain any new functions.
61 for (const auto &[F, CurrFS] : Curr) {
62 if (!Orig.contains(F)) {
63 DifferenceFound = true;
64 dbgs() << "Function " << F->getName()
65 << " found in current IR but not in original snapshot.\n";
66 dbgs() << CurrFS.TextualIR << "\n";
69 return DifferenceFound;
72 void IRSnapshotChecker::save() { OrigContextSnapshot = takeSnapshot(); }
74 void IRSnapshotChecker::expectNoDiff() {
75 ContextSnapshot CurrContextSnapshot = takeSnapshot();
76 if (diff(OrigContextSnapshot, CurrContextSnapshot)) {
77 llvm_unreachable(
78 "Original and current IR differ! Probably a checkpointing bug.");
82 void UseSet::dump() const {
83 dump(dbgs());
84 dbgs() << "\n";
87 void UseSwap::dump() const {
88 dump(dbgs());
89 dbgs() << "\n";
91 #endif // NDEBUG
93 PHIRemoveIncoming::PHIRemoveIncoming(PHINode *PHI, unsigned RemovedIdx)
94 : PHI(PHI), RemovedIdx(RemovedIdx) {
95 RemovedV = PHI->getIncomingValue(RemovedIdx);
96 RemovedBB = PHI->getIncomingBlock(RemovedIdx);
99 void PHIRemoveIncoming::revert(Tracker &Tracker) {
100 // Special case: if the PHI is now empty, as we don't need to care about the
101 // order of the incoming values.
102 unsigned NumIncoming = PHI->getNumIncomingValues();
103 if (NumIncoming == 0) {
104 PHI->addIncoming(RemovedV, RemovedBB);
105 return;
107 // Shift all incoming values by one starting from the end until `Idx`.
108 // Start by adding a copy of the last incoming values.
109 unsigned LastIdx = NumIncoming - 1;
110 PHI->addIncoming(PHI->getIncomingValue(LastIdx),
111 PHI->getIncomingBlock(LastIdx));
112 for (unsigned Idx = LastIdx; Idx > RemovedIdx; --Idx) {
113 auto *PrevV = PHI->getIncomingValue(Idx - 1);
114 auto *PrevBB = PHI->getIncomingBlock(Idx - 1);
115 PHI->setIncomingValue(Idx, PrevV);
116 PHI->setIncomingBlock(Idx, PrevBB);
118 PHI->setIncomingValue(RemovedIdx, RemovedV);
119 PHI->setIncomingBlock(RemovedIdx, RemovedBB);
122 #ifndef NDEBUG
123 void PHIRemoveIncoming::dump() const {
124 dump(dbgs());
125 dbgs() << "\n";
127 #endif // NDEBUG
129 PHIAddIncoming::PHIAddIncoming(PHINode *PHI)
130 : PHI(PHI), Idx(PHI->getNumIncomingValues()) {}
132 void PHIAddIncoming::revert(Tracker &Tracker) { PHI->removeIncomingValue(Idx); }
134 #ifndef NDEBUG
135 void PHIAddIncoming::dump() const {
136 dump(dbgs());
137 dbgs() << "\n";
139 #endif // NDEBUG
141 Tracker::~Tracker() {
142 assert(Changes.empty() && "You must accept or revert changes!");
145 EraseFromParent::EraseFromParent(std::unique_ptr<sandboxir::Value> &&ErasedIPtr)
146 : ErasedIPtr(std::move(ErasedIPtr)) {
147 auto *I = cast<Instruction>(this->ErasedIPtr.get());
148 auto LLVMInstrs = I->getLLVMInstrs();
149 // Iterate in reverse program order.
150 for (auto *LLVMI : reverse(LLVMInstrs)) {
151 SmallVector<llvm::Value *> Operands;
152 Operands.reserve(LLVMI->getNumOperands());
153 for (auto [OpNum, Use] : enumerate(LLVMI->operands()))
154 Operands.push_back(Use.get());
155 InstrData.push_back({Operands, LLVMI});
157 assert(is_sorted(InstrData,
158 [](const auto &D0, const auto &D1) {
159 return D0.LLVMI->comesBefore(D1.LLVMI);
160 }) &&
161 "Expected reverse program order!");
162 auto *BotLLVMI = cast<llvm::Instruction>(I->Val);
163 if (BotLLVMI->getNextNode() != nullptr)
164 NextLLVMIOrBB = BotLLVMI->getNextNode();
165 else
166 NextLLVMIOrBB = BotLLVMI->getParent();
169 void EraseFromParent::accept() {
170 for (const auto &IData : InstrData)
171 IData.LLVMI->deleteValue();
174 void EraseFromParent::revert(Tracker &Tracker) {
175 // Place the bottom-most instruction first.
176 auto [Operands, BotLLVMI] = InstrData[0];
177 if (auto *NextLLVMI = dyn_cast<llvm::Instruction *>(NextLLVMIOrBB)) {
178 BotLLVMI->insertBefore(NextLLVMI);
179 } else {
180 auto *LLVMBB = cast<llvm::BasicBlock *>(NextLLVMIOrBB);
181 BotLLVMI->insertInto(LLVMBB, LLVMBB->end());
183 for (auto [OpNum, Op] : enumerate(Operands))
184 BotLLVMI->setOperand(OpNum, Op);
186 // Go over the rest of the instructions and stack them on top.
187 for (auto [Operands, LLVMI] : drop_begin(InstrData)) {
188 LLVMI->insertBefore(BotLLVMI);
189 for (auto [OpNum, Op] : enumerate(Operands))
190 LLVMI->setOperand(OpNum, Op);
191 BotLLVMI = LLVMI;
193 Tracker.getContext().registerValue(std::move(ErasedIPtr));
196 #ifndef NDEBUG
197 void EraseFromParent::dump() const {
198 dump(dbgs());
199 dbgs() << "\n";
201 #endif // NDEBUG
203 RemoveFromParent::RemoveFromParent(Instruction *RemovedI) : RemovedI(RemovedI) {
204 if (auto *NextI = RemovedI->getNextNode())
205 NextInstrOrBB = NextI;
206 else
207 NextInstrOrBB = RemovedI->getParent();
210 void RemoveFromParent::revert(Tracker &Tracker) {
211 if (auto *NextI = dyn_cast<Instruction *>(NextInstrOrBB)) {
212 RemovedI->insertBefore(NextI);
213 } else {
214 auto *BB = cast<BasicBlock *>(NextInstrOrBB);
215 RemovedI->insertInto(BB, BB->end());
219 #ifndef NDEBUG
220 void RemoveFromParent::dump() const {
221 dump(dbgs());
222 dbgs() << "\n";
224 #endif
226 CatchSwitchAddHandler::CatchSwitchAddHandler(CatchSwitchInst *CSI)
227 : CSI(CSI), HandlerIdx(CSI->getNumHandlers()) {}
229 void CatchSwitchAddHandler::revert(Tracker &Tracker) {
230 // TODO: This should ideally use sandboxir::CatchSwitchInst::removeHandler()
231 // once it gets implemented.
232 auto *LLVMCSI = cast<llvm::CatchSwitchInst>(CSI->Val);
233 LLVMCSI->removeHandler(LLVMCSI->handler_begin() + HandlerIdx);
236 SwitchRemoveCase::SwitchRemoveCase(SwitchInst *Switch) : Switch(Switch) {
237 for (const auto &C : Switch->cases())
238 Cases.push_back({C.getCaseValue(), C.getCaseSuccessor()});
241 void SwitchRemoveCase::revert(Tracker &Tracker) {
242 // SwitchInst::removeCase doesn't provide any guarantees about the order of
243 // cases after removal. In order to preserve the original ordering, we save
244 // all of them and, when reverting, clear them all then insert them in the
245 // desired order. This still relies on the fact that `addCase` will insert
246 // them at the end, but it is documented to invalidate `case_end()` so it's
247 // probably okay.
248 unsigned NumCases = Switch->getNumCases();
249 for (unsigned I = 0; I < NumCases; ++I)
250 Switch->removeCase(Switch->case_begin());
251 for (auto &Case : Cases)
252 Switch->addCase(Case.Val, Case.Dest);
255 #ifndef NDEBUG
256 void SwitchRemoveCase::dump() const {
257 dump(dbgs());
258 dbgs() << "\n";
260 #endif // NDEBUG
262 void SwitchAddCase::revert(Tracker &Tracker) {
263 auto It = Switch->findCaseValue(Val);
264 Switch->removeCase(It);
267 #ifndef NDEBUG
268 void SwitchAddCase::dump() const {
269 dump(dbgs());
270 dbgs() << "\n";
272 #endif // NDEBUG
274 MoveInstr::MoveInstr(Instruction *MovedI) : MovedI(MovedI) {
275 if (auto *NextI = MovedI->getNextNode())
276 NextInstrOrBB = NextI;
277 else
278 NextInstrOrBB = MovedI->getParent();
281 void MoveInstr::revert(Tracker &Tracker) {
282 if (auto *NextI = dyn_cast<Instruction *>(NextInstrOrBB)) {
283 MovedI->moveBefore(NextI);
284 } else {
285 auto *BB = cast<BasicBlock *>(NextInstrOrBB);
286 MovedI->moveBefore(*BB, BB->end());
290 #ifndef NDEBUG
291 void MoveInstr::dump() const {
292 dump(dbgs());
293 dbgs() << "\n";
295 #endif
297 void InsertIntoBB::revert(Tracker &Tracker) { InsertedI->removeFromParent(); }
299 InsertIntoBB::InsertIntoBB(Instruction *InsertedI) : InsertedI(InsertedI) {}
301 #ifndef NDEBUG
302 void InsertIntoBB::dump() const {
303 dump(dbgs());
304 dbgs() << "\n";
306 #endif
308 void CreateAndInsertInst::revert(Tracker &Tracker) { NewI->eraseFromParent(); }
310 #ifndef NDEBUG
311 void CreateAndInsertInst::dump() const {
312 dump(dbgs());
313 dbgs() << "\n";
315 #endif
317 ShuffleVectorSetMask::ShuffleVectorSetMask(ShuffleVectorInst *SVI)
318 : SVI(SVI), PrevMask(SVI->getShuffleMask()) {}
320 void ShuffleVectorSetMask::revert(Tracker &Tracker) {
321 SVI->setShuffleMask(PrevMask);
324 #ifndef NDEBUG
325 void ShuffleVectorSetMask::dump() const {
326 dump(dbgs());
327 dbgs() << "\n";
329 #endif
331 CmpSwapOperands::CmpSwapOperands(CmpInst *Cmp) : Cmp(Cmp) {}
333 void CmpSwapOperands::revert(Tracker &Tracker) { Cmp->swapOperands(); }
334 #ifndef NDEBUG
335 void CmpSwapOperands::dump() const {
336 dump(dbgs());
337 dbgs() << "\n";
339 #endif
341 void Tracker::save() {
342 State = TrackerState::Record;
343 #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS)
344 SnapshotChecker.save();
345 #endif
348 void Tracker::revert() {
349 assert(State == TrackerState::Record && "Forgot to save()!");
350 State = TrackerState::Disabled;
351 for (auto &Change : reverse(Changes))
352 Change->revert(*this);
353 Changes.clear();
354 #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS)
355 SnapshotChecker.expectNoDiff();
356 #endif
359 void Tracker::accept() {
360 assert(State == TrackerState::Record && "Forgot to save()!");
361 State = TrackerState::Disabled;
362 for (auto &Change : Changes)
363 Change->accept();
364 Changes.clear();
367 #ifndef NDEBUG
368 void Tracker::dump(raw_ostream &OS) const {
369 for (auto [Idx, ChangePtr] : enumerate(Changes)) {
370 OS << Idx << ". ";
371 ChangePtr->dump(OS);
372 OS << "\n";
375 void Tracker::dump() const {
376 dump(dbgs());
377 dbgs() << "\n";
379 #endif // NDEBUG