Merged release21-maint changes.
[python/dscho.git] / Lib / compiler / pyassem.py
blob06846233be4d0254deb67e31f223a191281c294e
1 """A flow graph representation for Python bytecode"""
3 import dis
4 import new
5 import string
6 import sys
7 import types
9 from compiler import misc
11 def xxx_sort(l):
12 l = l[:]
13 def sorter(a, b):
14 return cmp(a.bid, b.bid)
15 l.sort(sorter)
16 return l
18 class FlowGraph:
19 def __init__(self):
20 self.current = self.entry = Block()
21 self.exit = Block("exit")
22 self.blocks = misc.Set()
23 self.blocks.add(self.entry)
24 self.blocks.add(self.exit)
26 def startBlock(self, block):
27 if self._debug:
28 if self.current:
29 print "end", repr(self.current)
30 print " next", self.current.next
31 print " ", self.current.get_children()
32 print repr(block)
33 self.current = block
35 def nextBlock(self, block=None):
36 # XXX think we need to specify when there is implicit transfer
37 # from one block to the next. might be better to represent this
38 # with explicit JUMP_ABSOLUTE instructions that are optimized
39 # out when they are unnecessary.
41 # I think this strategy works: each block has a child
42 # designated as "next" which is returned as the last of the
43 # children. because the nodes in a graph are emitted in
44 # reverse post order, the "next" block will always be emitted
45 # immediately after its parent.
46 # Worry: maintaining this invariant could be tricky
47 if block is None:
48 block = self.newBlock()
50 # Note: If the current block ends with an unconditional
51 # control transfer, then it is incorrect to add an implicit
52 # transfer to the block graph. The current code requires
53 # these edges to get the blocks emitted in the right order,
54 # however. :-( If a client needs to remove these edges, call
55 # pruneEdges().
57 self.current.addNext(block)
58 self.startBlock(block)
60 def newBlock(self):
61 b = Block()
62 self.blocks.add(b)
63 return b
65 def startExitBlock(self):
66 self.startBlock(self.exit)
68 _debug = 0
70 def _enable_debug(self):
71 self._debug = 1
73 def _disable_debug(self):
74 self._debug = 0
76 def emit(self, *inst):
77 if self._debug:
78 print "\t", inst
79 if inst[0] == 'RETURN_VALUE':
80 self.current.addOutEdge(self.exit)
81 if len(inst) == 2 and isinstance(inst[1], Block):
82 self.current.addOutEdge(inst[1])
83 self.current.emit(inst)
85 def getBlocksInOrder(self):
86 """Return the blocks in reverse postorder
88 i.e. each node appears before all of its successors
89 """
90 # XXX make sure every node that doesn't have an explicit next
91 # is set so that next points to exit
92 for b in self.blocks.elements():
93 if b is self.exit:
94 continue
95 if not b.next:
96 b.addNext(self.exit)
97 order = dfs_postorder(self.entry, {})
98 order.reverse()
99 self.fixupOrder(order, self.exit)
100 # hack alert
101 if not self.exit in order:
102 order.append(self.exit)
104 return order
106 def fixupOrder(self, blocks, default_next):
107 """Fixup bad order introduced by DFS."""
109 # XXX This is a total mess. There must be a better way to get
110 # the code blocks in the right order.
112 self.fixupOrderHonorNext(blocks, default_next)
113 self.fixupOrderForward(blocks, default_next)
115 def fixupOrderHonorNext(self, blocks, default_next):
116 """Fix one problem with DFS.
118 The DFS uses child block, but doesn't know about the special
119 "next" block. As a result, the DFS can order blocks so that a
120 block isn't next to the right block for implicit control
121 transfers.
123 index = {}
124 for i in range(len(blocks)):
125 index[blocks[i]] = i
127 for i in range(0, len(blocks) - 1):
128 b = blocks[i]
129 n = blocks[i + 1]
130 if not b.next or b.next[0] == default_next or b.next[0] == n:
131 continue
132 # The blocks are in the wrong order. Find the chain of
133 # blocks to insert where they belong.
134 cur = b
135 chain = []
136 elt = cur
137 while elt.next and elt.next[0] != default_next:
138 chain.append(elt.next[0])
139 elt = elt.next[0]
140 # Now remove the blocks in the chain from the current
141 # block list, so that they can be re-inserted.
142 l = []
143 for b in chain:
144 assert index[b] > i
145 l.append((index[b], b))
146 l.sort()
147 l.reverse()
148 for j, b in l:
149 del blocks[index[b]]
150 # Insert the chain in the proper location
151 blocks[i:i + 1] = [cur] + chain
152 # Finally, re-compute the block indexes
153 for i in range(len(blocks)):
154 index[blocks[i]] = i
156 def fixupOrderForward(self, blocks, default_next):
157 """Make sure all JUMP_FORWARDs jump forward"""
158 index = {}
159 chains = []
160 cur = []
161 for b in blocks:
162 index[b] = len(chains)
163 cur.append(b)
164 if b.next and b.next[0] == default_next:
165 chains.append(cur)
166 cur = []
167 chains.append(cur)
169 while 1:
170 constraints = []
172 for i in range(len(chains)):
173 l = chains[i]
174 for b in l:
175 for c in b.get_children():
176 if index[c] < i:
177 forward_p = 0
178 for inst in b.insts:
179 if inst[0] == 'JUMP_FORWARD':
180 if inst[1] == c:
181 forward_p = 1
182 if not forward_p:
183 continue
184 constraints.append((index[c], i))
186 if not constraints:
187 break
189 # XXX just do one for now
190 # do swaps to get things in the right order
191 goes_before, a_chain = constraints[0]
192 assert a_chain > goes_before
193 c = chains[a_chain]
194 chains.remove(c)
195 chains.insert(goes_before, c)
198 del blocks[:]
199 for c in chains:
200 for b in c:
201 blocks.append(b)
203 def getBlocks(self):
204 return self.blocks.elements()
206 def getRoot(self):
207 """Return nodes appropriate for use with dominator"""
208 return self.entry
210 def getContainedGraphs(self):
211 l = []
212 for b in self.getBlocks():
213 l.extend(b.getContainedGraphs())
214 return l
216 def dfs_postorder(b, seen):
217 """Depth-first search of tree rooted at b, return in postorder"""
218 order = []
219 seen[b] = b
220 for c in b.get_children():
221 if seen.has_key(c):
222 continue
223 order = order + dfs_postorder(c, seen)
224 order.append(b)
225 return order
227 class Block:
228 _count = 0
230 def __init__(self, label=''):
231 self.insts = []
232 self.inEdges = misc.Set()
233 self.outEdges = misc.Set()
234 self.label = label
235 self.bid = Block._count
236 self.next = []
237 Block._count = Block._count + 1
239 def __repr__(self):
240 if self.label:
241 return "<block %s id=%d>" % (self.label, self.bid)
242 else:
243 return "<block id=%d>" % (self.bid)
245 def __str__(self):
246 insts = map(str, self.insts)
247 return "<block %s %d:\n%s>" % (self.label, self.bid,
248 string.join(insts, '\n'))
250 def emit(self, inst):
251 op = inst[0]
252 if op[:4] == 'JUMP':
253 self.outEdges.add(inst[1])
254 self.insts.append(inst)
256 def getInstructions(self):
257 return self.insts
259 def addInEdge(self, block):
260 self.inEdges.add(block)
262 def addOutEdge(self, block):
263 self.outEdges.add(block)
265 def addNext(self, block):
266 self.next.append(block)
267 assert len(self.next) == 1, map(str, self.next)
269 _uncond_transfer = ('RETURN_VALUE', 'RAISE_VARARGS',
270 'JUMP_ABSOLUTE', 'JUMP_FORWARD')
272 def pruneNext(self):
273 """Remove bogus edge for unconditional transfers
275 Each block has a next edge that accounts for implicit control
276 transfers, e.g. from a JUMP_IF_FALSE to the block that will be
277 executed if the test is true.
279 These edges must remain for the current assembler code to
280 work. If they are removed, the dfs_postorder gets things in
281 weird orders. However, they shouldn't be there for other
282 purposes, e.g. conversion to SSA form. This method will
283 remove the next edge when it follows an unconditional control
284 transfer.
286 try:
287 op, arg = self.insts[-1]
288 except (IndexError, ValueError):
289 return
290 if op in self._uncond_transfer:
291 self.next = []
293 def get_children(self):
294 if self.next and self.next[0] in self.outEdges:
295 self.outEdges.remove(self.next[0])
296 return self.outEdges.elements() + self.next
298 def getContainedGraphs(self):
299 """Return all graphs contained within this block.
301 For example, a MAKE_FUNCTION block will contain a reference to
302 the graph for the function body.
304 contained = []
305 for inst in self.insts:
306 if len(inst) == 1:
307 continue
308 op = inst[1]
309 if hasattr(op, 'graph'):
310 contained.append(op.graph)
311 return contained
313 # flags for code objects
314 CO_OPTIMIZED = 0x0001
315 CO_NEWLOCALS = 0x0002
316 CO_VARARGS = 0x0004
317 CO_VARKEYWORDS = 0x0008
318 CO_NESTED = 0x0010
320 # the FlowGraph is transformed in place; it exists in one of these states
321 RAW = "RAW"
322 FLAT = "FLAT"
323 CONV = "CONV"
324 DONE = "DONE"
326 class PyFlowGraph(FlowGraph):
327 super_init = FlowGraph.__init__
329 def __init__(self, name, filename, args=(), optimized=0):
330 self.super_init()
331 self.name = name
332 self.filename = filename
333 self.docstring = None
334 self.args = args # XXX
335 self.argcount = getArgCount(args)
336 if optimized:
337 self.flags = CO_OPTIMIZED | CO_NEWLOCALS
338 else:
339 self.flags = 0
340 self.consts = []
341 self.names = []
342 # Free variables found by the symbol table scan, including
343 # variables used only in nested scopes, are included here.
344 self.freevars = []
345 self.cellvars = []
346 # The closure list is used to track the order of cell
347 # variables and free variables in the resulting code object.
348 # The offsets used by LOAD_CLOSURE/LOAD_DEREF refer to both
349 # kinds of variables.
350 self.closure = []
351 self.varnames = list(args) or []
352 for i in range(len(self.varnames)):
353 var = self.varnames[i]
354 if isinstance(var, TupleArg):
355 self.varnames[i] = var.getName()
356 self.stage = RAW
358 def setDocstring(self, doc):
359 self.docstring = doc
361 def setFlag(self, flag):
362 self.flags = self.flags | flag
363 if flag == CO_VARARGS:
364 self.argcount = self.argcount - 1
366 def setFreeVars(self, names):
367 self.freevars = list(names)
369 def setCellVars(self, names):
370 self.cellvars = names
372 def getCode(self):
373 """Get a Python code object"""
374 if self.stage == RAW:
375 self.flattenGraph()
376 if self.stage == FLAT:
377 self.convertArgs()
378 if self.stage == CONV:
379 self.makeByteCode()
380 if self.stage == DONE:
381 return self.newCodeObject()
382 raise RuntimeError, "inconsistent PyFlowGraph state"
384 def dump(self, io=None):
385 if io:
386 save = sys.stdout
387 sys.stdout = io
388 pc = 0
389 for t in self.insts:
390 opname = t[0]
391 if opname == "SET_LINENO":
392 print
393 if len(t) == 1:
394 print "\t", "%3d" % pc, opname
395 pc = pc + 1
396 else:
397 print "\t", "%3d" % pc, opname, t[1]
398 pc = pc + 3
399 if io:
400 sys.stdout = save
402 def flattenGraph(self):
403 """Arrange the blocks in order and resolve jumps"""
404 assert self.stage == RAW
405 self.insts = insts = []
406 pc = 0
407 begin = {}
408 end = {}
409 for b in self.getBlocksInOrder():
410 begin[b] = pc
411 for inst in b.getInstructions():
412 insts.append(inst)
413 if len(inst) == 1:
414 pc = pc + 1
415 else:
416 # arg takes 2 bytes
417 pc = pc + 3
418 end[b] = pc
419 pc = 0
420 for i in range(len(insts)):
421 inst = insts[i]
422 if len(inst) == 1:
423 pc = pc + 1
424 else:
425 pc = pc + 3
426 opname = inst[0]
427 if self.hasjrel.has_elt(opname):
428 oparg = inst[1]
429 offset = begin[oparg] - pc
430 insts[i] = opname, offset
431 elif self.hasjabs.has_elt(opname):
432 insts[i] = opname, begin[inst[1]]
433 self.stacksize = findDepth(self.insts)
434 self.stage = FLAT
436 hasjrel = misc.Set()
437 for i in dis.hasjrel:
438 hasjrel.add(dis.opname[i])
439 hasjabs = misc.Set()
440 for i in dis.hasjabs:
441 hasjabs.add(dis.opname[i])
443 def convertArgs(self):
444 """Convert arguments from symbolic to concrete form"""
445 assert self.stage == FLAT
446 self.consts.insert(0, self.docstring)
447 self.sort_cellvars()
448 for i in range(len(self.insts)):
449 t = self.insts[i]
450 if len(t) == 2:
451 opname = t[0]
452 oparg = t[1]
453 conv = self._converters.get(opname, None)
454 if conv:
455 self.insts[i] = opname, conv(self, oparg)
456 self.stage = CONV
458 def sort_cellvars(self):
459 """Sort cellvars in the order of varnames and prune from freevars.
461 cells = {}
462 for name in self.cellvars:
463 cells[name] = 1
464 self.cellvars = [name for name in self.varnames
465 if cells.has_key(name)]
466 for name in self.cellvars:
467 del cells[name]
468 self.cellvars = self.cellvars + cells.keys()
469 self.closure = self.cellvars + self.freevars
471 def _lookupName(self, name, list):
472 """Return index of name in list, appending if necessary"""
473 t = type(name)
474 for i in range(len(list)):
475 # must do a comparison on type first to prevent UnicodeErrors
476 if t == type(list[i]) and list[i] == name:
477 return i
478 end = len(list)
479 list.append(name)
480 return end
482 _converters = {}
483 def _convert_LOAD_CONST(self, arg):
484 if hasattr(arg, 'getCode'):
485 arg = arg.getCode()
486 return self._lookupName(arg, self.consts)
488 def _convert_LOAD_FAST(self, arg):
489 self._lookupName(arg, self.names)
490 return self._lookupName(arg, self.varnames)
491 _convert_STORE_FAST = _convert_LOAD_FAST
492 _convert_DELETE_FAST = _convert_LOAD_FAST
494 def _convert_NAME(self, arg):
495 return self._lookupName(arg, self.names)
496 _convert_LOAD_NAME = _convert_NAME
497 _convert_STORE_NAME = _convert_NAME
498 _convert_DELETE_NAME = _convert_NAME
499 _convert_IMPORT_NAME = _convert_NAME
500 _convert_IMPORT_FROM = _convert_NAME
501 _convert_STORE_ATTR = _convert_NAME
502 _convert_LOAD_ATTR = _convert_NAME
503 _convert_DELETE_ATTR = _convert_NAME
504 _convert_LOAD_GLOBAL = _convert_NAME
505 _convert_STORE_GLOBAL = _convert_NAME
506 _convert_DELETE_GLOBAL = _convert_NAME
508 def _convert_DEREF(self, arg):
509 self._lookupName(arg, self.names)
510 self._lookupName(arg, self.varnames)
511 return self._lookupName(arg, self.closure)
512 _convert_LOAD_DEREF = _convert_DEREF
513 _convert_STORE_DEREF = _convert_DEREF
515 def _convert_LOAD_CLOSURE(self, arg):
516 self._lookupName(arg, self.varnames)
517 return self._lookupName(arg, self.closure)
519 _cmp = list(dis.cmp_op)
520 def _convert_COMPARE_OP(self, arg):
521 return self._cmp.index(arg)
523 # similarly for other opcodes...
525 for name, obj in locals().items():
526 if name[:9] == "_convert_":
527 opname = name[9:]
528 _converters[opname] = obj
529 del name, obj, opname
531 def makeByteCode(self):
532 assert self.stage == CONV
533 self.lnotab = lnotab = LineAddrTable()
534 for t in self.insts:
535 opname = t[0]
536 if len(t) == 1:
537 lnotab.addCode(self.opnum[opname])
538 else:
539 oparg = t[1]
540 if opname == "SET_LINENO":
541 lnotab.nextLine(oparg)
542 hi, lo = twobyte(oparg)
543 try:
544 lnotab.addCode(self.opnum[opname], lo, hi)
545 except ValueError:
546 print opname, oparg
547 print self.opnum[opname], lo, hi
548 raise
549 self.stage = DONE
551 opnum = {}
552 for num in range(len(dis.opname)):
553 opnum[dis.opname[num]] = num
554 del num
556 def newCodeObject(self):
557 assert self.stage == DONE
558 if self.flags == 0:
559 nlocals = 0
560 else:
561 nlocals = len(self.varnames)
562 argcount = self.argcount
563 if self.flags & CO_VARKEYWORDS:
564 argcount = argcount - 1
565 return new.code(argcount, nlocals, self.stacksize, self.flags,
566 self.lnotab.getCode(), self.getConsts(),
567 tuple(self.names), tuple(self.varnames),
568 self.filename, self.name, self.lnotab.firstline,
569 self.lnotab.getTable(), tuple(self.freevars),
570 tuple(self.cellvars))
572 def getConsts(self):
573 """Return a tuple for the const slot of the code object
575 Must convert references to code (MAKE_FUNCTION) to code
576 objects recursively.
578 l = []
579 for elt in self.consts:
580 if isinstance(elt, PyFlowGraph):
581 elt = elt.getCode()
582 l.append(elt)
583 return tuple(l)
585 def isJump(opname):
586 if opname[:4] == 'JUMP':
587 return 1
589 class TupleArg:
590 """Helper for marking func defs with nested tuples in arglist"""
591 def __init__(self, count, names):
592 self.count = count
593 self.names = names
594 def __repr__(self):
595 return "TupleArg(%s, %s)" % (self.count, self.names)
596 def getName(self):
597 return ".%d" % self.count
599 def getArgCount(args):
600 argcount = len(args)
601 if args:
602 for arg in args:
603 if isinstance(arg, TupleArg):
604 numNames = len(misc.flatten(arg.names))
605 argcount = argcount - numNames
606 return argcount
608 def twobyte(val):
609 """Convert an int argument into high and low bytes"""
610 assert type(val) == types.IntType
611 return divmod(val, 256)
613 class LineAddrTable:
614 """lnotab
616 This class builds the lnotab, which is documented in compile.c.
617 Here's a brief recap:
619 For each SET_LINENO instruction after the first one, two bytes are
620 added to lnotab. (In some cases, multiple two-byte entries are
621 added.) The first byte is the distance in bytes between the
622 instruction for the last SET_LINENO and the current SET_LINENO.
623 The second byte is offset in line numbers. If either offset is
624 greater than 255, multiple two-byte entries are added -- see
625 compile.c for the delicate details.
628 def __init__(self):
629 self.code = []
630 self.codeOffset = 0
631 self.firstline = 0
632 self.lastline = 0
633 self.lastoff = 0
634 self.lnotab = []
636 def addCode(self, *args):
637 for arg in args:
638 self.code.append(chr(arg))
639 self.codeOffset = self.codeOffset + len(args)
641 def nextLine(self, lineno):
642 if self.firstline == 0:
643 self.firstline = lineno
644 self.lastline = lineno
645 else:
646 # compute deltas
647 addr = self.codeOffset - self.lastoff
648 line = lineno - self.lastline
649 # Python assumes that lineno always increases with
650 # increasing bytecode address (lnotab is unsigned char).
651 # Depending on when SET_LINENO instructions are emitted
652 # this is not always true. Consider the code:
653 # a = (1,
654 # b)
655 # In the bytecode stream, the assignment to "a" occurs
656 # after the loading of "b". This works with the C Python
657 # compiler because it only generates a SET_LINENO instruction
658 # for the assignment.
659 if line > 0:
660 push = self.lnotab.append
661 while addr > 255:
662 push(255); push(0)
663 addr -= 255
664 while line > 255:
665 push(addr); push(255)
666 line -= 255
667 addr = 0
668 if addr > 0 or line > 0:
669 push(addr); push(line)
670 self.lastline = lineno
671 self.lastoff = self.codeOffset
673 def getCode(self):
674 return string.join(self.code, '')
676 def getTable(self):
677 return string.join(map(chr, self.lnotab), '')
679 class StackDepthTracker:
680 # XXX 1. need to keep track of stack depth on jumps
681 # XXX 2. at least partly as a result, this code is broken
683 def findDepth(self, insts):
684 depth = 0
685 maxDepth = 0
686 for i in insts:
687 opname = i[0]
688 delta = self.effect.get(opname, 0)
689 if delta > 1:
690 depth = depth + delta
691 elif delta < 0:
692 if depth > maxDepth:
693 maxDepth = depth
694 depth = depth + delta
695 else:
696 if depth > maxDepth:
697 maxDepth = depth
698 # now check patterns
699 for pat, pat_delta in self.patterns:
700 if opname[:len(pat)] == pat:
701 delta = pat_delta
702 depth = depth + delta
703 break
704 # if we still haven't found a match
705 if delta == 0:
706 meth = getattr(self, opname, None)
707 if meth is not None:
708 depth = depth + meth(i[1])
709 if depth < 0:
710 depth = 0
711 return maxDepth
713 effect = {
714 'POP_TOP': -1,
715 'DUP_TOP': 1,
716 'SLICE+1': -1,
717 'SLICE+2': -1,
718 'SLICE+3': -2,
719 'STORE_SLICE+0': -1,
720 'STORE_SLICE+1': -2,
721 'STORE_SLICE+2': -2,
722 'STORE_SLICE+3': -3,
723 'DELETE_SLICE+0': -1,
724 'DELETE_SLICE+1': -2,
725 'DELETE_SLICE+2': -2,
726 'DELETE_SLICE+3': -3,
727 'STORE_SUBSCR': -3,
728 'DELETE_SUBSCR': -2,
729 # PRINT_EXPR?
730 'PRINT_ITEM': -1,
731 'LOAD_LOCALS': 1,
732 'RETURN_VALUE': -1,
733 'EXEC_STMT': -2,
734 'BUILD_CLASS': -2,
735 'STORE_NAME': -1,
736 'STORE_ATTR': -2,
737 'DELETE_ATTR': -1,
738 'STORE_GLOBAL': -1,
739 'BUILD_MAP': 1,
740 'COMPARE_OP': -1,
741 'STORE_FAST': -1,
742 'IMPORT_STAR': -1,
743 'IMPORT_NAME': 0,
744 'IMPORT_FROM': 1,
746 # use pattern match
747 patterns = [
748 ('BINARY_', -1),
749 ('LOAD_', 1),
752 # special cases:
753 # UNPACK_SEQUENCE, BUILD_TUPLE,
754 # BUILD_LIST, CALL_FUNCTION, MAKE_FUNCTION, BUILD_SLICE
755 def UNPACK_SEQUENCE(self, count):
756 return count
757 def BUILD_TUPLE(self, count):
758 return -count
759 def BUILD_LIST(self, count):
760 return -count
761 def CALL_FUNCTION(self, argc):
762 hi, lo = divmod(argc, 256)
763 return lo + hi * 2
764 def CALL_FUNCTION_VAR(self, argc):
765 return self.CALL_FUNCTION(argc)+1
766 def CALL_FUNCTION_KW(self, argc):
767 return self.CALL_FUNCTION(argc)+1
768 def CALL_FUNCTION_VAR_KW(self, argc):
769 return self.CALL_FUNCTION(argc)+2
770 def MAKE_FUNCTION(self, argc):
771 return -argc
772 def BUILD_SLICE(self, argc):
773 if argc == 2:
774 return -1
775 elif argc == 3:
776 return -2
778 findDepth = StackDepthTracker().findDepth