8 #include <boost/format.hpp>
9 #include <boost/log/trivial.hpp>
10 #include <boost/graph/breadth_first_search.hpp>
13 #include "instruction.h"
14 #include "checkpoint.h"
18 /*================================================================================================*/
20 extern bool in_tainting
;
22 extern std::map
<ADDRINT
, instruction
> addr_ins_static_map
;
23 extern std::map
<UINT32
, instruction
> order_ins_dynamic_map
;
25 extern UINT32 total_rollback_times
;
27 extern vdep_graph dta_graph
;
28 extern map_ins_io dta_inss_io
;
30 extern std::vector
<ADDRINT
> explored_trace
;
32 extern std::map
<UINT32
, ptr_branch
> order_input_dep_ptr_branch_map
;
33 extern std::map
<UINT32
, ptr_branch
> order_input_indep_ptr_branch_map
;
34 extern std::map
<UINT32
, ptr_branch
> order_tainted_ptr_branch_map
;
36 extern ptr_branch exploring_ptr_branch
;
38 extern std::vector
<ptr_branch
> total_input_dep_ptr_branches
;
40 // extern UINT32 input_dep_branch_num;
42 extern std::vector
<ptr_checkpoint
> saved_ptr_checkpoints
;
43 extern ptr_checkpoint master_ptr_checkpoint
;
45 extern std::map
< UINT32
,
46 std::vector
<ptr_checkpoint
> > exepoint_checkpoints_map
;
48 extern UINT8 received_msg_num
;
49 extern ADDRINT received_msg_addr
;
50 extern UINT32 received_msg_size
;
51 extern ADDRINT received_msg_struct_addr
;
53 extern KNOB
<BOOL
> print_debug_text
;
54 extern KNOB
<UINT32
> max_trace_length
;
56 extern UINT32 max_trace_size
;
58 /*================================================================================================*/
60 static std::map
<vdep_vertex_desc
,
61 vdep_vertex_desc
> prec_vertex_desc
;
63 static bool function_has_been_called
= false;
65 /*================================================================================================*/
67 class dep_bfs_visitor
: public boost::default_bfs_visitor
70 template <typename Edge
, typename Graph
>
71 void tree_edge(Edge e
, const Graph
& g
)
73 prec_vertex_desc
[boost::target(e
, g
)] = boost::source(e
, g
);
77 /*================================================================================================*/
79 inline void mark_resolved(ptr_branch
& omitted_ptr_branch
)
81 omitted_ptr_branch
->is_resolved
= true;
82 omitted_ptr_branch
->is_bypassed
= false;
83 omitted_ptr_branch
->is_just_resolved
= false;
88 /*================================================================================================*/
90 std::vector
<UINT32
> backward_trace(vdep_vertex_desc root_vertex
, vdep_vertex_desc last_vertex
)
92 std::vector
<UINT32
> backward_trace
;
94 vdep_vertex_desc current_vertex
= last_vertex
;
95 vdep_vertex_desc backward_vertex
;
97 vdep_edge_desc current_edge
;
100 while (current_vertex
!= root_vertex
)
102 backward_vertex
= prec_vertex_desc
[current_vertex
];
104 boost::tie(current_edge
, edge_exist
) = boost::edge(backward_vertex
,
105 current_vertex
, dta_graph
);
108 backward_trace
.push_back(dta_graph
[current_edge
].second
);
112 BOOST_LOG_TRIVIAL(fatal
) << "Edge not found in backward trace construction.";
113 PIN_ExitApplication(0);
116 current_vertex
= backward_vertex
;
119 return backward_trace
;
122 /*================================================================================================*/
124 inline void compute_branch_mem_dependency()
126 vdep_vertex_iter vertex_iter
;
127 vdep_vertex_iter last_vertex_iter
;
129 vdep_edge_desc edge_desc
;
132 dep_bfs_visitor dep_vis
;
134 std::map
<UINT32
, ptr_branch
>::iterator order_ptr_branch_iter
;
135 ptr_branch current_ptr_branch
;
137 std::map
<vdep_vertex_desc
, vdep_vertex_desc
>::iterator prec_vertex_iter
;
139 ADDRINT current_addr
;
141 boost::tie(vertex_iter
, last_vertex_iter
) = boost::vertices(dta_graph
);
142 for (; vertex_iter
!= last_vertex_iter
; ++vertex_iter
)
144 if (dta_graph
[*vertex_iter
].type
== MEM_VAR
)
146 // std::map<vdep_vertex_desc, vdep_vertex_desc>().swap(prec_vertex_desc);
147 prec_vertex_desc
.clear();
149 boost::breadth_first_search(dta_graph
, *vertex_iter
, boost::visitor(dep_vis
));
151 for (prec_vertex_iter
= prec_vertex_desc
.begin();
152 prec_vertex_iter
!= prec_vertex_desc
.end(); ++prec_vertex_iter
)
154 boost::tie(edge_desc
, edge_exist
) = boost::edge(prec_vertex_iter
->second
,
155 prec_vertex_iter
->first
,
159 order_ptr_branch_iter
= order_tainted_ptr_branch_map
.begin();
160 for (; order_ptr_branch_iter
!= order_tainted_ptr_branch_map
.end();
161 ++order_ptr_branch_iter
)
163 current_ptr_branch
= order_ptr_branch_iter
->second
;
164 if (dta_graph
[edge_desc
].second
== current_ptr_branch
->trace
.size())
166 current_addr
= dta_graph
[*vertex_iter
].mem
;
168 if ((received_msg_addr
<= current_addr
) &&
169 (current_addr
< received_msg_addr
+ received_msg_size
))
171 current_ptr_branch
->dep_input_addrs
.insert(current_addr
);
175 current_ptr_branch
->dep_other_addrs
.insert(current_addr
);
178 current_ptr_branch
->dep_backward_traces
[current_addr
]
179 = backward_trace(*vertex_iter
, prec_vertex_iter
->second
);
185 BOOST_LOG_TRIVIAL(fatal
) << "Backward edge not found in BFS.";
186 PIN_ExitApplication(0);
192 order_ptr_branch_iter
= order_tainted_ptr_branch_map
.begin();
193 for (; order_ptr_branch_iter
!= order_tainted_ptr_branch_map
.end();
194 ++order_ptr_branch_iter
)
196 current_ptr_branch
= order_ptr_branch_iter
->second
;
197 if (!current_ptr_branch
->dep_input_addrs
.empty())
199 order_input_dep_ptr_branch_map
[current_ptr_branch
->trace
.size()]
200 = current_ptr_branch
;
202 if (exploring_ptr_branch
)
204 if (current_ptr_branch
->trace
.size() > exploring_ptr_branch
->trace
.size())
206 total_input_dep_ptr_branches
.push_back(current_ptr_branch
);
211 total_input_dep_ptr_branches
.push_back(current_ptr_branch
);
216 order_input_indep_ptr_branch_map
[current_ptr_branch
->trace
.size()]
217 = current_ptr_branch
;
224 /*================================================================================================*/
226 inline void compute_branch_min_checkpoint()
228 std::vector
<ptr_checkpoint
>::iterator ptr_checkpoint_iter
;
229 std::vector
<ptr_checkpoint
>::reverse_iterator ptr_checkpoint_reverse_iter
;
230 std::set
<ADDRINT
>::iterator addr_iter
;
231 std::map
<UINT32
, ptr_branch
>::iterator order_ptr_branch_iter
;
233 ptr_branch current_ptr_branch
;
234 ptr_checkpoint nearest_ptr_checkpoint
;
236 bool nearest_checkpoint_found
;
237 std::set
<ADDRINT
> intersec_mems
;
239 order_ptr_branch_iter
= order_tainted_ptr_branch_map
.begin();
240 for (; order_ptr_branch_iter
!= order_tainted_ptr_branch_map
.end(); ++order_ptr_branch_iter
)
242 current_ptr_branch
= order_ptr_branch_iter
->second
;
243 if (current_ptr_branch
->dep_input_addrs
.empty())
245 current_ptr_branch
->checkpoint
.reset();
247 else // compute the nearest checkpoint for current_ptr_branch
249 // for each *addr_iter in current_ptr_branch->dep_input_addrs,
250 // find the earliest checkpoint that uses it
251 addr_iter
= current_ptr_branch
->dep_input_addrs
.begin();
252 for (; addr_iter
!= current_ptr_branch
->dep_input_addrs
.end(); ++addr_iter
)
254 ptr_checkpoint_iter
= saved_ptr_checkpoints
.begin();
255 for (; ptr_checkpoint_iter
!= saved_ptr_checkpoints
.end(); ++ptr_checkpoint_iter
)
257 nearest_checkpoint_found
= false;
258 // *addr_iter is found in (*ptr_checkpoint_iter)->dep_mems
259 if (std::find((*ptr_checkpoint_iter
)->dep_mems
.begin(),
260 (*ptr_checkpoint_iter
)->dep_mems
.end(), *addr_iter
)
261 != (*ptr_checkpoint_iter
)->dep_mems
.end())
263 nearest_checkpoint_found
= true;
264 current_ptr_branch
->nearest_checkpoints
[*ptr_checkpoint_iter
].insert(*addr_iter
);
269 // find the ideal checkpoint by finding reversely the checkpoint list
270 if (nearest_checkpoint_found
)
272 ptr_checkpoint_reverse_iter
= saved_ptr_checkpoints
.rbegin();
273 for (; ptr_checkpoint_reverse_iter
!= saved_ptr_checkpoints
.rend();
274 ++ptr_checkpoint_reverse_iter
)
276 if ((*ptr_checkpoint_reverse_iter
)->trace
.size() < current_ptr_branch
->trace
.size())
278 if (std::find((*ptr_checkpoint_reverse_iter
)->dep_mems
.begin(),
279 (*ptr_checkpoint_reverse_iter
)->dep_mems
.end(), *addr_iter
)
280 != (*ptr_checkpoint_reverse_iter
)->dep_mems
.end())
282 current_ptr_branch
->econ_execution_length
[*ptr_checkpoint_iter
] =
283 (*ptr_checkpoint_reverse_iter
)->trace
.size() - (*ptr_checkpoint_iter
)->trace
.size();
284 // std::cout << current_ptr_branch->econ_execution_length[*ptr_checkpoint_iter] << std::endl;
293 if (current_ptr_branch
->nearest_checkpoints
.size() != 0)
295 BOOST_LOG_TRIVIAL(info
)
296 << boost::format("The branch at %d:%d (%s: %s) has %d nearest checkpoints.")
297 % current_ptr_branch
->trace
.size()
298 % current_ptr_branch
->br_taken
299 % remove_leading_zeros(StringFromAddrint(current_ptr_branch
->addr
))
300 % order_ins_dynamic_map
[current_ptr_branch
->trace
.size()].disass
301 % current_ptr_branch
->nearest_checkpoints
.size();
303 current_ptr_branch
->checkpoint
= current_ptr_branch
->nearest_checkpoints
.rbegin()->first
;
307 BOOST_LOG_TRIVIAL(fatal
)
308 << boost::format("Cannot found any nearest checkpoint for the branch at %d.!")
309 % current_ptr_branch
->trace
.size();
311 PIN_ExitApplication(0);
319 /*================================================================================================*/
321 inline void prepare_new_rollbacking_phase()
323 BOOST_LOG_TRIVIAL(info
)
324 << boost::format("\033[33mStop exploring, %d instructions analyzed. Start detecting checkpoints\033[0m")
325 % explored_trace
.size();
327 // journal_tainting_graph("tainting_graph.dot");
328 // PIN_ExitApplication(0);
330 compute_branch_mem_dependency();
331 compute_branch_min_checkpoint();
333 BOOST_LOG_TRIVIAL(info
)
334 << boost::format("\033[33mStop detecting, %d checkpoints and %d/%d branches detected. Start rollbacking.\033[0m")
335 % saved_ptr_checkpoints
.size()
336 % order_input_dep_ptr_branch_map
.size()
337 % order_tainted_ptr_branch_map
.size();
339 // journal_tainting_log();
342 PIN_RemoveInstrumentation();
344 if (exploring_ptr_branch
)
346 rollback_with_input_replacement(saved_ptr_checkpoints
[0],
348 ->inputs
[!exploring_ptr_branch
->br_taken
][0].get());
352 journal_tainting_graph("tainting_graph.dot");
353 journal_explored_trace("explored_trace.log");
354 // journal_static_trace("static_trace");
356 // the first rollbacking phase
357 if (!order_input_dep_ptr_branch_map
.empty())
359 ptr_branch first_ptr_branch
= order_input_dep_ptr_branch_map
.begin()->second
;
360 rollback_with_input_replacement(saved_ptr_checkpoints
[0],
362 ->inputs
[first_ptr_branch
->br_taken
][0].get());
366 BOOST_LOG_TRIVIAL(info
) << "There is no branch needed to resolve.";
367 PIN_ExitApplication(0);
374 /*================================================================================================*/
376 VOID
logging_syscall_instruction_analyzer(ADDRINT ins_addr
)
378 prepare_new_rollbacking_phase();
382 /*================================================================================================*/
384 VOID
logging_general_instruction_analyzer(ADDRINT ins_addr
)
386 if ((explored_trace
.size() < max_trace_size
) &&
387 (!addr_ins_static_map
[ins_addr
].contained_image
.empty()))
389 explored_trace
.push_back(ins_addr
);
390 order_ins_dynamic_map
[explored_trace
.size()] = addr_ins_static_map
[ins_addr
];
391 std::cout
<< addr_ins_static_map
[ins_addr
].disass
<< "\n";
393 else // trace length limit reached
395 prepare_new_rollbacking_phase();
401 /*================================================================================================*/
403 VOID
logging_mem_read_instruction_analyzer(ADDRINT ins_addr
,
404 ADDRINT mem_read_addr
, UINT32 mem_read_size
,
407 // a new checkpoint found
408 if (std::max(mem_read_addr
, received_msg_addr
) <
409 std::min(mem_read_addr
+ mem_read_size
, received_msg_addr
+ received_msg_size
))
411 ptr_checkpoint
new_ptr_checkpoint(new checkpoint(ins_addr
, p_ctxt
, explored_trace
,
412 mem_read_addr
, mem_read_size
));
413 saved_ptr_checkpoints
.push_back(new_ptr_checkpoint
);
415 BOOST_LOG_TRIVIAL(trace
)
416 << boost::format("Checkpoint detected at %d (%s).")
417 % new_ptr_checkpoint
->trace
.size() % addr_ins_static_map
[ins_addr
].disass
;
420 for (UINT32 idx
= 0; idx
< mem_read_size
; ++idx
)
422 order_ins_dynamic_map
[explored_trace
.size()].src_mems
.insert(mem_read_addr
+ idx
);
428 /*================================================================================================*/
430 VOID
logging_mem_write_instruction_analyzer(ADDRINT ins_addr
,
431 ADDRINT mem_written_addr
, UINT32 mem_written_size
)
433 if (!saved_ptr_checkpoints
.empty())
435 saved_ptr_checkpoints
[0]->mem_written_logging(ins_addr
,
436 mem_written_addr
, mem_written_size
);
439 exepoint_checkpoints_map
[explored_trace
.size()] = saved_ptr_checkpoints
;
441 for (UINT32 idx
= 0; idx
< mem_written_size
; ++idx
)
443 order_ins_dynamic_map
[explored_trace
.size()].dst_mems
.insert(mem_written_addr
+ idx
);
449 /*================================================================================================*/
451 VOID
logging_cond_br_analyzer(ADDRINT ins_addr
, bool br_taken
)
453 ptr_branch
new_ptr_branch(new branch(ins_addr
, br_taken
));
455 // save the first input
456 store_input(new_ptr_branch
, br_taken
);
458 // verify if the branch is a new tainted branch
459 if (exploring_ptr_branch
&&
460 (new_ptr_branch
->trace
.size() <= exploring_ptr_branch
->trace
.size()))
462 // mark it as resolved
463 mark_resolved(new_ptr_branch
);
466 order_tainted_ptr_branch_map
[explored_trace
.size()] = new_ptr_branch
;
471 /*================================================================================================*/
473 VOID
logging_before_recv_functions_analyzer(ADDRINT msg_addr
)
475 std::cout
<< "msg_addr logged\n";
476 received_msg_addr
= msg_addr
;
480 VOID
logging_after_recv_functions_analyzer(UINT32 msg_length
)
484 std::cout
<< "msg_size logged in recv\n";
485 received_msg_num
++; received_msg_size
= msg_length
;
490 /*================================================================================================*/
493 #include <WinSock2.h>
496 VOID
logging_before_wsarecv_functions_analyzer(ADDRINT msg_struct_adddr
)
498 std::cout
<< "before\n" << std::fflush
;
499 received_msg_struct_addr
= msg_struct_adddr
;
500 received_msg_addr
= reinterpret_cast<ADDRINT
>((reinterpret_cast<WINDOWS::LPWSABUF
>(received_msg_struct_addr
))->buf
);
501 function_has_been_called
= true;
505 VOID
logging_after_wsarecv_funtions_analyzer()
507 if (function_has_been_called
)
509 std::cout
<< "after\n";
510 received_msg_size
= (reinterpret_cast<WINDOWS::LPWSABUF
>(received_msg_struct_addr
))->len
;
511 std::cerr
<< received_msg_size
<< "\n";
512 if (received_msg_size
> 0)
516 function_has_been_called
= false;