1 //==-- X86LoadValueInjectionLoadHardening.cpp - LVI load hardening for x86 --=//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 /// Description: This pass finds Load Value Injection (LVI) gadgets consisting
10 /// of a load from memory (i.e., SOURCE), and any operation that may transmit
11 /// the value loaded from memory over a covert channel, or use the value loaded
12 /// from memory to determine a branch/call target (i.e., SINK). After finding
13 /// all such gadgets in a given function, the pass minimally inserts LFENCE
14 /// instructions in such a manner that the following property is satisfied: for
15 /// all SOURCE+SINK pairs, all paths in the CFG from SOURCE to SINK contain at
16 /// least one LFENCE instruction. The algorithm that implements this minimal
17 /// insertion is influenced by an academic paper that minimally inserts memory
18 /// fences for high-performance concurrent programs:
19 /// http://www.cs.ucr.edu/~lesani/companion/oopsla15/OOPSLA15.pdf
20 /// The algorithm implemented in this pass is as follows:
21 /// 1. Build a condensed CFG (i.e., a GadgetGraph) consisting only of the
22 /// following components:
23 /// - SOURCE instructions (also includes function arguments)
24 /// - SINK instructions
25 /// - Basic block entry points
26 /// - Basic block terminators
27 /// - LFENCE instructions
28 /// 2. Analyze the GadgetGraph to determine which SOURCE+SINK pairs (i.e.,
29 /// gadgets) are already mitigated by existing LFENCEs. If all gadgets have been
30 /// mitigated, go to step 6.
31 /// 3. Use a heuristic or plugin to approximate minimal LFENCE insertion.
32 /// 4. Insert one LFENCE along each CFG edge that was cut in step 3.
34 /// 6. If any LFENCEs were inserted, return `true` from runOnMachineFunction()
35 /// to tell LLVM that the function was modified.
37 //===----------------------------------------------------------------------===//
39 #include "ImmutableGraph.h"
41 #include "X86Subtarget.h"
42 #include "X86TargetMachine.h"
43 #include "llvm/ADT/DenseMap.h"
44 #include "llvm/ADT/STLExtras.h"
45 #include "llvm/ADT/SmallSet.h"
46 #include "llvm/ADT/Statistic.h"
47 #include "llvm/ADT/StringRef.h"
48 #include "llvm/CodeGen/MachineBasicBlock.h"
49 #include "llvm/CodeGen/MachineDominanceFrontier.h"
50 #include "llvm/CodeGen/MachineDominators.h"
51 #include "llvm/CodeGen/MachineFunction.h"
52 #include "llvm/CodeGen/MachineFunctionPass.h"
53 #include "llvm/CodeGen/MachineInstr.h"
54 #include "llvm/CodeGen/MachineInstrBuilder.h"
55 #include "llvm/CodeGen/MachineLoopInfo.h"
56 #include "llvm/CodeGen/RDFGraph.h"
57 #include "llvm/CodeGen/RDFLiveness.h"
58 #include "llvm/InitializePasses.h"
59 #include "llvm/Support/CommandLine.h"
60 #include "llvm/Support/DOTGraphTraits.h"
61 #include "llvm/Support/Debug.h"
62 #include "llvm/Support/DynamicLibrary.h"
63 #include "llvm/Support/GraphWriter.h"
64 #include "llvm/Support/raw_ostream.h"
68 #define PASS_KEY "x86-lvi-load"
69 #define DEBUG_TYPE PASS_KEY
71 STATISTIC(NumFences
, "Number of LFENCEs inserted for LVI mitigation");
72 STATISTIC(NumFunctionsConsidered
, "Number of functions analyzed");
73 STATISTIC(NumFunctionsMitigated
, "Number of functions for which mitigations "
75 STATISTIC(NumGadgets
, "Number of LVI gadgets detected during analysis");
77 static cl::opt
<std::string
> OptimizePluginPath(
78 PASS_KEY
"-opt-plugin",
79 cl::desc("Specify a plugin to optimize LFENCE insertion"), cl::Hidden
);
81 static cl::opt
<bool> NoConditionalBranches(
82 PASS_KEY
"-no-cbranch",
83 cl::desc("Don't treat conditional branches as disclosure gadgets. This "
84 "may improve performance, at the cost of security."),
85 cl::init(false), cl::Hidden
);
87 static cl::opt
<bool> EmitDot(
90 "For each function, emit a dot graph depicting potential LVI gadgets"),
91 cl::init(false), cl::Hidden
);
93 static cl::opt
<bool> EmitDotOnly(
95 cl::desc("For each function, emit a dot graph depicting potential LVI "
96 "gadgets, and do not insert any fences"),
97 cl::init(false), cl::Hidden
);
99 static cl::opt
<bool> EmitDotVerify(
100 PASS_KEY
"-dot-verify",
101 cl::desc("For each function, emit a dot graph to stdout depicting "
102 "potential LVI gadgets, used for testing purposes only"),
103 cl::init(false), cl::Hidden
);
105 static llvm::sys::DynamicLibrary OptimizeDL
;
106 typedef int (*OptimizeCutT
)(unsigned int *Nodes
, unsigned int NodesSize
,
107 unsigned int *Edges
, int *EdgeValues
,
108 int *CutEdges
/* out */, unsigned int EdgesSize
);
109 static OptimizeCutT OptimizeCut
= nullptr;
113 struct MachineGadgetGraph
: ImmutableGraph
<MachineInstr
*, int> {
114 static constexpr int GadgetEdgeSentinel
= -1;
115 static constexpr MachineInstr
*const ArgNodeSentinel
= nullptr;
117 using GraphT
= ImmutableGraph
<MachineInstr
*, int>;
118 using Node
= typename
GraphT::Node
;
119 using Edge
= typename
GraphT::Edge
;
120 using size_type
= typename
GraphT::size_type
;
121 MachineGadgetGraph(std::unique_ptr
<Node
[]> Nodes
,
122 std::unique_ptr
<Edge
[]> Edges
, size_type NodesSize
,
123 size_type EdgesSize
, int NumFences
= 0, int NumGadgets
= 0)
124 : GraphT(std::move(Nodes
), std::move(Edges
), NodesSize
, EdgesSize
),
125 NumFences(NumFences
), NumGadgets(NumGadgets
) {}
126 static inline bool isCFGEdge(const Edge
&E
) {
127 return E
.getValue() != GadgetEdgeSentinel
;
129 static inline bool isGadgetEdge(const Edge
&E
) {
130 return E
.getValue() == GadgetEdgeSentinel
;
136 class X86LoadValueInjectionLoadHardeningPass
: public MachineFunctionPass
{
138 X86LoadValueInjectionLoadHardeningPass() : MachineFunctionPass(ID
) {}
140 StringRef
getPassName() const override
{
141 return "X86 Load Value Injection (LVI) Load Hardening";
143 void getAnalysisUsage(AnalysisUsage
&AU
) const override
;
144 bool runOnMachineFunction(MachineFunction
&MF
) override
;
149 using GraphBuilder
= ImmutableGraphBuilder
<MachineGadgetGraph
>;
150 using Edge
= MachineGadgetGraph::Edge
;
151 using Node
= MachineGadgetGraph::Node
;
152 using EdgeSet
= MachineGadgetGraph::EdgeSet
;
153 using NodeSet
= MachineGadgetGraph::NodeSet
;
155 const X86Subtarget
*STI
= nullptr;
156 const TargetInstrInfo
*TII
= nullptr;
157 const TargetRegisterInfo
*TRI
= nullptr;
159 std::unique_ptr
<MachineGadgetGraph
>
160 getGadgetGraph(MachineFunction
&MF
, const MachineLoopInfo
&MLI
,
161 const MachineDominatorTree
&MDT
,
162 const MachineDominanceFrontier
&MDF
) const;
163 int hardenLoadsWithPlugin(MachineFunction
&MF
,
164 std::unique_ptr
<MachineGadgetGraph
> Graph
) const;
165 int hardenLoadsWithHeuristic(MachineFunction
&MF
,
166 std::unique_ptr
<MachineGadgetGraph
> Graph
) const;
167 int elimMitigatedEdgesAndNodes(MachineGadgetGraph
&G
,
168 EdgeSet
&ElimEdges
/* in, out */,
169 NodeSet
&ElimNodes
/* in, out */) const;
170 std::unique_ptr
<MachineGadgetGraph
>
171 trimMitigatedEdges(std::unique_ptr
<MachineGadgetGraph
> Graph
) const;
172 int insertFences(MachineFunction
&MF
, MachineGadgetGraph
&G
,
173 EdgeSet
&CutEdges
/* in, out */) const;
174 bool instrUsesRegToAccessMemory(const MachineInstr
&I
, unsigned Reg
) const;
175 bool instrUsesRegToBranch(const MachineInstr
&I
, unsigned Reg
) const;
176 inline bool isFence(const MachineInstr
*MI
) const {
177 return MI
&& (MI
->getOpcode() == X86::LFENCE
||
178 (STI
->useLVIControlFlowIntegrity() && MI
->isCall()));
182 } // end anonymous namespace
187 struct GraphTraits
<MachineGadgetGraph
*>
188 : GraphTraits
<ImmutableGraph
<MachineInstr
*, int> *> {};
191 struct DOTGraphTraits
<MachineGadgetGraph
*> : DefaultDOTGraphTraits
{
192 using GraphType
= MachineGadgetGraph
;
193 using Traits
= llvm::GraphTraits
<GraphType
*>;
194 using NodeRef
= typename
Traits::NodeRef
;
195 using EdgeRef
= typename
Traits::EdgeRef
;
196 using ChildIteratorType
= typename
Traits::ChildIteratorType
;
197 using ChildEdgeIteratorType
= typename
Traits::ChildEdgeIteratorType
;
199 DOTGraphTraits(bool IsSimple
= false) : DefaultDOTGraphTraits(IsSimple
) {}
201 std::string
getNodeLabel(NodeRef Node
, GraphType
*) {
202 if (Node
->getValue() == MachineGadgetGraph::ArgNodeSentinel
)
206 raw_string_ostream
OS(Str
);
207 OS
<< *Node
->getValue();
211 static std::string
getNodeAttributes(NodeRef Node
, GraphType
*) {
212 MachineInstr
*MI
= Node
->getValue();
213 if (MI
== MachineGadgetGraph::ArgNodeSentinel
)
214 return "color = blue";
215 if (MI
->getOpcode() == X86::LFENCE
)
216 return "color = green";
220 static std::string
getEdgeAttributes(NodeRef
, ChildIteratorType E
,
222 int EdgeVal
= (*E
.getCurrent()).getValue();
223 return EdgeVal
>= 0 ? "label = " + std::to_string(EdgeVal
)
224 : "color = red, style = \"dashed\"";
228 } // end namespace llvm
230 constexpr MachineInstr
*MachineGadgetGraph::ArgNodeSentinel
;
231 constexpr int MachineGadgetGraph::GadgetEdgeSentinel
;
233 char X86LoadValueInjectionLoadHardeningPass::ID
= 0;
235 void X86LoadValueInjectionLoadHardeningPass::getAnalysisUsage(
236 AnalysisUsage
&AU
) const {
237 MachineFunctionPass::getAnalysisUsage(AU
);
238 AU
.addRequired
<MachineLoopInfoWrapperPass
>();
239 AU
.addRequired
<MachineDominatorTreeWrapperPass
>();
240 AU
.addRequired
<MachineDominanceFrontier
>();
241 AU
.setPreservesCFG();
244 static void writeGadgetGraph(raw_ostream
&OS
, MachineFunction
&MF
,
245 MachineGadgetGraph
*G
) {
246 WriteGraph(OS
, G
, /*ShortNames*/ false,
247 "Speculative gadgets for \"" + MF
.getName() + "\" function");
250 bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction(
251 MachineFunction
&MF
) {
252 LLVM_DEBUG(dbgs() << "***** " << getPassName() << " : " << MF
.getName()
254 STI
= &MF
.getSubtarget
<X86Subtarget
>();
255 if (!STI
->useLVILoadHardening())
258 // FIXME: support 32-bit
260 report_fatal_error("LVI load hardening is only supported on 64-bit", false);
262 // Don't skip functions with the "optnone" attr but participate in opt-bisect.
263 const Function
&F
= MF
.getFunction();
264 if (!F
.hasOptNone() && skipFunction(F
))
267 ++NumFunctionsConsidered
;
268 TII
= STI
->getInstrInfo();
269 TRI
= STI
->getRegisterInfo();
270 LLVM_DEBUG(dbgs() << "Building gadget graph...\n");
271 const auto &MLI
= getAnalysis
<MachineLoopInfoWrapperPass
>().getLI();
272 const auto &MDT
= getAnalysis
<MachineDominatorTreeWrapperPass
>().getDomTree();
273 const auto &MDF
= getAnalysis
<MachineDominanceFrontier
>();
274 std::unique_ptr
<MachineGadgetGraph
> Graph
= getGadgetGraph(MF
, MLI
, MDT
, MDF
);
275 LLVM_DEBUG(dbgs() << "Building gadget graph... Done\n");
276 if (Graph
== nullptr)
277 return false; // didn't find any gadgets
280 writeGadgetGraph(outs(), MF
, Graph
.get());
284 if (EmitDot
|| EmitDotOnly
) {
285 LLVM_DEBUG(dbgs() << "Emitting gadget graph...\n");
286 std::error_code FileError
;
287 std::string FileName
= "lvi.";
288 FileName
+= MF
.getName();
290 raw_fd_ostream
FileOut(FileName
, FileError
);
292 errs() << FileError
.message();
293 writeGadgetGraph(FileOut
, MF
, Graph
.get());
295 LLVM_DEBUG(dbgs() << "Emitting gadget graph... Done\n");
301 if (!OptimizePluginPath
.empty()) {
302 if (!OptimizeDL
.isValid()) {
303 std::string ErrorMsg
;
304 OptimizeDL
= llvm::sys::DynamicLibrary::getPermanentLibrary(
305 OptimizePluginPath
.c_str(), &ErrorMsg
);
306 if (!ErrorMsg
.empty())
307 report_fatal_error(Twine("Failed to load opt plugin: \"") + ErrorMsg
+
309 OptimizeCut
= (OptimizeCutT
)OptimizeDL
.getAddressOfSymbol("optimize_cut");
311 report_fatal_error("Invalid optimization plugin");
313 FencesInserted
= hardenLoadsWithPlugin(MF
, std::move(Graph
));
314 } else { // Use the default greedy heuristic
315 FencesInserted
= hardenLoadsWithHeuristic(MF
, std::move(Graph
));
318 if (FencesInserted
> 0)
319 ++NumFunctionsMitigated
;
320 NumFences
+= FencesInserted
;
321 return (FencesInserted
> 0);
324 std::unique_ptr
<MachineGadgetGraph
>
325 X86LoadValueInjectionLoadHardeningPass::getGadgetGraph(
326 MachineFunction
&MF
, const MachineLoopInfo
&MLI
,
327 const MachineDominatorTree
&MDT
,
328 const MachineDominanceFrontier
&MDF
) const {
331 // Build the Register Dataflow Graph using the RDF framework
332 DataFlowGraph DFG
{MF
, *TII
, *TRI
, MDT
, MDF
};
334 Liveness L
{MF
.getRegInfo(), DFG
};
337 GraphBuilder Builder
;
338 using GraphIter
= typename
GraphBuilder::BuilderNodeRef
;
339 DenseMap
<MachineInstr
*, GraphIter
> NodeMap
;
340 int FenceCount
= 0, GadgetCount
= 0;
341 auto MaybeAddNode
= [&NodeMap
, &Builder
](MachineInstr
*MI
) {
342 auto Ref
= NodeMap
.find(MI
);
343 if (Ref
== NodeMap
.end()) {
344 auto I
= Builder
.addVertex(MI
);
346 return std::pair
<GraphIter
, bool>{I
, true};
348 return std::pair
<GraphIter
, bool>{Ref
->getSecond(), false};
351 // The `Transmitters` map memoizes transmitters found for each def. If a def
352 // has not yet been analyzed, then it will not appear in the map. If a def
353 // has been analyzed and was determined not to have any transmitters, then
354 // its list of transmitters will be empty.
355 DenseMap
<NodeId
, std::vector
<NodeId
>> Transmitters
;
357 // Analyze all machine instructions to find gadgets and LFENCEs, adding
358 // each interesting value to `Nodes`
359 auto AnalyzeDef
= [&](NodeAddr
<DefNode
*> SourceDef
) {
360 SmallSet
<NodeId
, 8> UsesVisited
, DefsVisited
;
361 std::function
<void(NodeAddr
<DefNode
*>)> AnalyzeDefUseChain
=
362 [&](NodeAddr
<DefNode
*> Def
) {
363 if (Transmitters
.contains(Def
.Id
))
364 return; // Already analyzed `Def`
366 // Use RDF to find all the uses of `Def`
368 RegisterRef DefReg
= Def
.Addr
->getRegRef(DFG
);
369 for (auto UseID
: L
.getAllReachedUses(DefReg
, Def
)) {
370 auto Use
= DFG
.addr
<UseNode
*>(UseID
);
371 if (Use
.Addr
->getFlags() & NodeAttrs::PhiRef
) { // phi node
372 NodeAddr
<PhiNode
*> Phi
= Use
.Addr
->getOwner(DFG
);
373 for (const auto& I
: L
.getRealUses(Phi
.Id
)) {
374 if (DFG
.getPRI().alias(RegisterRef(I
.first
), DefReg
)) {
375 for (const auto &UA
: I
.second
)
376 Uses
.emplace(UA
.first
);
379 } else { // not a phi node
384 // For each use of `Def`, we want to know whether:
385 // (1) The use can leak the Def'ed value,
386 // (2) The use can further propagate the Def'ed value to more defs
387 for (auto UseID
: Uses
) {
388 if (!UsesVisited
.insert(UseID
).second
)
389 continue; // Already visited this use of `Def`
391 auto Use
= DFG
.addr
<UseNode
*>(UseID
);
392 assert(!(Use
.Addr
->getFlags() & NodeAttrs::PhiRef
));
393 MachineOperand
&UseMO
= Use
.Addr
->getOp();
394 MachineInstr
&UseMI
= *UseMO
.getParent();
395 assert(UseMO
.isReg());
397 // We naively assume that an instruction propagates any loaded
398 // uses to all defs unless the instruction is a call, in which
399 // case all arguments will be treated as gadget sources during
400 // analysis of the callee function.
404 // Check whether this use can transmit (leak) its value.
405 if (instrUsesRegToAccessMemory(UseMI
, UseMO
.getReg()) ||
406 (!NoConditionalBranches
&&
407 instrUsesRegToBranch(UseMI
, UseMO
.getReg()))) {
408 Transmitters
[Def
.Id
].push_back(Use
.Addr
->getOwner(DFG
).Id
);
410 continue; // Found a transmitting load -- no need to continue
411 // traversing its defs (i.e., this load will become
412 // a new gadget source anyways).
415 // Check whether the use propagates to more defs.
416 NodeAddr
<InstrNode
*> Owner
{Use
.Addr
->getOwner(DFG
)};
417 rdf::NodeList AnalyzedChildDefs
;
418 for (const auto &ChildDef
:
419 Owner
.Addr
->members_if(DataFlowGraph::IsDef
, DFG
)) {
420 if (!DefsVisited
.insert(ChildDef
.Id
).second
)
421 continue; // Already visited this def
422 if (Def
.Addr
->getAttrs() & NodeAttrs::Dead
)
424 if (Def
.Id
== ChildDef
.Id
)
425 continue; // `Def` uses itself (e.g., increment loop counter)
427 AnalyzeDefUseChain(ChildDef
);
429 // `Def` inherits all of its child defs' transmitters.
430 for (auto TransmitterId
: Transmitters
[ChildDef
.Id
])
431 Transmitters
[Def
.Id
].push_back(TransmitterId
);
435 // Note that this statement adds `Def.Id` to the map if no
436 // transmitters were found for `Def`.
437 auto &DefTransmitters
= Transmitters
[Def
.Id
];
439 // Remove duplicate transmitters
440 llvm::sort(DefTransmitters
);
441 DefTransmitters
.erase(llvm::unique(DefTransmitters
),
442 DefTransmitters
.end());
445 // Find all of the transmitters
446 AnalyzeDefUseChain(SourceDef
);
447 auto &SourceDefTransmitters
= Transmitters
[SourceDef
.Id
];
448 if (SourceDefTransmitters
.empty())
449 return; // No transmitters for `SourceDef`
451 MachineInstr
*Source
= SourceDef
.Addr
->getFlags() & NodeAttrs::PhiRef
452 ? MachineGadgetGraph::ArgNodeSentinel
453 : SourceDef
.Addr
->getOp().getParent();
454 auto GadgetSource
= MaybeAddNode(Source
);
455 // Each transmitter is a sink for `SourceDef`.
456 for (auto TransmitterId
: SourceDefTransmitters
) {
457 MachineInstr
*Sink
= DFG
.addr
<StmtNode
*>(TransmitterId
).Addr
->getCode();
458 auto GadgetSink
= MaybeAddNode(Sink
);
459 // Add the gadget edge to the graph.
460 Builder
.addEdge(MachineGadgetGraph::GadgetEdgeSentinel
,
461 GadgetSource
.first
, GadgetSink
.first
);
466 LLVM_DEBUG(dbgs() << "Analyzing def-use chains to find gadgets\n");
467 // Analyze function arguments
468 NodeAddr
<BlockNode
*> EntryBlock
= DFG
.getFunc().Addr
->getEntryBlock(DFG
);
469 for (NodeAddr
<PhiNode
*> ArgPhi
:
470 EntryBlock
.Addr
->members_if(DataFlowGraph::IsPhi
, DFG
)) {
471 NodeList Defs
= ArgPhi
.Addr
->members_if(DataFlowGraph::IsDef
, DFG
);
472 llvm::for_each(Defs
, AnalyzeDef
);
474 // Analyze every instruction in MF
475 for (NodeAddr
<BlockNode
*> BA
: DFG
.getFunc().Addr
->members(DFG
)) {
476 for (NodeAddr
<StmtNode
*> SA
:
477 BA
.Addr
->members_if(DataFlowGraph::IsCode
<NodeAttrs::Stmt
>, DFG
)) {
478 MachineInstr
*MI
= SA
.Addr
->getCode();
482 } else if (MI
->mayLoad()) {
483 NodeList Defs
= SA
.Addr
->members_if(DataFlowGraph::IsDef
, DFG
);
484 llvm::for_each(Defs
, AnalyzeDef
);
488 LLVM_DEBUG(dbgs() << "Found " << FenceCount
<< " fences\n");
489 LLVM_DEBUG(dbgs() << "Found " << GadgetCount
<< " gadgets\n");
490 if (GadgetCount
== 0)
492 NumGadgets
+= GadgetCount
;
494 // Traverse CFG to build the rest of the graph
495 SmallSet
<MachineBasicBlock
*, 8> BlocksVisited
;
496 std::function
<void(MachineBasicBlock
*, GraphIter
, unsigned)> TraverseCFG
=
497 [&](MachineBasicBlock
*MBB
, GraphIter GI
, unsigned ParentDepth
) {
498 unsigned LoopDepth
= MLI
.getLoopDepth(MBB
);
500 // Always add the first instruction in each block
501 auto NI
= MBB
->begin();
502 auto BeginBB
= MaybeAddNode(&*NI
);
503 Builder
.addEdge(ParentDepth
, GI
, BeginBB
.first
);
504 if (!BlocksVisited
.insert(MBB
).second
)
507 // Add any instructions within the block that are gadget components
509 while (++NI
!= MBB
->end()) {
510 auto Ref
= NodeMap
.find(&*NI
);
511 if (Ref
!= NodeMap
.end()) {
512 Builder
.addEdge(LoopDepth
, GI
, Ref
->getSecond());
513 GI
= Ref
->getSecond();
517 // Always add the terminator instruction, if one exists
518 auto T
= MBB
->getFirstTerminator();
519 if (T
!= MBB
->end()) {
520 auto EndBB
= MaybeAddNode(&*T
);
522 Builder
.addEdge(LoopDepth
, GI
, EndBB
.first
);
526 for (MachineBasicBlock
*Succ
: MBB
->successors())
527 TraverseCFG(Succ
, GI
, LoopDepth
);
529 // ArgNodeSentinel is a pseudo-instruction that represents MF args in the
531 GraphIter ArgNode
= MaybeAddNode(MachineGadgetGraph::ArgNodeSentinel
).first
;
532 TraverseCFG(&MF
.front(), ArgNode
, 0);
533 std::unique_ptr
<MachineGadgetGraph
> G
{Builder
.get(FenceCount
, GadgetCount
)};
534 LLVM_DEBUG(dbgs() << "Found " << G
->nodes_size() << " nodes\n");
538 // Returns the number of remaining gadget edges that could not be eliminated
539 int X86LoadValueInjectionLoadHardeningPass::elimMitigatedEdgesAndNodes(
540 MachineGadgetGraph
&G
, EdgeSet
&ElimEdges
/* in, out */,
541 NodeSet
&ElimNodes
/* in, out */) const {
542 if (G
.NumFences
> 0) {
543 // Eliminate fences and CFG edges that ingress and egress the fence, as
544 // they are trivially mitigated.
545 for (const Edge
&E
: G
.edges()) {
546 const Node
*Dest
= E
.getDest();
547 if (isFence(Dest
->getValue())) {
548 ElimNodes
.insert(*Dest
);
550 for (const Edge
&DE
: Dest
->edges())
551 ElimEdges
.insert(DE
);
556 // Find and eliminate gadget edges that have been mitigated.
557 int RemainingGadgets
= 0;
558 NodeSet ReachableNodes
{G
};
559 for (const Node
&RootN
: G
.nodes()) {
560 if (llvm::none_of(RootN
.edges(), MachineGadgetGraph::isGadgetEdge
))
561 continue; // skip this node if it isn't a gadget source
563 // Find all of the nodes that are CFG-reachable from RootN using DFS
564 ReachableNodes
.clear();
565 std::function
<void(const Node
*, bool)> FindReachableNodes
=
566 [&](const Node
*N
, bool FirstNode
) {
568 ReachableNodes
.insert(*N
);
569 for (const Edge
&E
: N
->edges()) {
570 const Node
*Dest
= E
.getDest();
571 if (MachineGadgetGraph::isCFGEdge(E
) && !ElimEdges
.contains(E
) &&
572 !ReachableNodes
.contains(*Dest
))
573 FindReachableNodes(Dest
, false);
576 FindReachableNodes(&RootN
, true);
578 // Any gadget whose sink is unreachable has been mitigated
579 for (const Edge
&E
: RootN
.edges()) {
580 if (MachineGadgetGraph::isGadgetEdge(E
)) {
581 if (ReachableNodes
.contains(*E
.getDest())) {
582 // This gadget's sink is reachable
584 } else { // This gadget's sink is unreachable, and therefore mitigated
590 return RemainingGadgets
;
593 std::unique_ptr
<MachineGadgetGraph
>
594 X86LoadValueInjectionLoadHardeningPass::trimMitigatedEdges(
595 std::unique_ptr
<MachineGadgetGraph
> Graph
) const {
596 NodeSet ElimNodes
{*Graph
};
597 EdgeSet ElimEdges
{*Graph
};
598 int RemainingGadgets
=
599 elimMitigatedEdgesAndNodes(*Graph
, ElimEdges
, ElimNodes
);
600 if (ElimEdges
.empty() && ElimNodes
.empty()) {
601 Graph
->NumFences
= 0;
602 Graph
->NumGadgets
= RemainingGadgets
;
604 Graph
= GraphBuilder::trim(*Graph
, ElimNodes
, ElimEdges
, 0 /* NumFences */,
610 int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithPlugin(
611 MachineFunction
&MF
, std::unique_ptr
<MachineGadgetGraph
> Graph
) const {
612 int FencesInserted
= 0;
615 LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n");
616 Graph
= trimMitigatedEdges(std::move(Graph
));
617 LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n");
618 if (Graph
->NumGadgets
== 0)
621 LLVM_DEBUG(dbgs() << "Cutting edges...\n");
622 EdgeSet CutEdges
{*Graph
};
623 auto Nodes
= std::make_unique
<unsigned int[]>(Graph
->nodes_size() +
624 1 /* terminator node */);
625 auto Edges
= std::make_unique
<unsigned int[]>(Graph
->edges_size());
626 auto EdgeCuts
= std::make_unique
<int[]>(Graph
->edges_size());
627 auto EdgeValues
= std::make_unique
<int[]>(Graph
->edges_size());
628 for (const Node
&N
: Graph
->nodes()) {
629 Nodes
[Graph
->getNodeIndex(N
)] = Graph
->getEdgeIndex(*N
.edges_begin());
631 Nodes
[Graph
->nodes_size()] = Graph
->edges_size(); // terminator node
632 for (const Edge
&E
: Graph
->edges()) {
633 Edges
[Graph
->getEdgeIndex(E
)] = Graph
->getNodeIndex(*E
.getDest());
634 EdgeValues
[Graph
->getEdgeIndex(E
)] = E
.getValue();
636 OptimizeCut(Nodes
.get(), Graph
->nodes_size(), Edges
.get(), EdgeValues
.get(),
637 EdgeCuts
.get(), Graph
->edges_size());
638 for (int I
= 0; I
< Graph
->edges_size(); ++I
)
641 LLVM_DEBUG(dbgs() << "Cutting edges... Done\n");
642 LLVM_DEBUG(dbgs() << "Cut " << CutEdges
.count() << " edges\n");
644 LLVM_DEBUG(dbgs() << "Inserting LFENCEs...\n");
645 FencesInserted
+= insertFences(MF
, *Graph
, CutEdges
);
646 LLVM_DEBUG(dbgs() << "Inserting LFENCEs... Done\n");
647 LLVM_DEBUG(dbgs() << "Inserted " << FencesInserted
<< " fences\n");
649 Graph
= GraphBuilder::trim(*Graph
, NodeSet
{*Graph
}, CutEdges
);
652 return FencesInserted
;
655 int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithHeuristic(
656 MachineFunction
&MF
, std::unique_ptr
<MachineGadgetGraph
> Graph
) const {
657 // If `MF` does not have any fences, then no gadgets would have been
658 // mitigated at this point.
659 if (Graph
->NumFences
> 0) {
660 LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n");
661 Graph
= trimMitigatedEdges(std::move(Graph
));
662 LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n");
665 if (Graph
->NumGadgets
== 0)
668 LLVM_DEBUG(dbgs() << "Cutting edges...\n");
669 EdgeSet CutEdges
{*Graph
};
671 // Begin by collecting all ingress CFG edges for each node
672 DenseMap
<const Node
*, SmallVector
<const Edge
*, 2>> IngressEdgeMap
;
673 for (const Edge
&E
: Graph
->edges())
674 if (MachineGadgetGraph::isCFGEdge(E
))
675 IngressEdgeMap
[E
.getDest()].push_back(&E
);
677 // For each gadget edge, make cuts that guarantee the gadget will be
678 // mitigated. A computationally efficient way to achieve this is to either:
679 // (a) cut all egress CFG edges from the gadget source, or
680 // (b) cut all ingress CFG edges to the gadget sink.
682 // Moreover, the algorithm tries not to make a cut into a loop by preferring
683 // to make a (b)-type cut if the gadget source resides at a greater loop depth
684 // than the gadget sink, or an (a)-type cut otherwise.
685 for (const Node
&N
: Graph
->nodes()) {
686 for (const Edge
&E
: N
.edges()) {
687 if (!MachineGadgetGraph::isGadgetEdge(E
))
690 SmallVector
<const Edge
*, 2> EgressEdges
;
691 SmallVector
<const Edge
*, 2> &IngressEdges
= IngressEdgeMap
[E
.getDest()];
692 for (const Edge
&EgressEdge
: N
.edges())
693 if (MachineGadgetGraph::isCFGEdge(EgressEdge
))
694 EgressEdges
.push_back(&EgressEdge
);
696 int EgressCutCost
= 0, IngressCutCost
= 0;
697 for (const Edge
*EgressEdge
: EgressEdges
)
698 if (!CutEdges
.contains(*EgressEdge
))
699 EgressCutCost
+= EgressEdge
->getValue();
700 for (const Edge
*IngressEdge
: IngressEdges
)
701 if (!CutEdges
.contains(*IngressEdge
))
702 IngressCutCost
+= IngressEdge
->getValue();
705 IngressCutCost
< EgressCutCost
? IngressEdges
: EgressEdges
;
706 for (const Edge
*E
: EdgesToCut
)
710 LLVM_DEBUG(dbgs() << "Cutting edges... Done\n");
711 LLVM_DEBUG(dbgs() << "Cut " << CutEdges
.count() << " edges\n");
713 LLVM_DEBUG(dbgs() << "Inserting LFENCEs...\n");
714 int FencesInserted
= insertFences(MF
, *Graph
, CutEdges
);
715 LLVM_DEBUG(dbgs() << "Inserting LFENCEs... Done\n");
716 LLVM_DEBUG(dbgs() << "Inserted " << FencesInserted
<< " fences\n");
718 return FencesInserted
;
721 int X86LoadValueInjectionLoadHardeningPass::insertFences(
722 MachineFunction
&MF
, MachineGadgetGraph
&G
,
723 EdgeSet
&CutEdges
/* in, out */) const {
724 int FencesInserted
= 0;
725 for (const Node
&N
: G
.nodes()) {
726 for (const Edge
&E
: N
.edges()) {
727 if (CutEdges
.contains(E
)) {
728 MachineInstr
*MI
= N
.getValue(), *Prev
;
729 MachineBasicBlock
*MBB
; // Insert an LFENCE in this MBB
730 MachineBasicBlock::iterator InsertionPt
; // ...at this point
731 if (MI
== MachineGadgetGraph::ArgNodeSentinel
) {
732 // insert LFENCE at beginning of entry block
734 InsertionPt
= MBB
->begin();
736 } else if (MI
->isBranch()) { // insert the LFENCE before the branch
737 MBB
= MI
->getParent();
739 Prev
= MI
->getPrevNode();
740 // Remove all egress CFG edges from this branch because the inserted
741 // LFENCE prevents gadgets from crossing the branch.
742 for (const Edge
&E
: N
.edges()) {
743 if (MachineGadgetGraph::isCFGEdge(E
))
746 } else { // insert the LFENCE after the instruction
747 MBB
= MI
->getParent();
748 InsertionPt
= MI
->getNextNode() ? MI
->getNextNode() : MBB
->end();
749 Prev
= InsertionPt
== MBB
->end()
750 ? (MBB
->empty() ? nullptr : &MBB
->back())
751 : InsertionPt
->getPrevNode();
753 // Ensure this insertion is not redundant (two LFENCEs in sequence).
754 if ((InsertionPt
== MBB
->end() || !isFence(&*InsertionPt
)) &&
755 (!Prev
|| !isFence(Prev
))) {
756 BuildMI(*MBB
, InsertionPt
, DebugLoc(), TII
->get(X86::LFENCE
));
762 return FencesInserted
;
765 bool X86LoadValueInjectionLoadHardeningPass::instrUsesRegToAccessMemory(
766 const MachineInstr
&MI
, unsigned Reg
) const {
767 if (!MI
.mayLoadOrStore() || MI
.getOpcode() == X86::MFENCE
||
768 MI
.getOpcode() == X86::SFENCE
|| MI
.getOpcode() == X86::LFENCE
)
771 const int MemRefBeginIdx
= X86::getFirstAddrOperandIdx(MI
);
772 if (MemRefBeginIdx
< 0) {
773 LLVM_DEBUG(dbgs() << "Warning: unable to obtain memory operand for loading "
775 MI
.print(dbgs()); dbgs() << '\n';);
779 const MachineOperand
&BaseMO
=
780 MI
.getOperand(MemRefBeginIdx
+ X86::AddrBaseReg
);
781 const MachineOperand
&IndexMO
=
782 MI
.getOperand(MemRefBeginIdx
+ X86::AddrIndexReg
);
783 return (BaseMO
.isReg() && BaseMO
.getReg() != X86::NoRegister
&&
784 TRI
->regsOverlap(BaseMO
.getReg(), Reg
)) ||
785 (IndexMO
.isReg() && IndexMO
.getReg() != X86::NoRegister
&&
786 TRI
->regsOverlap(IndexMO
.getReg(), Reg
));
789 bool X86LoadValueInjectionLoadHardeningPass::instrUsesRegToBranch(
790 const MachineInstr
&MI
, unsigned Reg
) const {
791 if (!MI
.isConditionalBranch())
793 for (const MachineOperand
&Use
: MI
.uses())
794 if (Use
.isReg() && Use
.getReg() == Reg
)
799 INITIALIZE_PASS_BEGIN(X86LoadValueInjectionLoadHardeningPass
, PASS_KEY
,
800 "X86 LVI load hardening", false, false)
801 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass
)
802 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass
)
803 INITIALIZE_PASS_DEPENDENCY(MachineDominanceFrontier
)
804 INITIALIZE_PASS_END(X86LoadValueInjectionLoadHardeningPass
, PASS_KEY
,
805 "X86 LVI load hardening", false, false)
807 FunctionPass
*llvm::createX86LoadValueInjectionLoadHardeningPass() {
808 return new X86LoadValueInjectionLoadHardeningPass();