2 from graph
import Graph
10 def __init__(self
, n
):
16 UNK
= Singleton("UNK")
17 DYN
= Singleton("DYN")
21 def __init__(self
, addr
):
29 def __getitem__(self
, i
):
32 def def_addrs(self
, regs_only
=True):
33 """Return all variable definitions for this basic block,
34 as set of (var, inst_addr) pairs. Note that this includes
35 multiple entries for the same var, if it is redefined
36 multiple times within the basic block.
40 inst_defs
= i
.defs(regs_only
)
45 def defs(self
, regs_only
=True):
46 """Return set of all variables defined in this basic block."""
49 defs |
= i
.defs(regs_only
)
53 """Return set of all variables used in this basic block."""
60 return "%s(%s)" % (self
.__class
__.__name
__, self
.addr
)
62 def write(self
, stream
, indent
, s
):
63 for l
in str(s
).splitlines():
64 stream
.write(" " * indent
)
65 stream
.write(l
+ "\n")
67 def dump(self
, stream
, indent
=0, printer
=str):
71 self
.write(stream
, indent
, out
)
73 TYPE_SORT
= ("REG", "ADDR", "MEM", "EXPR", "COND", "VALUE")
76 return TYPE_SORT
.index(t
.__name
__)
78 # Helper predicates for types below
81 return isinstance(e
, VALUE
)
84 return isinstance(e
, ADDR
)
87 return isinstance(e
, REG
)
90 return isinstance(e
, MEM
)
93 return isinstance(e
, EXPR
)
96 return isinstance(e
, COND
)
98 def is_sfunc(e
, name
):
99 return is_expr(e
) and e
.args
[0] == SFUNC(name
)
103 # Something which is a simple expression
109 "Get registers referenced by the expression"
112 def side_effect(self
):
119 class REG(SimpleExpr
):
121 def __init__(self
, name
):
122 #assert isinstance(name, str)
128 return self
.__str
__()
129 type = "REG_S" if self
.signed
else "REG"
130 return self
.comment
+ type + "(%s)" % self
.name
133 cast
= "(i32)" if self
.signed
else ""
134 return self
.comment
+ cast
+ "$" + str(self
.name
)
136 def __eq__(self
, other
):
137 return type(self
) == type(other
) and self
.name
== other
.name
139 def __lt__(self
, other
):
140 if type(self
) != type(other
):
141 return type_sort(type(self
)) < type_sort(type(other
))
143 n1
= utils
.natural_sort_key(self
.name
)
144 n2
= utils
.natural_sort_key(other
.name
)
147 def __contains__(self
, other
):
151 return hash(self
.name
)
157 class VALUE(SimpleExpr
):
159 def __init__(self
, val
, base
=16):
165 return self
.__str
__()
166 return self
.comment
+ "VALUE(%#x)" % self
.val
169 if isinstance(self
.val
, int) and self
.base
== 16:
170 val
= "%#x" % self
.val
173 return self
.comment
+ val
175 def __eq__(self
, other
):
176 return type(self
) == type(other
) and self
.val
== other
.val
178 def __lt__(self
, other
):
179 if type(self
) != type(other
):
180 return type_sort(type(self
)) < type_sort(type(other
))
181 return self
.val
< other
.val
183 def __contains__(self
, other
):
187 return hash(self
.val
)
190 class STR(SimpleExpr
):
192 def __init__(self
, s
):
197 return self
.__str
__()
198 return self
.comment
+ "STR(%s)" % self
.val
201 return self
.comment
+ self
.val
204 class ADDR(SimpleExpr
):
206 resolver
= staticmethod(lambda x
: x
)
208 def __init__(self
, addr
):
212 return self
.comment
+ "ADDR(%s)" % self
.addr
215 return self
.comment
+ self
.resolver(self
.addr
)
217 def __eq__(self
, other
):
218 return type(self
) == type(other
) and self
.addr
== other
.addr
220 def __lt__(self
, other
):
221 if type(self
) != type(other
):
222 return type_sort(type(self
)) < type_sort(type(other
))
223 return self
.addr
< other
.addr
225 def __contains__(self
, other
):
229 return hash(self
.addr
)
232 class CVAR(SimpleExpr
):
234 def __init__(self
, name
):
238 return self
.comment
+ "CVAR(%s)" % self
.name
241 return self
.comment
+ self
.name
243 def __eq__(self
, other
):
244 return type(self
) == type(other
) and self
.name
== other
.name
247 return hash(self
.name
)
250 class MEM(SimpleExpr
):
251 def __init__(self
, type, expr
):
256 return self
.comment
+ "*(%s*)%r" % (self
.type, self
.expr
)
259 if isinstance(self
.expr
, EXPR
):
260 return self
.comment
+ "*(%s*)(%s)" % (self
.type, self
.expr
)
262 return self
.comment
+ "*(%s*)%s" % (self
.type, self
.expr
)
264 def __eq__(self
, other
):
265 return type(self
) == type(other
) and self
.type == other
.type and \
266 self
.expr
== other
.expr
268 def __lt__(self
, other
):
269 if type(self
) == type(other
):
270 return self
.expr
< other
.expr
271 return type_sort(type(self
)) < type_sort(type(other
))
273 def __contains__(self
, other
):
276 return other
in self
.expr
279 return hash(self
.type) ^
hash(self
.expr
)
282 return self
.expr
.regs()
285 class SFIELD(SimpleExpr
):
287 def __init__(self
, type, addr
, field
):
293 return self
.comment
+ "SFIELD(%s, %s, %s)" % (self
.type, self
.addr
, self
.field
)
296 return self
.comment
+ "((%s*)%s)->%s" % (self
.type, self
.addr
, self
.field
)
299 class SFUNC(SimpleExpr
):
301 def __init__(self
, name
):
305 return "(SFUNC)%s" % (self
.name
)
308 return "%s" % self
.name
310 def __eq__(self
, other
):
311 return type(self
) == type(other
) and self
.name
== other
.name
313 def __contains__(self
, other
):
317 return hash(self
.name
)
320 class TYPE(SimpleExpr
):
322 def __init__(self
, name
):
326 return "(TYPE)%s" % (self
.name
)
329 return "%s" % self
.name
331 def __eq__(self
, other
):
332 return type(self
) == type(other
) and self
.name
== other
.name
334 def __contains__(self
, other
):
338 return hash(self
.name
)
341 assert self
.name
[0] in ("i", "u")
342 return int(self
.name
[1:])
346 "A recursive expression."
347 def __init__(self
, op
, *args
):
349 if isinstance(args
[0], list):
356 return "EXPR(%s%s)" % (self
.op
, self
.args
)
361 # See e.g. http://en.cppreference.com/w/c/language/operator_precedence
363 "+=": 14, "-=": 14, "*=": 14, "/=": 14, "%=": 14,
364 "<<=": 14, ">>=": 14, "&=": 14, "|=": 14, "^=": 14,
366 "|": 10, "^": 9, "&": 8,
368 ">": 6, "<": 6, ">=": 6, "<=": 6,
371 "*": 3, "/": 3, "%": 3,
372 # All the below is highest precedence
373 "CAST": 1, "SFUNC": 1, "NEG": 1, "!": 1,
377 # Render this expr's arg, wrapped in parens if needed
379 def strarg(expr
, arg
):
380 if isinstance(arg
, (set, frozenset)):
381 s
= utils
.repr_stable(arg
)
384 preced_my
= EXPR
.preced(expr
)
385 preced_arg
= EXPR
.preced(arg
)
386 full_assoc
= expr
.op
in {"+", "*", "&", "^", "|"}
387 if preced_arg
== preced_my
and full_assoc
:
388 # Render repeated fully associative operators without extra parens
390 elif preced_arg
> preced_my
or (preced_arg
== preced_my
and preced_arg
!= 1):
391 # Otherwise, if precedence rules require parens, render them, unless
392 # the arg is a unary/primary term
395 # Parens would not be required per the precedence rules, but
396 # handle common cases of confusing precedence in C, where parens
397 # are usually suggested.
398 if expr
.op
in ("&", "^", "|") and preced_arg
!= 1:
399 # Any binary op subexpression of bitwise ops in parens
401 elif expr
.op
in ("<<", ">>") and preced_arg
!= 1:
402 # Any binary op subexpression of shift in parens
407 if not SimpleExpr
.simple_repr
:
408 return self
.__repr
__()
410 if self
.op
== "SFUNC":
411 return str(self
.args
[0]) + "(" + ", ".join([str(a
) for a
in self
.args
[1:]]) + ")"
412 if self
.op
== "CAST":
413 return "(" + str(self
.args
[0]) + ")" + self
.strarg(self
, self
.args
[1])
420 assert len(self
.args
) == 1
422 return s
+ self
.strarg(self
, self
.args
[0])
424 l
= [self
.strarg(self
, self
.args
[0])]
425 for a
in self
.args
[1:]:
426 if self
.op
== "+" and is_value(a
) and a
.val
< 0:
428 a
= VALUE(-a
.val
, a
.base
)
431 l
.append(self
.strarg(self
, a
))
434 def __eq__(self
, other
):
435 return type(self
) == type(other
) and self
.op
== other
.op
and self
.args
== other
.args
437 def __lt__(self
, other
):
438 if type(self
) == type(other
):
439 return str(self
) < str(other
)
440 return type_sort(type(self
)) < type_sort(type(other
))
442 def __contains__(self
, other
):
451 return hash(self
.op
) ^
hash(tuple(self
.args
))
454 # One for operation itself
468 def defs(self
, regs_only
=True):
469 assert not self
.side_effect()
472 def side_effect(self
):
473 if self
.op
== "SFUNC":
474 return self
.args
[0].name
not in {
475 "BIT", "abs", "bitfield", "count_leading_zeroes",
480 def foreach_subexpr(self
, func
):
481 # If func returned True, it means it handled entire subexpression,
482 # so we don't recurse into it.
483 # Note that this function recurses only within EXPR tree, it doesn't
484 # recurse e.g. inside MEM.
489 a
.foreach_subexpr(func
)
499 annotate_calls
= False
502 def __init__(self
, dest
, op
, args
, addr
=None):
510 "If instruction may transfer control, return jump address(es), otherwise return None."
511 if self
.op
in ("call", "goto"):
512 if isinstance(self
.args
[0], ADDR
):
513 return self
.args
[0].addr
516 for i
in range(0, len(self
.args
), 2):
517 if isinstance(self
.args
[i
+ 1], ADDR
):
518 res
.append(self
.args
[i
+ 1].addr
)
524 def side_effect(self
):
525 if self
.op
== "call":
527 if self
.op
in ("=", "SFUNC"):
528 assert len(self
.args
) == 1, self
.args
529 return self
.args
[0].side_effect()
533 def uses(self
, cfg
=None):
534 # Avoid circular import. TODO: fix properly
538 """Return set of all registers used by this instruction. Function
539 calls (and maybe SFUNCs) require special treatment."""
540 if self
.op
== "call":
543 if isinstance(addr
, ADDR
):
544 # Direct call with known address
546 if addr
in progdb
.FUNC_DB
and "params" in progdb
.FUNC_DB
[addr
]:
547 return uses | progdb
.FUNC_DB
[addr
]["params"]
549 # Indirect call or not params in funcdb
550 # TODO: need to allow saving callsite info in funcdb
551 return uses | arch
.call_params(addr
)
553 if self
.op
== "return":
555 return arch
.ret_uses(cfg
)
561 if is_mem(self
.dest
):
562 for r
in self
.dest
.regs():
567 def defs(self
, regs_only
=True, cfg
=None):
568 # Avoid circular import. TODO: fix properly
571 """Return set of all registers defined by this instruction. Function
572 calls (and maybe SFUNCs) require special treatment."""
573 if self
.op
== "call":
575 if isinstance(addr
, ADDR
):
576 # Direct call with known address
578 if addr
in progdb
.FUNC_DB
and "modifieds" in progdb
.FUNC_DB
[addr
]:
579 return progdb
.FUNC_DB
[addr
]["modifieds"]
581 # Indirect call or not params in funcdb
582 # TODO: need to allow saving callsite info in funcdb
583 return arch
.call_defs(addr
)
587 if not regs_only
or isinstance(self
.dest
, REG
):
592 def foreach_subexpr(self
, func
):
595 if is_expr(arg
) or is_cond(arg
):
596 arg
.foreach_subexpr(func
)
606 comments
= self
.comments
.copy()
608 if "org_inst" in comments
:
609 s
= "// " + str(comments
.pop("org_inst")) + "\n"
610 if self
.addr
is not None:
611 s
+= "/*%s*/ " % self
.addr
612 if self
.dest
is None:
616 s
+= "%s(%s)" % (self
.op
, self
.args
)
619 # Simplify repr for assignment
620 s
+= "%s = %s" % (self
.dest
, self
.args
)
622 s
+= "%s = %s(%s)" % (self
.dest
, self
.op
, self
.args
)
624 s
+= " # " + repr(comments
)
628 if not SimpleExpr
.simple_repr
:
629 return self
.__repr
__()
631 comments
= self
.comments
.copy()
635 addr
= "/*%s*/ " % self
.addr
638 return addr
+ self
.args
[0]
641 if self
.show_comments
and "org_inst" in comments
:
642 s
= self
.comment
+ " " + str(comments
.pop("org_inst")) + " "
646 if self
.op
== "call" and self
.annotate_calls
:
647 comments
["uses"] = sorted(self
.uses())
648 comments
["defs"] = sorted(self
.defs())
651 if self
.show_comments
and comments
:
652 tail
+= " " + self
.comment
+ " " + utils
.repr_stable_dict(comments
)
654 if self
.op
== "return":
655 args
= ", ".join([str(a
) for a
in self
.args
])
658 return s
+ self
.op
+ args
+ tail
659 if self
.op
in ("goto", "call"):
660 return s
+ "%s %s" % (self
.op
, self
.args
[0]) + tail
662 joined
= ", ".join(["%s goto %s" % (self
.args
[i
] or "else", self
.args
[i
+ 1]) for i
in range(0, len(self
.args
), 2)])
663 return s
+ "if " + joined
+ tail
665 if self
.op
== "DEAD":
666 return s
+ "(dead)" + tail
668 if self
.op
== "SFUNC":
669 assert self
.dest
is None
670 assert len(self
.args
) == 1, repr(self
.args
)
671 return s
+ str(self
.args
[0]) + tail
673 assert self
.op
== "=", repr(self
.op
)
674 assert len(self
.args
) == 1, (self
.op
, repr(self
.args
))
676 if self
.op
== "=" and not is_expr(self
.args
[0]):
677 s
+= "%s = %s" % (self
.dest
, self
.args
[0])
679 e
= copy
.copy(self
.args
[0])
682 if not (op
== "!" or op
[0].isalpha()):
684 assert len(args
) >= 2, repr(args
)
685 if self
.dest
== args
[0]:
686 s
+= "%s %s= " % (self
.dest
, op
)
687 # Render operator as a compound statement operator
688 # (lowest priority, no extra parens).
692 s
+= "%s = " % self
.dest
694 if self
.dest
is not None:
695 s
+= "%s = " % self
.dest
702 def __eq__(self
, other
):
703 return self
.op
== other
.op
and self
.dest
== other
.dest
and self
.args
== other
.args
707 """This class is a container of EXPR used as a condition in the
708 'if (cond) goto' statement. It's needed because the same condition
709 is used both in the Inst representing such a statement and a lebel
710 of a CFG edge connecting basic blocks. If condition is updated,
711 e.g. while transforming its Inst, the change should be mirrored
712 to the CFG edge. Using COND class, this can be easily achieved:
713 the same COND instance is referenced both in Inst and edge, and
714 we can freely update or even completely replace EXPR it contains,
715 while both users will stay up to date.
737 def __init__(self
, expr
):
743 return self
.__class
__(EXPR(self
.NEG
[op
], self
.expr
.args
))
745 return self
.__class
__(self
.expr
.args
[0])
747 return self
.__class
__(EXPR("!", self
.expr
))
750 "Swap arguments in-place."
751 self
.expr
.args
[0], self
.expr
.args
[1] = self
.expr
.args
[1], self
.expr
.args
[0]
752 self
.expr
.op
= self
.SWAP
[self
.expr
.op
]
755 if is_value(self
.expr
.args
[0]) and not is_value(self
.expr
.args
[1]):
758 def is_relation(self
):
759 return is_expr(self
.expr
) and self
.expr
.op
in self
.NEG
765 return "(%s)" % self
.expr
768 # if self.op in ("in", "not in"):
769 # return "COND(%r %s %s)" % (self.arg1, self.op, utils.repr_stable(self.arg2))
770 return "COND(%r)" % self
.expr
772 def __eq__(self
, other
):
773 return type(self
) == type(other
) and self
.expr
== other
.expr
775 def __contains__(self
, other
):
776 return other
in self
.expr
779 return hash(self
.expr
) ^
hash(self
.__class
__)
782 return self
.expr
.regs()
784 def defs(self
, regs_only
=True):
785 return self
.expr
.defs(regs_only
)
788 return self
.expr
.uses()
790 def foreach_subexpr(self
, func
):
791 self
.expr
.foreach_subexpr(func
)
801 def __init__(self
, l
):
804 def append(self
, op
, arg2
):
805 self
.args
.extend([op
, arg2
])
808 return self
.__class
__([self
.NEG
[x
] if isinstance(x
, str) else x
.neg() for x
in self
.args
])
814 r
= " ".join([str(x
) for x
in self
.args
])
818 return "CCond%s" % str(self
)
820 def repr_state(state
):
823 for k
, v
in sorted(state
.items()):
827 res
.append("%s=%s" % (k
, v
))
830 res
+= " UNK: " + ",".join(unk
)
831 return "{" + res
+ "}"
835 """Print BBlocks in a CFG. Various printing params can be overriden
838 header_reg_prefix
= "$"
839 addr_in_header
= False
841 def __init__(self
, cfg
, stream
=sys
.stdout
):
844 # Current bblock addr
846 # Current CFG node properties
847 self
.node_props
= None
850 # Current BBlock properties
851 self
.bblock_props
= None
852 self
.inst_printer
= str
855 def bblock_order(self
):
856 "Return iterator over bblocks to be printed."
857 return self
.cfg
.iter_sorted_nodes()
859 def print_graph_header(self
):
861 print("// Graph props:", file=self
.stream
)
862 for k
in sorted(self
.cfg
.props
.keys()):
863 v
= self
.cfg
.props
[k
]
864 v
= utils
.repr_stable(v
)
865 print("// %s: %s" % (k
, v
), file=self
.stream
)
866 print(file=self
.stream
)
869 def print_header(self
):
870 if self
.addr_in_header
:
871 print("// %s" % self
.addr
, file=self
.stream
)
872 print("// Predecessors: %s" % sorted(self
.cfg
.pred(self
.addr
)), file=self
.stream
)
875 print("// Node props:", file=self
.stream
)
876 for k
in sorted(self
.node_props
.keys()):
877 v
= self
.node_props
[k
]
878 v
= utils
.repr_stable(v
)
879 v
= v
.replace("$", self
.header_reg_prefix
)
880 print("// %s: %s" % (k
, v
), file=self
.stream
)
882 if self
.bblock_props
:
883 print("// BBlock props:", file=self
.stream
)
885 for k
in sorted(self
.bblock_props
.keys()):
886 v
= self
.bblock_props
[k
]
887 if k
.startswith("state_"):
890 v
= utils
.repr_stable(v
)
891 v
= v
.replace("$", self
.header_reg_prefix
)
892 print("// %s: %s" % (k
, v
), file=self
.stream
)
895 def print_trailer(self
):
896 succ
= self
.cfg
.succ(self
.addr
)
897 exits
= [(self
.cfg
.edge(self
.addr
, x
).get("cond"), x
) for x
in succ
]
898 print("Exits:", sorted(exits
, key
=lambda x
: utils
.natural_sort_key(str(x
))), file=self
.stream
)
901 def print_label(self
):
902 print("%s:" % self
.addr
, file=self
.stream
)
904 def print_inst(self
, inst
):
905 if inst
.op
== "DEAD" and self
.no_dead
:
907 return self
.inst_printer(inst
)
909 def print_separator(self
):
910 self
.stream
.write("\n")
913 self
.print_graph_header()
915 for self
.addr
, info
in self
.bblock_order():
916 self
.node_props
= info
.copy()
917 self
.bblock
= self
.node_props
.pop("val")
918 self
.bblock_props
= self
.bblock
.props
920 self
.print_separator()
923 if self
.bblock
is not None:
924 self
.bblock
.dump(self
.stream
, 0, self
.print_inst
)
926 print(" ", self
.bblock
, file=self
.stream
)
931 def dump_bblocks(cfg
, stream
=sys
.stdout
, printer
=str, no_graph_header
=False):
932 p
= CFGPrinter(cfg
, stream
)
933 p
.inst_printer
= printer
935 p
.print_graph_header
= lambda: None