4242#include " X86TargetMachine.h"
4343#include " llvm/ADT/DenseMap.h"
4444#include " llvm/ADT/DenseSet.h"
45+ #include " llvm/ADT/STLExtras.h"
4546#include " llvm/ADT/SmallSet.h"
4647#include " llvm/ADT/Statistic.h"
4748#include " llvm/ADT/StringRef.h"
@@ -104,9 +105,9 @@ static cl::opt<bool> EmitDotVerify(
104105 cl::init(false ), cl::Hidden);
105106
106107static llvm::sys::DynamicLibrary OptimizeDL;
107- typedef int (*OptimizeCutT)(unsigned int *nodes , unsigned int nodes_size ,
108- unsigned int *edges , int *edge_values ,
109- int *cut_edges /* out */ , unsigned int edges_size );
108+ typedef int (*OptimizeCutT)(unsigned int *Nodes , unsigned int NodesSize ,
109+ unsigned int *Edges , int *EdgeValues ,
110+ int *CutEdges /* out */ , unsigned int EdgesSize );
110111static OptimizeCutT OptimizeCut = nullptr ;
111112
112113namespace {
@@ -148,9 +149,10 @@ class X86LoadValueInjectionLoadHardeningPass : public MachineFunctionPass {
148149
149150private:
150151 using GraphBuilder = ImmutableGraphBuilder<MachineGadgetGraph>;
152+ using Edge = MachineGadgetGraph::Edge;
153+ using Node = MachineGadgetGraph::Node;
151154 using EdgeSet = MachineGadgetGraph::EdgeSet;
152155 using NodeSet = MachineGadgetGraph::NodeSet;
153- using Gadget = std::pair<MachineInstr *, MachineInstr *>;
154156
155157 const X86Subtarget *STI;
156158 const TargetInstrInfo *TII;
@@ -162,8 +164,8 @@ class X86LoadValueInjectionLoadHardeningPass : public MachineFunctionPass {
162164 const MachineDominanceFrontier &MDF) const ;
163165 int hardenLoadsWithPlugin (MachineFunction &MF,
164166 std::unique_ptr<MachineGadgetGraph> Graph) const ;
165- int hardenLoadsWithGreedyHeuristic (
166- MachineFunction &MF, std::unique_ptr<MachineGadgetGraph> Graph) const ;
167+ int hardenLoadsWithHeuristic (MachineFunction &MF,
168+ std::unique_ptr<MachineGadgetGraph> Graph) const ;
167169 int elimMitigatedEdgesAndNodes (MachineGadgetGraph &G,
168170 EdgeSet &ElimEdges /* in, out */ ,
169171 NodeSet &ElimNodes /* in, out */ ) const ;
@@ -198,7 +200,7 @@ struct DOTGraphTraits<MachineGadgetGraph *> : DefaultDOTGraphTraits {
198200 using ChildIteratorType = typename Traits::ChildIteratorType;
199201 using ChildEdgeIteratorType = typename Traits::ChildEdgeIteratorType;
200202
201- DOTGraphTraits (bool isSimple = false ) : DefaultDOTGraphTraits(isSimple ) {}
203+ DOTGraphTraits (bool IsSimple = false ) : DefaultDOTGraphTraits(IsSimple ) {}
202204
203205 std::string getNodeLabel (NodeRef Node, GraphType *) {
204206 if (Node->getValue () == MachineGadgetGraph::ArgNodeSentinel)
@@ -243,7 +245,7 @@ void X86LoadValueInjectionLoadHardeningPass::getAnalysisUsage(
243245 AU.setPreservesCFG ();
244246}
245247
246- static void WriteGadgetGraph (raw_ostream &OS, MachineFunction &MF,
248+ static void writeGadgetGraph (raw_ostream &OS, MachineFunction &MF,
247249 MachineGadgetGraph *G) {
248250 WriteGraph (OS, G, /* ShortNames*/ false ,
249251 " Speculative gadgets for \" " + MF.getName () + " \" function" );
@@ -279,7 +281,7 @@ bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction(
279281 return false ; // didn't find any gadgets
280282
281283 if (EmitDotVerify) {
282- WriteGadgetGraph (outs (), MF, Graph.get ());
284+ writeGadgetGraph (outs (), MF, Graph.get ());
283285 return false ;
284286 }
285287
@@ -292,7 +294,7 @@ bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction(
292294 raw_fd_ostream FileOut (FileName, FileError);
293295 if (FileError)
294296 errs () << FileError.message ();
295- WriteGadgetGraph (FileOut, MF, Graph.get ());
297+ writeGadgetGraph (FileOut, MF, Graph.get ());
296298 FileOut.close ();
297299 LLVM_DEBUG (dbgs () << " Emitting gadget graph... Done\n " );
298300 if (EmitDotOnly)
@@ -313,7 +315,7 @@ bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction(
313315 }
314316 FencesInserted = hardenLoadsWithPlugin (MF, std::move (Graph));
315317 } else { // Use the default greedy heuristic
316- FencesInserted = hardenLoadsWithGreedyHeuristic (MF, std::move (Graph));
318+ FencesInserted = hardenLoadsWithHeuristic (MF, std::move (Graph));
317319 }
318320
319321 if (FencesInserted > 0 )
@@ -540,47 +542,46 @@ X86LoadValueInjectionLoadHardeningPass::getGadgetGraph(
540542
541543// Returns the number of remaining gadget edges that could not be eliminated
542544int X86LoadValueInjectionLoadHardeningPass::elimMitigatedEdgesAndNodes (
543- MachineGadgetGraph &G, MachineGadgetGraph:: EdgeSet &ElimEdges /* in, out */ ,
544- MachineGadgetGraph:: NodeSet &ElimNodes /* in, out */ ) const {
545+ MachineGadgetGraph &G, EdgeSet &ElimEdges /* in, out */ ,
546+ NodeSet &ElimNodes /* in, out */ ) const {
545547 if (G.NumFences > 0 ) {
546548 // Eliminate fences and CFG edges that ingress and egress the fence, as
547549 // they are trivially mitigated.
548- for (const auto &E : G.edges ()) {
549- const MachineGadgetGraph:: Node *Dest = E.getDest ();
550+ for (const Edge &E : G.edges ()) {
551+ const Node *Dest = E.getDest ();
550552 if (isFence (Dest->getValue ())) {
551553 ElimNodes.insert (*Dest);
552554 ElimEdges.insert (E);
553- for (const auto &DE : Dest->edges ())
555+ for (const Edge &DE : Dest->edges ())
554556 ElimEdges.insert (DE);
555557 }
556558 }
557559 }
558560
559561 // Find and eliminate gadget edges that have been mitigated.
560562 int MitigatedGadgets = 0 , RemainingGadgets = 0 ;
561- MachineGadgetGraph:: NodeSet ReachableNodes{G};
562- for (const auto &RootN : G.nodes ()) {
563+ NodeSet ReachableNodes{G};
564+ for (const Node &RootN : G.nodes ()) {
563565 if (llvm::none_of (RootN.edges (), MachineGadgetGraph::isGadgetEdge))
564566 continue ; // skip this node if it isn't a gadget source
565567
566568 // Find all of the nodes that are CFG-reachable from RootN using DFS
567569 ReachableNodes.clear ();
568- std::function<void (const MachineGadgetGraph::Node *, bool )>
569- FindReachableNodes =
570- [&](const MachineGadgetGraph::Node *N, bool FirstNode) {
571- if (!FirstNode)
572- ReachableNodes.insert (*N);
573- for (const auto &E : N->edges ()) {
574- const MachineGadgetGraph::Node *Dest = E.getDest ();
575- if (MachineGadgetGraph::isCFGEdge (E) &&
576- !ElimEdges.contains (E) && !ReachableNodes.contains (*Dest))
577- FindReachableNodes (Dest, false );
578- }
579- };
570+ std::function<void (const Node *, bool )> FindReachableNodes =
571+ [&](const Node *N, bool FirstNode) {
572+ if (!FirstNode)
573+ ReachableNodes.insert (*N);
574+ for (const Edge &E : N->edges ()) {
575+ const Node *Dest = E.getDest ();
576+ if (MachineGadgetGraph::isCFGEdge (E) && !ElimEdges.contains (E) &&
577+ !ReachableNodes.contains (*Dest))
578+ FindReachableNodes (Dest, false );
579+ }
580+ };
580581 FindReachableNodes (&RootN, true );
581582
582583 // Any gadget whose sink is unreachable has been mitigated
583- for (const auto &E : RootN.edges ()) {
584+ for (const Edge &E : RootN.edges ()) {
584585 if (MachineGadgetGraph::isGadgetEdge (E)) {
585586 if (ReachableNodes.contains (*E.getDest ())) {
586587 // This gadget's sink is reachable
@@ -598,8 +599,8 @@ int X86LoadValueInjectionLoadHardeningPass::elimMitigatedEdgesAndNodes(
598599std::unique_ptr<MachineGadgetGraph>
599600X86LoadValueInjectionLoadHardeningPass::trimMitigatedEdges (
600601 std::unique_ptr<MachineGadgetGraph> Graph) const {
601- MachineGadgetGraph:: NodeSet ElimNodes{*Graph};
602- MachineGadgetGraph:: EdgeSet ElimEdges{*Graph};
602+ NodeSet ElimNodes{*Graph};
603+ EdgeSet ElimEdges{*Graph};
603604 int RemainingGadgets =
604605 elimMitigatedEdgesAndNodes (*Graph, ElimEdges, ElimNodes);
605606 if (ElimEdges.empty () && ElimNodes.empty ()) {
@@ -630,11 +631,11 @@ int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithPlugin(
630631 auto Edges = std::make_unique<unsigned int []>(Graph->edges_size ());
631632 auto EdgeCuts = std::make_unique<int []>(Graph->edges_size ());
632633 auto EdgeValues = std::make_unique<int []>(Graph->edges_size ());
633- for (const auto &N : Graph->nodes ()) {
634+ for (const Node &N : Graph->nodes ()) {
634635 Nodes[Graph->getNodeIndex (N)] = Graph->getEdgeIndex (*N.edges_begin ());
635636 }
636637 Nodes[Graph->nodes_size ()] = Graph->edges_size (); // terminator node
637- for (const auto &E : Graph->edges ()) {
638+ for (const Edge &E : Graph->edges ()) {
638639 Edges[Graph->getEdgeIndex (E)] = Graph->getNodeIndex (*E.getDest ());
639640 EdgeValues[Graph->getEdgeIndex (E)] = E.getValue ();
640641 }
@@ -651,74 +652,67 @@ int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithPlugin(
651652 LLVM_DEBUG (dbgs () << " Inserting LFENCEs... Done\n " );
652653 LLVM_DEBUG (dbgs () << " Inserted " << FencesInserted << " fences\n " );
653654
654- Graph = GraphBuilder::trim (*Graph, MachineGadgetGraph::NodeSet{*Graph},
655- CutEdges);
655+ Graph = GraphBuilder::trim (*Graph, NodeSet{*Graph}, CutEdges);
656656 } while (true );
657657
658658 return FencesInserted;
659659}
660660
661- int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithGreedyHeuristic (
661+ int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithHeuristic (
662662 MachineFunction &MF, std::unique_ptr<MachineGadgetGraph> Graph) const {
663- LLVM_DEBUG (dbgs () << " Eliminating mitigated paths...\n " );
664- Graph = trimMitigatedEdges (std::move (Graph));
665- LLVM_DEBUG (dbgs () << " Eliminating mitigated paths... Done\n " );
663+ // If `MF` does not have any fences, then no gadgets would have been
664+ // mitigated at this point.
665+ if (Graph->NumFences > 0 ) {
666+ LLVM_DEBUG (dbgs () << " Eliminating mitigated paths...\n " );
667+ Graph = trimMitigatedEdges (std::move (Graph));
668+ LLVM_DEBUG (dbgs () << " Eliminating mitigated paths... Done\n " );
669+ }
670+
666671 if (Graph->NumGadgets == 0 )
667672 return 0 ;
668673
669674 LLVM_DEBUG (dbgs () << " Cutting edges...\n " );
670- MachineGadgetGraph::NodeSet ElimNodes{*Graph}, GadgetSinks{*Graph};
671- MachineGadgetGraph::EdgeSet ElimEdges{*Graph}, CutEdges{*Graph};
672- auto IsCFGEdge = [&ElimEdges, &CutEdges](const MachineGadgetGraph::Edge &E) {
673- return !ElimEdges.contains (E) && !CutEdges.contains (E) &&
674- MachineGadgetGraph::isCFGEdge (E);
675- };
676- auto IsGadgetEdge = [&ElimEdges,
677- &CutEdges](const MachineGadgetGraph::Edge &E) {
678- return !ElimEdges.contains (E) && !CutEdges.contains (E) &&
679- MachineGadgetGraph::isGadgetEdge (E);
680- };
681-
682- // FIXME: this is O(E^2), we could probably do better.
683- do {
684- // Find the cheapest CFG edge that will eliminate a gadget (by being
685- // egress from a SOURCE node or ingress to a SINK node), and cut it.
686- const MachineGadgetGraph::Edge *CheapestSoFar = nullptr ;
687-
688- // First, collect all gadget source and sink nodes.
689- MachineGadgetGraph::NodeSet GadgetSources{*Graph}, GadgetSinks{*Graph};
690- for (const auto &N : Graph->nodes ()) {
691- if (ElimNodes.contains (N))
675+ EdgeSet CutEdges{*Graph};
676+
677+ // Begin by collecting all ingress CFG edges for each node
678+ DenseMap<const Node *, SmallVector<const Edge *, 2 >> IngressEdgeMap;
679+ for (const Edge &E : Graph->edges ())
680+ if (MachineGadgetGraph::isCFGEdge (E))
681+ IngressEdgeMap[E.getDest ()].push_back (&E);
682+
683+ // For each gadget edge, make cuts that guarantee the gadget will be
684+ // mitigated. A computationally efficient way to achieve this is to either:
685+ // (a) cut all egress CFG edges from the gadget source, or
686+ // (b) cut all ingress CFG edges to the gadget sink.
687+ //
688+ // Moreover, the algorithm tries not to make a cut into a loop by preferring
689+ // to make a (b)-type cut if the gadget source resides at a greater loop depth
690+ // than the gadget sink, or an (a)-type cut otherwise.
691+ for (const Node &N : Graph->nodes ()) {
692+ for (const Edge &E : N.edges ()) {
693+ if (!MachineGadgetGraph::isGadgetEdge (E))
692694 continue ;
693- for (const auto &E : N.edges ()) {
694- if (IsGadgetEdge (E)) {
695- GadgetSources.insert (N);
696- GadgetSinks.insert (*E.getDest ());
697- }
698- }
699- }
700695
701- // Next, look for the cheapest CFG edge which, when cut, is guaranteed to
702- // mitigate at least one gadget by either:
703- // (a) being egress from a gadget source, or
704- // (b) being ingress to a gadget sink.
705- for (const auto &N : Graph->nodes ()) {
706- if (ElimNodes.contains (N))
707- continue ;
708- for (const auto &E : N.edges ()) {
709- if (IsCFGEdge (E)) {
710- if (GadgetSources.contains (N) || GadgetSinks.contains (*E.getDest ())) {
711- if (!CheapestSoFar || E.getValue () < CheapestSoFar->getValue ())
712- CheapestSoFar = &E;
713- }
714- }
715- }
696+ SmallVector<const Edge *, 2 > EgressEdges;
697+ SmallVector<const Edge *, 2 > &IngressEdges = IngressEdgeMap[E.getDest ()];
698+ for (const Edge &EgressEdge : N.edges ())
699+ if (MachineGadgetGraph::isCFGEdge (EgressEdge))
700+ EgressEdges.push_back (&EgressEdge);
701+
702+ int EgressCutCost = 0 , IngressCutCost = 0 ;
703+ for (const Edge *EgressEdge : EgressEdges)
704+ if (!CutEdges.contains (*EgressEdge))
705+ EgressCutCost += EgressEdge->getValue ();
706+ for (const Edge *IngressEdge : IngressEdges)
707+ if (!CutEdges.contains (*IngressEdge))
708+ IngressCutCost += IngressEdge->getValue ();
709+
710+ auto &EdgesToCut =
711+ IngressCutCost < EgressCutCost ? IngressEdges : EgressEdges;
712+ for (const Edge *E : EdgesToCut)
713+ CutEdges.insert (*E);
716714 }
717-
718- assert (CheapestSoFar && " Failed to cut an edge" );
719- CutEdges.insert (*CheapestSoFar);
720- ElimEdges.insert (*CheapestSoFar);
721- } while (elimMitigatedEdgesAndNodes (*Graph, ElimEdges, ElimNodes));
715+ }
722716 LLVM_DEBUG (dbgs () << " Cutting edges... Done\n " );
723717 LLVM_DEBUG (dbgs () << " Cut " << CutEdges.count () << " edges\n " );
724718
@@ -734,8 +728,8 @@ int X86LoadValueInjectionLoadHardeningPass::insertFences(
734728 MachineFunction &MF, MachineGadgetGraph &G,
735729 EdgeSet &CutEdges /* in, out */ ) const {
736730 int FencesInserted = 0 ;
737- for (const auto &N : G.nodes ()) {
738- for (const auto &E : N.edges ()) {
731+ for (const Node &N : G.nodes ()) {
732+ for (const Edge &E : N.edges ()) {
739733 if (CutEdges.contains (E)) {
740734 MachineInstr *MI = N.getValue (), *Prev;
741735 MachineBasicBlock *MBB; // Insert an LFENCE in this MBB
@@ -751,7 +745,7 @@ int X86LoadValueInjectionLoadHardeningPass::insertFences(
751745 Prev = MI->getPrevNode ();
752746 // Remove all egress CFG edges from this branch because the inserted
753747 // LFENCE prevents gadgets from crossing the branch.
754- for (const auto &E : N.edges ()) {
748+ for (const Edge &E : N.edges ()) {
755749 if (MachineGadgetGraph::isCFGEdge (E))
756750 CutEdges.insert (E);
757751 }
0 commit comments