1 //===- Tracker.cpp --------------------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/SandboxIR/Tracker.h"
10 #include "llvm/ADT/STLExtras.h"
11 #include "llvm/IR/BasicBlock.h"
12 #include "llvm/IR/Instruction.h"
13 #include "llvm/IR/Module.h"
14 #include "llvm/IR/StructuralHash.h"
15 #include "llvm/SandboxIR/Instruction.h"
18 using namespace llvm::sandboxir
;
22 std::string
IRSnapshotChecker::dumpIR(const llvm::Function
&F
) const {
24 raw_string_ostream
SS(Result
);
25 F
.print(SS
, /*AssemblyAnnotationWriter=*/nullptr);
29 IRSnapshotChecker::ContextSnapshot
IRSnapshotChecker::takeSnapshot() const {
30 ContextSnapshot Result
;
31 for (const auto &Entry
: Ctx
.LLVMModuleToModuleMap
)
32 for (const auto &F
: *Entry
.first
) {
33 FunctionSnapshot Snapshot
;
34 Snapshot
.Hash
= StructuralHash(F
, /*DetailedHash=*/true);
35 Snapshot
.TextualIR
= dumpIR(F
);
36 Result
[&F
] = Snapshot
;
41 bool IRSnapshotChecker::diff(const ContextSnapshot
&Orig
,
42 const ContextSnapshot
&Curr
) const {
43 bool DifferenceFound
= false;
44 for (const auto &[F
, OrigFS
] : Orig
) {
45 auto CurrFSIt
= Curr
.find(F
);
46 if (CurrFSIt
== Curr
.end()) {
47 DifferenceFound
= true;
48 dbgs() << "Function " << F
->getName() << " not found in current IR.\n";
49 dbgs() << OrigFS
.TextualIR
<< "\n";
52 const FunctionSnapshot
&CurrFS
= CurrFSIt
->second
;
53 if (OrigFS
.Hash
!= CurrFS
.Hash
) {
54 DifferenceFound
= true;
55 dbgs() << "Found IR difference in Function " << F
->getName() << "\n";
56 dbgs() << "Original:\n" << OrigFS
.TextualIR
<< "\n";
57 dbgs() << "Current:\n" << CurrFS
.TextualIR
<< "\n";
60 // Check that Curr doesn't contain any new functions.
61 for (const auto &[F
, CurrFS
] : Curr
) {
62 if (!Orig
.contains(F
)) {
63 DifferenceFound
= true;
64 dbgs() << "Function " << F
->getName()
65 << " found in current IR but not in original snapshot.\n";
66 dbgs() << CurrFS
.TextualIR
<< "\n";
69 return DifferenceFound
;
72 void IRSnapshotChecker::save() { OrigContextSnapshot
= takeSnapshot(); }
74 void IRSnapshotChecker::expectNoDiff() {
75 ContextSnapshot CurrContextSnapshot
= takeSnapshot();
76 if (diff(OrigContextSnapshot
, CurrContextSnapshot
)) {
78 "Original and current IR differ! Probably a checkpointing bug.");
82 void UseSet::dump() const {
87 void UseSwap::dump() const {
93 PHIRemoveIncoming::PHIRemoveIncoming(PHINode
*PHI
, unsigned RemovedIdx
)
94 : PHI(PHI
), RemovedIdx(RemovedIdx
) {
95 RemovedV
= PHI
->getIncomingValue(RemovedIdx
);
96 RemovedBB
= PHI
->getIncomingBlock(RemovedIdx
);
99 void PHIRemoveIncoming::revert(Tracker
&Tracker
) {
100 // Special case: if the PHI is now empty, as we don't need to care about the
101 // order of the incoming values.
102 unsigned NumIncoming
= PHI
->getNumIncomingValues();
103 if (NumIncoming
== 0) {
104 PHI
->addIncoming(RemovedV
, RemovedBB
);
107 // Shift all incoming values by one starting from the end until `Idx`.
108 // Start by adding a copy of the last incoming values.
109 unsigned LastIdx
= NumIncoming
- 1;
110 PHI
->addIncoming(PHI
->getIncomingValue(LastIdx
),
111 PHI
->getIncomingBlock(LastIdx
));
112 for (unsigned Idx
= LastIdx
; Idx
> RemovedIdx
; --Idx
) {
113 auto *PrevV
= PHI
->getIncomingValue(Idx
- 1);
114 auto *PrevBB
= PHI
->getIncomingBlock(Idx
- 1);
115 PHI
->setIncomingValue(Idx
, PrevV
);
116 PHI
->setIncomingBlock(Idx
, PrevBB
);
118 PHI
->setIncomingValue(RemovedIdx
, RemovedV
);
119 PHI
->setIncomingBlock(RemovedIdx
, RemovedBB
);
123 void PHIRemoveIncoming::dump() const {
129 PHIAddIncoming::PHIAddIncoming(PHINode
*PHI
)
130 : PHI(PHI
), Idx(PHI
->getNumIncomingValues()) {}
132 void PHIAddIncoming::revert(Tracker
&Tracker
) { PHI
->removeIncomingValue(Idx
); }
135 void PHIAddIncoming::dump() const {
141 Tracker::~Tracker() {
142 assert(Changes
.empty() && "You must accept or revert changes!");
145 EraseFromParent::EraseFromParent(std::unique_ptr
<sandboxir::Value
> &&ErasedIPtr
)
146 : ErasedIPtr(std::move(ErasedIPtr
)) {
147 auto *I
= cast
<Instruction
>(this->ErasedIPtr
.get());
148 auto LLVMInstrs
= I
->getLLVMInstrs();
149 // Iterate in reverse program order.
150 for (auto *LLVMI
: reverse(LLVMInstrs
)) {
151 SmallVector
<llvm::Value
*> Operands
;
152 Operands
.reserve(LLVMI
->getNumOperands());
153 for (auto [OpNum
, Use
] : enumerate(LLVMI
->operands()))
154 Operands
.push_back(Use
.get());
155 InstrData
.push_back({Operands
, LLVMI
});
157 assert(is_sorted(InstrData
,
158 [](const auto &D0
, const auto &D1
) {
159 return D0
.LLVMI
->comesBefore(D1
.LLVMI
);
161 "Expected reverse program order!");
162 auto *BotLLVMI
= cast
<llvm::Instruction
>(I
->Val
);
163 if (BotLLVMI
->getNextNode() != nullptr)
164 NextLLVMIOrBB
= BotLLVMI
->getNextNode();
166 NextLLVMIOrBB
= BotLLVMI
->getParent();
169 void EraseFromParent::accept() {
170 for (const auto &IData
: InstrData
)
171 IData
.LLVMI
->deleteValue();
174 void EraseFromParent::revert(Tracker
&Tracker
) {
175 // Place the bottom-most instruction first.
176 auto [Operands
, BotLLVMI
] = InstrData
[0];
177 if (auto *NextLLVMI
= dyn_cast
<llvm::Instruction
*>(NextLLVMIOrBB
)) {
178 BotLLVMI
->insertBefore(NextLLVMI
);
180 auto *LLVMBB
= cast
<llvm::BasicBlock
*>(NextLLVMIOrBB
);
181 BotLLVMI
->insertInto(LLVMBB
, LLVMBB
->end());
183 for (auto [OpNum
, Op
] : enumerate(Operands
))
184 BotLLVMI
->setOperand(OpNum
, Op
);
186 // Go over the rest of the instructions and stack them on top.
187 for (auto [Operands
, LLVMI
] : drop_begin(InstrData
)) {
188 LLVMI
->insertBefore(BotLLVMI
);
189 for (auto [OpNum
, Op
] : enumerate(Operands
))
190 LLVMI
->setOperand(OpNum
, Op
);
193 Tracker
.getContext().registerValue(std::move(ErasedIPtr
));
197 void EraseFromParent::dump() const {
203 RemoveFromParent::RemoveFromParent(Instruction
*RemovedI
) : RemovedI(RemovedI
) {
204 if (auto *NextI
= RemovedI
->getNextNode())
205 NextInstrOrBB
= NextI
;
207 NextInstrOrBB
= RemovedI
->getParent();
210 void RemoveFromParent::revert(Tracker
&Tracker
) {
211 if (auto *NextI
= dyn_cast
<Instruction
*>(NextInstrOrBB
)) {
212 RemovedI
->insertBefore(NextI
);
214 auto *BB
= cast
<BasicBlock
*>(NextInstrOrBB
);
215 RemovedI
->insertInto(BB
, BB
->end());
220 void RemoveFromParent::dump() const {
226 CatchSwitchAddHandler::CatchSwitchAddHandler(CatchSwitchInst
*CSI
)
227 : CSI(CSI
), HandlerIdx(CSI
->getNumHandlers()) {}
229 void CatchSwitchAddHandler::revert(Tracker
&Tracker
) {
230 // TODO: This should ideally use sandboxir::CatchSwitchInst::removeHandler()
231 // once it gets implemented.
232 auto *LLVMCSI
= cast
<llvm::CatchSwitchInst
>(CSI
->Val
);
233 LLVMCSI
->removeHandler(LLVMCSI
->handler_begin() + HandlerIdx
);
236 SwitchRemoveCase::SwitchRemoveCase(SwitchInst
*Switch
) : Switch(Switch
) {
237 for (const auto &C
: Switch
->cases())
238 Cases
.push_back({C
.getCaseValue(), C
.getCaseSuccessor()});
241 void SwitchRemoveCase::revert(Tracker
&Tracker
) {
242 // SwitchInst::removeCase doesn't provide any guarantees about the order of
243 // cases after removal. In order to preserve the original ordering, we save
244 // all of them and, when reverting, clear them all then insert them in the
245 // desired order. This still relies on the fact that `addCase` will insert
246 // them at the end, but it is documented to invalidate `case_end()` so it's
248 unsigned NumCases
= Switch
->getNumCases();
249 for (unsigned I
= 0; I
< NumCases
; ++I
)
250 Switch
->removeCase(Switch
->case_begin());
251 for (auto &Case
: Cases
)
252 Switch
->addCase(Case
.Val
, Case
.Dest
);
256 void SwitchRemoveCase::dump() const {
262 void SwitchAddCase::revert(Tracker
&Tracker
) {
263 auto It
= Switch
->findCaseValue(Val
);
264 Switch
->removeCase(It
);
268 void SwitchAddCase::dump() const {
274 MoveInstr::MoveInstr(Instruction
*MovedI
) : MovedI(MovedI
) {
275 if (auto *NextI
= MovedI
->getNextNode())
276 NextInstrOrBB
= NextI
;
278 NextInstrOrBB
= MovedI
->getParent();
281 void MoveInstr::revert(Tracker
&Tracker
) {
282 if (auto *NextI
= dyn_cast
<Instruction
*>(NextInstrOrBB
)) {
283 MovedI
->moveBefore(NextI
);
285 auto *BB
= cast
<BasicBlock
*>(NextInstrOrBB
);
286 MovedI
->moveBefore(*BB
, BB
->end());
291 void MoveInstr::dump() const {
297 void InsertIntoBB::revert(Tracker
&Tracker
) { InsertedI
->removeFromParent(); }
299 InsertIntoBB::InsertIntoBB(Instruction
*InsertedI
) : InsertedI(InsertedI
) {}
302 void InsertIntoBB::dump() const {
308 void CreateAndInsertInst::revert(Tracker
&Tracker
) { NewI
->eraseFromParent(); }
311 void CreateAndInsertInst::dump() const {
317 ShuffleVectorSetMask::ShuffleVectorSetMask(ShuffleVectorInst
*SVI
)
318 : SVI(SVI
), PrevMask(SVI
->getShuffleMask()) {}
320 void ShuffleVectorSetMask::revert(Tracker
&Tracker
) {
321 SVI
->setShuffleMask(PrevMask
);
325 void ShuffleVectorSetMask::dump() const {
331 CmpSwapOperands::CmpSwapOperands(CmpInst
*Cmp
) : Cmp(Cmp
) {}
333 void CmpSwapOperands::revert(Tracker
&Tracker
) { Cmp
->swapOperands(); }
335 void CmpSwapOperands::dump() const {
341 void Tracker::save() {
342 State
= TrackerState::Record
;
343 #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS)
344 SnapshotChecker
.save();
348 void Tracker::revert() {
349 assert(State
== TrackerState::Record
&& "Forgot to save()!");
350 State
= TrackerState::Disabled
;
351 for (auto &Change
: reverse(Changes
))
352 Change
->revert(*this);
354 #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS)
355 SnapshotChecker
.expectNoDiff();
359 void Tracker::accept() {
360 assert(State
== TrackerState::Record
&& "Forgot to save()!");
361 State
= TrackerState::Disabled
;
362 for (auto &Change
: Changes
)
368 void Tracker::dump(raw_ostream
&OS
) const {
369 for (auto [Idx
, ChangePtr
] : enumerate(Changes
)) {
375 void Tracker::dump() const {