Remove a ?? in the description of Mac OS support.
[python/dscho.git] / Lib / compiler / pyassem.py
blob74ea562f44e4bfb0cde6384ba53007a16d04c285
1 """A flow graph representation for Python bytecode"""
3 import dis
4 import new
5 import string
6 import types
8 from compiler import misc
10 class FlowGraph:
11 def __init__(self):
12 self.current = self.entry = Block()
13 self.exit = Block("exit")
14 self.blocks = misc.Set()
15 self.blocks.add(self.entry)
16 self.blocks.add(self.exit)
18 def startBlock(self, block):
19 self.current = block
21 def nextBlock(self, block=None):
22 if block is None:
23 block = self.newBlock()
24 # XXX think we need to specify when there is implicit transfer
25 # from one block to the next
27 # I think this strategy works: each block has a child
28 # designated as "next" which is returned as the last of the
29 # children. because the nodes in a graph are emitted in
30 # reverse post order, the "next" block will always be emitted
31 # immediately after its parent.
32 # Worry: maintaining this invariant could be tricky
33 self.current.addNext(block)
34 self.startBlock(block)
36 def newBlock(self):
37 b = Block()
38 self.blocks.add(b)
39 return b
41 def startExitBlock(self):
42 self.startBlock(self.exit)
44 def emit(self, *inst):
45 # XXX should jump instructions implicitly call nextBlock?
46 if inst[0] == 'RETURN_VALUE':
47 self.current.addOutEdge(self.exit)
48 self.current.emit(inst)
50 def getBlocks(self):
51 """Return the blocks in reverse postorder
53 i.e. each node appears before all of its successors
54 """
55 # XXX make sure every node that doesn't have an explicit next
56 # is set so that next points to exit
57 for b in self.blocks.elements():
58 if b is self.exit:
59 continue
60 if not b.next:
61 b.addNext(self.exit)
62 order = dfs_postorder(self.entry, {})
63 order.reverse()
64 # hack alert
65 if not self.exit in order:
66 order.append(self.exit)
67 return order
69 def dfs_postorder(b, seen):
70 """Depth-first search of tree rooted at b, return in postorder"""
71 order = []
72 seen[b] = b
73 for c in b.children():
74 if seen.has_key(c):
75 continue
76 order = order + dfs_postorder(c, seen)
77 order.append(b)
78 return order
80 class Block:
81 _count = 0
83 def __init__(self, label=''):
84 self.insts = []
85 self.inEdges = misc.Set()
86 self.outEdges = misc.Set()
87 self.label = label
88 self.bid = Block._count
89 self.next = []
90 Block._count = Block._count + 1
92 def __repr__(self):
93 if self.label:
94 return "<block %s id=%d len=%d>" % (self.label, self.bid,
95 len(self.insts))
96 else:
97 return "<block id=%d len=%d>" % (self.bid, len(self.insts))
99 def __str__(self):
100 insts = map(str, self.insts)
101 return "<block %s %d:\n%s>" % (self.label, self.bid,
102 string.join(insts, '\n'))
104 def emit(self, inst):
105 op = inst[0]
106 if op[:4] == 'JUMP':
107 self.outEdges.add(inst[1])
108 self.insts.append(inst)
110 def getInstructions(self):
111 return self.insts
113 def addInEdge(self, block):
114 self.inEdges.add(block)
116 def addOutEdge(self, block):
117 self.outEdges.add(block)
119 def addNext(self, block):
120 self.next.append(block)
121 assert len(self.next) == 1, map(str, self.next)
123 def children(self):
124 return self.outEdges.elements() + self.next
126 # flags for code objects
127 CO_OPTIMIZED = 0x0001
128 CO_NEWLOCALS = 0x0002
129 CO_VARARGS = 0x0004
130 CO_VARKEYWORDS = 0x0008
132 # the FlowGraph is transformed in place; it exists in one of these states
133 RAW = "RAW"
134 FLAT = "FLAT"
135 CONV = "CONV"
136 DONE = "DONE"
138 class PyFlowGraph(FlowGraph):
139 super_init = FlowGraph.__init__
141 def __init__(self, name, filename, args=(), optimized=0):
142 self.super_init()
143 self.name = name
144 self.filename = filename
145 self.docstring = None
146 self.args = args # XXX
147 self.argcount = getArgCount(args)
148 if optimized:
149 self.flags = CO_OPTIMIZED | CO_NEWLOCALS
150 else:
151 self.flags = 0
152 self.consts = []
153 self.names = []
154 self.varnames = list(args) or []
155 for i in range(len(self.varnames)):
156 var = self.varnames[i]
157 if isinstance(var, TupleArg):
158 self.varnames[i] = var.getName()
159 self.stage = RAW
161 def setDocstring(self, doc):
162 self.docstring = doc
163 self.consts.insert(0, doc)
165 def setFlag(self, flag):
166 self.flags = self.flags | flag
167 if flag == CO_VARARGS:
168 self.argcount = self.argcount - 1
170 def getCode(self):
171 """Get a Python code object"""
172 if self.stage == RAW:
173 self.flattenGraph()
174 if self.stage == FLAT:
175 self.convertArgs()
176 if self.stage == CONV:
177 self.makeByteCode()
178 if self.stage == DONE:
179 return self.newCodeObject()
180 raise RuntimeError, "inconsistent PyFlowGraph state"
182 def dump(self, io=None):
183 if io:
184 save = sys.stdout
185 sys.stdout = io
186 pc = 0
187 for t in self.insts:
188 opname = t[0]
189 if opname == "SET_LINENO":
190 print
191 if len(t) == 1:
192 print "\t", "%3d" % pc, opname
193 pc = pc + 1
194 else:
195 print "\t", "%3d" % pc, opname, t[1]
196 pc = pc + 3
197 if io:
198 sys.stdout = save
200 def flattenGraph(self):
201 """Arrange the blocks in order and resolve jumps"""
202 assert self.stage == RAW
203 self.insts = insts = []
204 pc = 0
205 begin = {}
206 end = {}
207 for b in self.getBlocks():
208 begin[b] = pc
209 for inst in b.getInstructions():
210 insts.append(inst)
211 if len(inst) == 1:
212 pc = pc + 1
213 else:
214 # arg takes 2 bytes
215 pc = pc + 3
216 end[b] = pc
217 pc = 0
218 for i in range(len(insts)):
219 inst = insts[i]
220 if len(inst) == 1:
221 pc = pc + 1
222 else:
223 pc = pc + 3
224 opname = inst[0]
225 if self.hasjrel.has_elt(opname):
226 oparg = inst[1]
227 offset = begin[oparg] - pc
228 insts[i] = opname, offset
229 elif self.hasjabs.has_elt(opname):
230 insts[i] = opname, begin[inst[1]]
231 self.stacksize = findDepth(self.insts)
232 self.stage = FLAT
234 hasjrel = misc.Set()
235 for i in dis.hasjrel:
236 hasjrel.add(dis.opname[i])
237 hasjabs = misc.Set()
238 for i in dis.hasjabs:
239 hasjabs.add(dis.opname[i])
241 def convertArgs(self):
242 """Convert arguments from symbolic to concrete form"""
243 assert self.stage == FLAT
244 for i in range(len(self.insts)):
245 t = self.insts[i]
246 if len(t) == 2:
247 opname = t[0]
248 oparg = t[1]
249 conv = self._converters.get(opname, None)
250 if conv:
251 self.insts[i] = opname, conv(self, oparg)
252 self.stage = CONV
254 def _lookupName(self, name, list):
255 """Return index of name in list, appending if necessary"""
256 if name in list:
257 i = list.index(name)
258 # this is cheap, but incorrect in some cases, e.g 2 vs. 2L
259 if type(name) == type(list[i]):
260 return i
261 for i in range(len(list)):
262 elt = list[i]
263 if type(elt) == type(name) and elt == name:
264 return i
265 end = len(list)
266 list.append(name)
267 return end
269 _converters = {}
270 def _convert_LOAD_CONST(self, arg):
271 return self._lookupName(arg, self.consts)
273 def _convert_LOAD_FAST(self, arg):
274 self._lookupName(arg, self.names)
275 return self._lookupName(arg, self.varnames)
276 _convert_STORE_FAST = _convert_LOAD_FAST
277 _convert_DELETE_FAST = _convert_LOAD_FAST
279 def _convert_NAME(self, arg):
280 return self._lookupName(arg, self.names)
281 _convert_LOAD_NAME = _convert_NAME
282 _convert_STORE_NAME = _convert_NAME
283 _convert_DELETE_NAME = _convert_NAME
284 _convert_IMPORT_NAME = _convert_NAME
285 _convert_IMPORT_FROM = _convert_NAME
286 _convert_STORE_ATTR = _convert_NAME
287 _convert_LOAD_ATTR = _convert_NAME
288 _convert_DELETE_ATTR = _convert_NAME
289 _convert_LOAD_GLOBAL = _convert_NAME
290 _convert_STORE_GLOBAL = _convert_NAME
291 _convert_DELETE_GLOBAL = _convert_NAME
293 _cmp = list(dis.cmp_op)
294 def _convert_COMPARE_OP(self, arg):
295 return self._cmp.index(arg)
297 # similarly for other opcodes...
299 for name, obj in locals().items():
300 if name[:9] == "_convert_":
301 opname = name[9:]
302 _converters[opname] = obj
303 del name, obj, opname
305 def makeByteCode(self):
306 assert self.stage == CONV
307 self.lnotab = lnotab = LineAddrTable()
308 for t in self.insts:
309 opname = t[0]
310 if len(t) == 1:
311 lnotab.addCode(self.opnum[opname])
312 else:
313 oparg = t[1]
314 if opname == "SET_LINENO":
315 lnotab.nextLine(oparg)
316 hi, lo = twobyte(oparg)
317 try:
318 lnotab.addCode(self.opnum[opname], lo, hi)
319 except ValueError:
320 print opname, oparg
321 print self.opnum[opname], lo, hi
322 raise
323 self.stage = DONE
325 opnum = {}
326 for num in range(len(dis.opname)):
327 opnum[dis.opname[num]] = num
328 del num
330 def newCodeObject(self):
331 assert self.stage == DONE
332 if self.flags == 0:
333 nlocals = 0
334 else:
335 nlocals = len(self.varnames)
336 argcount = self.argcount
337 if self.flags & CO_VARKEYWORDS:
338 argcount = argcount - 1
339 return new.code(argcount, nlocals, self.stacksize, self.flags,
340 self.lnotab.getCode(), self.getConsts(),
341 tuple(self.names), tuple(self.varnames),
342 self.filename, self.name, self.lnotab.firstline,
343 self.lnotab.getTable())
345 def getConsts(self):
346 """Return a tuple for the const slot of the code object
348 Must convert references to code (MAKE_FUNCTION) to code
349 objects recursively.
351 l = []
352 for elt in self.consts:
353 if isinstance(elt, PyFlowGraph):
354 elt = elt.getCode()
355 l.append(elt)
356 return tuple(l)
358 def isJump(opname):
359 if opname[:4] == 'JUMP':
360 return 1
362 class TupleArg:
363 """Helper for marking func defs with nested tuples in arglist"""
364 def __init__(self, count, names):
365 self.count = count
366 self.names = names
367 def __repr__(self):
368 return "TupleArg(%s, %s)" % (self.count, self.names)
369 def getName(self):
370 return ".nested%d" % self.count
372 def getArgCount(args):
373 argcount = len(args)
374 if args:
375 for arg in args:
376 if isinstance(arg, TupleArg):
377 numNames = len(misc.flatten(arg.names))
378 argcount = argcount - numNames
379 return argcount
381 def twobyte(val):
382 """Convert an int argument into high and low bytes"""
383 assert type(val) == types.IntType
384 return divmod(val, 256)
386 class LineAddrTable:
387 """lnotab
389 This class builds the lnotab, which is undocumented but described
390 by com_set_lineno in compile.c. Here's an attempt at explanation:
392 For each SET_LINENO instruction after the first one, two bytes are
393 added to lnotab. (In some cases, multiple two-byte entries are
394 added.) The first byte is the distance in bytes between the
395 instruction for the last SET_LINENO and the current SET_LINENO.
396 The second byte is offset in line numbers. If either offset is
397 greater than 255, multiple two-byte entries are added -- one entry
398 for each factor of 255.
401 def __init__(self):
402 self.code = []
403 self.codeOffset = 0
404 self.firstline = 0
405 self.lastline = 0
406 self.lastoff = 0
407 self.lnotab = []
409 def addCode(self, *args):
410 for arg in args:
411 self.code.append(chr(arg))
412 self.codeOffset = self.codeOffset + len(args)
414 def nextLine(self, lineno):
415 if self.firstline == 0:
416 self.firstline = lineno
417 self.lastline = lineno
418 else:
419 # compute deltas
420 addr = self.codeOffset - self.lastoff
421 line = lineno - self.lastline
422 # Python assumes that lineno always increases with
423 # increasing bytecode address (lnotab is unsigned char).
424 # Depending on when SET_LINENO instructions are emitted
425 # this is not always true. Consider the code:
426 # a = (1,
427 # b)
428 # In the bytecode stream, the assignment to "a" occurs
429 # after the loading of "b". This works with the C Python
430 # compiler because it only generates a SET_LINENO instruction
431 # for the assignment.
432 if line > 0:
433 while addr > 0 or line > 0:
434 # write the values in 1-byte chunks that sum
435 # to desired value
436 trunc_addr = addr
437 trunc_line = line
438 if trunc_addr > 255:
439 trunc_addr = 255
440 if trunc_line > 255:
441 trunc_line = 255
442 self.lnotab.append(trunc_addr)
443 self.lnotab.append(trunc_line)
444 addr = addr - trunc_addr
445 line = line - trunc_line
446 self.lastline = lineno
447 self.lastoff = self.codeOffset
449 def getCode(self):
450 return string.join(self.code, '')
452 def getTable(self):
453 return string.join(map(chr, self.lnotab), '')
455 class StackDepthTracker:
456 # XXX 1. need to keep track of stack depth on jumps
457 # XXX 2. at least partly as a result, this code is broken
459 def findDepth(self, insts):
460 depth = 0
461 maxDepth = 0
462 for i in insts:
463 opname = i[0]
464 delta = self.effect.get(opname, 0)
465 if delta > 1:
466 depth = depth + delta
467 elif delta < 0:
468 if depth > maxDepth:
469 maxDepth = depth
470 depth = depth + delta
471 else:
472 if depth > maxDepth:
473 maxDepth = depth
474 # now check patterns
475 for pat, pat_delta in self.patterns:
476 if opname[:len(pat)] == pat:
477 delta = pat_delta
478 depth = depth + delta
479 break
480 # if we still haven't found a match
481 if delta == 0:
482 meth = getattr(self, opname, None)
483 if meth is not None:
484 depth = depth + meth(i[1])
485 if depth < 0:
486 depth = 0
487 return maxDepth
489 effect = {
490 'POP_TOP': -1,
491 'DUP_TOP': 1,
492 'SLICE+1': -1,
493 'SLICE+2': -1,
494 'SLICE+3': -2,
495 'STORE_SLICE+0': -1,
496 'STORE_SLICE+1': -2,
497 'STORE_SLICE+2': -2,
498 'STORE_SLICE+3': -3,
499 'DELETE_SLICE+0': -1,
500 'DELETE_SLICE+1': -2,
501 'DELETE_SLICE+2': -2,
502 'DELETE_SLICE+3': -3,
503 'STORE_SUBSCR': -3,
504 'DELETE_SUBSCR': -2,
505 # PRINT_EXPR?
506 'PRINT_ITEM': -1,
507 'LOAD_LOCALS': 1,
508 'RETURN_VALUE': -1,
509 'EXEC_STMT': -2,
510 'BUILD_CLASS': -2,
511 'STORE_NAME': -1,
512 'STORE_ATTR': -2,
513 'DELETE_ATTR': -1,
514 'STORE_GLOBAL': -1,
515 'BUILD_MAP': 1,
516 'COMPARE_OP': -1,
517 'STORE_FAST': -1,
519 # use pattern match
520 patterns = [
521 ('BINARY_', -1),
522 ('LOAD_', 1),
523 ('IMPORT_', 1),
526 # special cases:
527 # UNPACK_SEQUENCE, BUILD_TUPLE,
528 # BUILD_LIST, CALL_FUNCTION, MAKE_FUNCTION, BUILD_SLICE
529 def UNPACK_SEQUENCE(self, count):
530 return count
531 def BUILD_TUPLE(self, count):
532 return -count
533 def BUILD_LIST(self, count):
534 return -count
535 def CALL_FUNCTION(self, argc):
536 hi, lo = divmod(argc, 256)
537 return lo + hi * 2
538 def CALL_FUNCTION_VAR(self, argc):
539 return self.CALL_FUNCTION(argc)+1
540 def CALL_FUNCTION_KW(self, argc):
541 return self.CALL_FUNCTION(argc)+1
542 def CALL_FUNCTION_VAR_KW(self, argc):
543 return self.CALL_FUNCTION(argc)+2
544 def MAKE_FUNCTION(self, argc):
545 return -argc
546 def BUILD_SLICE(self, argc):
547 if argc == 2:
548 return -1
549 elif argc == 3:
550 return -2
552 findDepth = StackDepthTracker().findDepth