1 # cython: infer_types=True
4 # Tree visitor and transform framework
8 from Cython
.Compiler
import TypeSlots
9 from Cython
.Compiler
import Builtin
10 from Cython
.Compiler
import Nodes
11 from Cython
.Compiler
import ExprNodes
12 from Cython
.Compiler
import Errors
13 from Cython
.Compiler
import DebugFlags
18 class TreeVisitor(object):
20 Base class for writing visitors for a Cython tree, contains utilities for
21 recursing such trees using visitors. Each node is
22 expected to have a child_attrs iterable containing the names of attributes
23 containing child nodes or lists of child nodes. Lists are not considered
24 part of the tree structure (i.e. contained nodes are considered direct
25 children of the parent node).
27 visit_children visits each of the children of a given node (see the visit_children
28 documentation). When recursing the tree using visit_children, an attribute
29 access_path is maintained which gives information about the current location
30 in the tree as a stack of tuples: (parent_node, attrname, index), representing
31 the node, attribute and optional list index that was taken in each step in the path to
36 >>> class SampleNode(object):
37 ... child_attrs = ["head", "body"]
38 ... def __init__(self, value, head=None, body=None):
39 ... self.value = value
42 ... def __repr__(self): return "SampleNode(%s)" % self.value
44 >>> tree = SampleNode(0, SampleNode(1), [SampleNode(2), SampleNode(3)])
45 >>> class MyVisitor(TreeVisitor):
46 ... def visit_SampleNode(self, node):
47 ... print "in", node.value, self.access_path
48 ... self.visitchildren(node)
49 ... print "out", node.value
51 >>> MyVisitor().visit(tree)
53 in 1 [(SampleNode(0), 'head', None)]
55 in 2 [(SampleNode(0), 'body', 0)]
57 in 3 [(SampleNode(0), 'body', 1)]
62 super(TreeVisitor
, self
).__init
__()
63 self
.dispatch_table
= {}
66 def dump_node(self
, node
, indent
=0):
67 ignored
= list(node
.child_attrs
or []) + [u
'child_attrs', u
'pos',
68 u
'gil_message', u
'cpp_message',
71 pos
= getattr(node
, 'pos', None)
76 source
= os
.path
.basename(source
.get_description())
77 values
.append(u
'%s:%s:%s' % (source
, pos
[1], pos
[2]))
78 attribute_names
= dir(node
)
79 attribute_names
.sort()
80 for attr
in attribute_names
:
83 if attr
.startswith(u
'_') or attr
.endswith(u
'_'):
86 value
= getattr(node
, attr
)
87 except AttributeError:
89 if value
is None or value
== 0:
91 elif isinstance(value
, list):
92 value
= u
'[...]/%d' % len(value
)
93 elif not isinstance(value
, (str, unicode, long, int, float)):
97 values
.append(u
'%s = %s' % (attr
, value
))
98 return u
'%s(%s)' % (node
.__class
__.__name
__,
101 def _find_node_path(self
, stacktrace
):
103 last_traceback
= stacktrace
105 while hasattr(stacktrace
, 'tb_frame'):
106 frame
= stacktrace
.tb_frame
107 node
= frame
.f_locals
.get(u
'self')
108 if isinstance(node
, Nodes
.Node
):
110 method_name
= code
.co_name
111 pos
= (os
.path
.basename(code
.co_filename
),
113 nodes
.append((node
, method_name
, pos
))
114 last_traceback
= stacktrace
115 stacktrace
= stacktrace
.tb_next
116 return (last_traceback
, nodes
)
118 def _raise_compiler_error(self
, child
, e
):
121 for parent
, attribute
, index
in self
.access_path
:
122 node
= getattr(parent
, attribute
)
127 index
= u
'[%d]' % index
128 trace
.append(u
'%s.%s%s = %s' % (
129 parent
.__class
__.__name
__, attribute
, index
,
130 self
.dump_node(node
)))
131 stacktrace
, called_nodes
= self
._find
_node
_path
(sys
.exc_info()[2])
133 for node
, method_name
, pos
in called_nodes
:
135 trace
.append(u
"File '%s', line %d, in %s: %s" % (
136 pos
[0], pos
[1], method_name
, self
.dump_node(node
)))
137 raise Errors
.CompilerCrash(
138 getattr(last_node
, 'pos', None), self
.__class
__.__name
__,
139 u
'\n'.join(trace
), e
, stacktrace
)
142 def find_handler(self
, obj
):
143 # to resolve, try entire hierarchy
146 mro
= inspect
.getmro(cls
)
147 handler_method
= None
149 handler_method
= getattr(self
, pattern
% mro_cls
.__name
__, None)
150 if handler_method
is not None:
151 return handler_method
152 print type(self
), cls
154 print self
.access_path
155 print self
.access_path
[-1][0].pos
156 print self
.access_path
[-1][0].__dict
__
157 raise RuntimeError("Visitor %r does not accept object: %s" % (self
, obj
))
159 def visit(self
, obj
):
160 return self
._visit
(obj
)
163 def _visit(self
, obj
):
166 handler_method
= self
.dispatch_table
[type(obj
)]
168 handler_method
= self
.find_handler(obj
)
169 self
.dispatch_table
[type(obj
)] = handler_method
170 return handler_method(obj
)
171 except Errors
.CompileError
:
173 except Errors
.AbortError
:
176 if DebugFlags
.debug_no_exception_intercept
:
178 self
._raise
_compiler
_error
(obj
, e
)
181 def _visitchild(self
, child
, parent
, attrname
, idx
):
182 self
.access_path
.append((parent
, attrname
, idx
))
183 result
= self
._visit
(child
)
184 self
.access_path
.pop()
187 def visitchildren(self
, parent
, attrs
=None):
188 return self
._visitchildren
(parent
, attrs
)
191 @cython.locals(idx
=int)
192 def _visitchildren(self
, parent
, attrs
):
194 Visits the children of the given parent. If parent is None, returns
195 immediately (returning None).
197 The return value is a dictionary giving the results for each
198 child (mapping the attribute name to either the return value
199 or a list of return values (in the case of multiple children
202 if parent
is None: return None
204 for attr
in parent
.child_attrs
:
205 if attrs
is not None and attr
not in attrs
: continue
206 child
= getattr(parent
, attr
)
207 if child
is not None:
208 if type(child
) is list:
209 childretval
= [self
._visitchild
(x
, parent
, attr
, idx
) for idx
, x
in enumerate(child
)]
211 childretval
= self
._visitchild
(child
, parent
, attr
, None)
212 assert not isinstance(childretval
, list), 'Cannot insert list here: %s in %r' % (attr
, parent
)
213 result
[attr
] = childretval
217 class VisitorTransform(TreeVisitor
):
219 A tree transform is a base class for visitors that wants to do stream
220 processing of the structure (rather than attributes etc.) of a tree.
222 It implements __call__ to simply visit the argument node.
224 It requires the visitor methods to return the nodes which should take
225 the place of the visited node in the result tree (which can be the same
226 or one or more replacement). Specifically, if the return value from
229 - [] or None; the visited node will be removed (set to None if an attribute and
230 removed if in a list)
231 - A single node; the visited node will be replaced by the returned node.
232 - A list of nodes; the visited nodes will be replaced by all the nodes in the
233 list. This will only work if the node was already a member of a list; if it
234 was not, an exception will be raised. (Typically you want to ensure that you
235 are within a StatListNode or similar before doing this.)
237 def visitchildren(self
, parent
, attrs
=None):
238 result
= self
._visitchildren
(parent
, attrs
)
239 for attr
, newnode
in result
.iteritems():
240 if type(newnode
) is not list:
241 setattr(parent
, attr
, newnode
)
243 # Flatten the list one level and remove any None
251 setattr(parent
, attr
, newlist
)
254 def recurse_to_children(self
, node
):
255 self
.visitchildren(node
)
258 def __call__(self
, root
):
259 return self
._visit
(root
)
261 class CythonTransform(VisitorTransform
):
263 Certain common conventions and utilities for Cython transforms.
265 - Sets up the context of the pipeline in self.context
266 - Tracks directives in effect in self.current_directives
268 def __init__(self
, context
):
269 super(CythonTransform
, self
).__init
__()
270 self
.context
= context
272 def __call__(self
, node
):
274 if isinstance(node
, ModuleNode
.ModuleNode
):
275 self
.current_directives
= node
.directives
276 return super(CythonTransform
, self
).__call
__(node
)
278 def visit_CompilerDirectivesNode(self
, node
):
279 old
= self
.current_directives
280 self
.current_directives
= node
.directives
281 self
.visitchildren(node
)
282 self
.current_directives
= old
285 def visit_Node(self
, node
):
286 self
.visitchildren(node
)
289 class ScopeTrackingTransform(CythonTransform
):
290 # Keeps track of type of scopes
291 #scope_type: can be either of 'module', 'function', 'cclass', 'pyclass', 'struct'
292 #scope_node: the node that owns the current scope
294 def visit_ModuleNode(self
, node
):
295 self
.scope_type
= 'module'
296 self
.scope_node
= node
297 self
.visitchildren(node
)
300 def visit_scope(self
, node
, scope_type
):
301 prev
= self
.scope_type
, self
.scope_node
302 self
.scope_type
= scope_type
303 self
.scope_node
= node
304 self
.visitchildren(node
)
305 self
.scope_type
, self
.scope_node
= prev
308 def visit_CClassDefNode(self
, node
):
309 return self
.visit_scope(node
, 'cclass')
311 def visit_PyClassDefNode(self
, node
):
312 return self
.visit_scope(node
, 'pyclass')
314 def visit_FuncDefNode(self
, node
):
315 return self
.visit_scope(node
, 'function')
317 def visit_CStructOrUnionDefNode(self
, node
):
318 return self
.visit_scope(node
, 'struct')
321 class EnvTransform(CythonTransform
):
323 This transformation keeps a stack of the environments.
325 def __call__(self
, root
):
327 self
.enter_scope(root
, root
.scope
)
328 return super(EnvTransform
, self
).__call
__(root
)
330 def current_env(self
):
331 return self
.env_stack
[-1][1]
333 def current_scope_node(self
):
334 return self
.env_stack
[-1][0]
336 def global_scope(self
):
337 return self
.current_env().global_scope()
339 def enter_scope(self
, node
, scope
):
340 self
.env_stack
.append((node
, scope
))
342 def exit_scope(self
):
345 def visit_FuncDefNode(self
, node
):
346 self
.enter_scope(node
, node
.local_scope
)
347 self
.visitchildren(node
)
351 def visit_GeneratorBodyDefNode(self
, node
):
352 self
.visitchildren(node
)
355 def visit_ClassDefNode(self
, node
):
356 self
.enter_scope(node
, node
.scope
)
357 self
.visitchildren(node
)
361 def visit_CStructOrUnionDefNode(self
, node
):
362 self
.enter_scope(node
, node
.scope
)
363 self
.visitchildren(node
)
367 def visit_ScopedExprNode(self
, node
):
369 self
.enter_scope(node
, node
.expr_scope
)
370 self
.visitchildren(node
)
373 self
.visitchildren(node
)
376 def visit_CArgDeclNode(self
, node
):
377 # default arguments are evaluated in the outer scope
379 attrs
= [ attr
for attr
in node
.child_attrs
if attr
!= 'default' ]
380 self
.visitchildren(node
, attrs
)
381 self
.enter_scope(node
, self
.current_env().outer_scope
)
382 self
.visitchildren(node
, ('default',))
385 self
.visitchildren(node
)
389 class NodeRefCleanupMixin(object):
391 Clean up references to nodes that were replaced.
393 NOTE: this implementation assumes that the replacement is
394 done first, before hitting any further references during
395 normal tree traversal. This needs to be arranged by calling
396 "self.visitchildren()" at a proper place in the transform
397 and by ordering the "child_attrs" of nodes appropriately.
399 def __init__(self
, *args
):
400 super(NodeRefCleanupMixin
, self
).__init
__(*args
)
401 self
._replacements
= {}
403 def visit_CloneNode(self
, node
):
405 if arg
not in self
._replacements
:
406 self
.visitchildren(node
)
408 node
.arg
= self
._replacements
.get(arg
, arg
)
411 def visit_ResultRefNode(self
, node
):
412 expr
= node
.expression
413 if expr
is None or expr
not in self
._replacements
:
414 self
.visitchildren(node
)
415 expr
= node
.expression
417 node
.expression
= self
._replacements
.get(expr
, expr
)
420 def replace(self
, node
, replacement
):
421 self
._replacements
[node
] = replacement
425 find_special_method_for_binary_operator
= {
435 '//': '__floordiv__',
444 'in': '__contains__',
448 find_special_method_for_unary_operator
= {
456 class MethodDispatcherTransform(EnvTransform
):
458 Base class for transformations that want to intercept on specific
459 builtin functions or methods of builtin types, including special
460 methods triggered by Python operators. Must run after declaration
461 analysis when entries were assigned.
463 Naming pattern for handler methods is as follows:
465 * builtin functions: _handle_(general|simple|any)_function_NAME
467 * builtin methods: _handle_(general|simple|any)_method_TYPENAME_METHODNAME
469 # only visit call nodes and Python operations
470 def visit_GeneralCallNode(self
, node
):
471 self
.visitchildren(node
)
472 function
= node
.function
473 if not function
.type.is_pyobject
:
475 arg_tuple
= node
.positional_args
476 if not isinstance(arg_tuple
, ExprNodes
.TupleNode
):
478 keyword_args
= node
.keyword_args
479 if keyword_args
and not isinstance(keyword_args
, ExprNodes
.DictNode
):
480 # can't handle **kwargs
482 args
= arg_tuple
.args
483 return self
._dispatch
_to
_handler
(node
, function
, args
, keyword_args
)
485 def visit_SimpleCallNode(self
, node
):
486 self
.visitchildren(node
)
487 function
= node
.function
488 if function
.type.is_pyobject
:
489 arg_tuple
= node
.arg_tuple
490 if not isinstance(arg_tuple
, ExprNodes
.TupleNode
):
492 args
= arg_tuple
.args
495 return self
._dispatch
_to
_handler
(node
, function
, args
, None)
497 def visit_PrimaryCmpNode(self
, node
):
499 # not currently handled below
500 self
.visitchildren(node
)
502 return self
._visit
_binop
_node
(node
)
504 def visit_BinopNode(self
, node
):
505 return self
._visit
_binop
_node
(node
)
507 def _visit_binop_node(self
, node
):
508 self
.visitchildren(node
)
509 # FIXME: could special case 'not_in'
510 special_method_name
= find_special_method_for_binary_operator(node
.operator
)
511 if special_method_name
:
512 operand1
, operand2
= node
.operand1
, node
.operand2
513 if special_method_name
== '__contains__':
514 operand1
, operand2
= operand2
, operand1
515 obj_type
= operand1
.type
516 if obj_type
.is_builtin_type
:
517 type_name
= obj_type
.name
519 type_name
= "object" # safety measure
520 node
= self
._dispatch
_to
_method
_handler
(
521 special_method_name
, None, False, type_name
,
522 node
, None, [operand1
, operand2
], None)
525 def visit_UnopNode(self
, node
):
526 self
.visitchildren(node
)
527 special_method_name
= find_special_method_for_unary_operator(node
.operator
)
528 if special_method_name
:
529 operand
= node
.operand
530 obj_type
= operand
.type
531 if obj_type
.is_builtin_type
:
532 type_name
= obj_type
.name
534 type_name
= "object" # safety measure
535 node
= self
._dispatch
_to
_method
_handler
(
536 special_method_name
, None, False, type_name
,
537 node
, None, [operand
], None)
540 ### dispatch to specific handlers
542 def _find_handler(self
, match_name
, has_kwargs
):
543 call_type
= has_kwargs
and 'general' or 'simple'
544 handler
= getattr(self
, '_handle_%s_%s' % (call_type
, match_name
), None)
546 handler
= getattr(self
, '_handle_any_%s' % match_name
, None)
549 def _delegate_to_assigned_value(self
, node
, function
, arg_list
, kwargs
):
550 assignment
= function
.cf_state
[0]
551 value
= assignment
.rhs
553 if not value
.entry
or len(value
.entry
.cf_assignments
) > 1:
554 # the variable might have been reassigned => play safe
556 elif value
.is_attribute
and value
.obj
.is_name
:
557 if not value
.obj
.entry
or len(value
.obj
.entry
.cf_assignments
) > 1:
558 # the underlying variable might have been reassigned => play safe
562 return self
._dispatch
_to
_handler
(
563 node
, value
, arg_list
, kwargs
)
565 def _dispatch_to_handler(self
, node
, function
, arg_list
, kwargs
):
567 # we only consider functions that are either builtin
568 # Python functions or builtins that were already replaced
569 # into a C function call (defined in the builtin scope)
570 if not function
.entry
:
573 function
.entry
.is_builtin
or
574 function
.entry
is self
.current_env().builtin_scope().lookup_here(function
.name
))
576 if function
.cf_state
and function
.cf_state
.is_single
:
577 # we know the value of the variable
578 # => see if it's usable instead
579 return self
._delegate
_to
_assigned
_value
(
580 node
, function
, arg_list
, kwargs
)
582 function_handler
= self
._find
_handler
(
583 "function_%s" % function
.name
, kwargs
)
584 if function_handler
is None:
585 return self
._handle
_function
(node
, function
.name
, function
, arg_list
, kwargs
)
587 return function_handler(node
, function
, arg_list
, kwargs
)
589 return function_handler(node
, function
, arg_list
)
590 elif function
.is_attribute
and function
.type.is_pyobject
:
591 attr_name
= function
.attribute
592 self_arg
= function
.obj
593 obj_type
= self_arg
.type
594 is_unbound_method
= False
595 if obj_type
.is_builtin_type
:
596 if (obj_type
is Builtin
.type_type
and self_arg
.is_name
and
597 arg_list
and arg_list
[0].type.is_pyobject
):
598 # calling an unbound method like 'list.append(L,x)'
599 # (ignoring 'type.mro()' here ...)
600 type_name
= self_arg
.name
602 is_unbound_method
= True
604 type_name
= obj_type
.name
606 type_name
= "object" # safety measure
607 return self
._dispatch
_to
_method
_handler
(
608 attr_name
, self_arg
, is_unbound_method
, type_name
,
609 node
, function
, arg_list
, kwargs
)
613 def _dispatch_to_method_handler(self
, attr_name
, self_arg
,
614 is_unbound_method
, type_name
,
615 node
, function
, arg_list
, kwargs
):
616 method_handler
= self
._find
_handler
(
617 "method_%s_%s" % (type_name
, attr_name
), kwargs
)
618 if method_handler
is None:
619 if (attr_name
in TypeSlots
.method_name_to_slot
620 or attr_name
== '__new__'):
621 method_handler
= self
._find
_handler
(
622 "slot%s" % attr_name
, kwargs
)
623 if method_handler
is None:
624 return self
._handle
_method
(
625 node
, type_name
, attr_name
, function
,
626 arg_list
, is_unbound_method
, kwargs
)
627 if self_arg
is not None:
628 arg_list
= [self_arg
] + list(arg_list
)
630 return method_handler(
631 node
, function
, arg_list
, is_unbound_method
, kwargs
)
633 return method_handler(
634 node
, function
, arg_list
, is_unbound_method
)
636 def _handle_function(self
, node
, function_name
, function
, arg_list
, kwargs
):
637 """Fallback handler"""
640 def _handle_method(self
, node
, type_name
, attr_name
, function
,
641 arg_list
, is_unbound_method
, kwargs
):
642 """Fallback handler"""
646 class RecursiveNodeReplacer(VisitorTransform
):
648 Recursively replace all occurrences of a node in a subtree by
651 def __init__(self
, orig_node
, new_node
):
652 super(RecursiveNodeReplacer
, self
).__init
__()
653 self
.orig_node
, self
.new_node
= orig_node
, new_node
655 def visit_Node(self
, node
):
656 self
.visitchildren(node
)
657 if node
is self
.orig_node
:
662 def recursively_replace_node(tree
, old_node
, new_node
):
663 replace_in
= RecursiveNodeReplacer(old_node
, new_node
)
667 class NodeFinder(TreeVisitor
):
669 Find out if a node appears in a subtree.
671 def __init__(self
, node
):
672 super(NodeFinder
, self
).__init
__()
676 def visit_Node(self
, node
):
679 elif node
is self
.node
:
682 self
._visitchildren
(node
, None)
684 def tree_contains(tree
, node
):
685 finder
= NodeFinder(node
)
691 def replace_node(ptr
, value
):
692 """Replaces a node. ptr is of the form used on the access path stack
693 (parent, attrname, listidx|None)
695 parent
, attrname
, listidx
= ptr
697 setattr(parent
, attrname
, value
)
699 getattr(parent
, attrname
)[listidx
] = value
701 class PrintTree(TreeVisitor
):
702 """Prints a representation of the tree to standard output.
703 Subclass and override repr_of to provide more information
706 TreeVisitor
.__init
__(self
)
712 self
._indent
= self
._indent
[:-2]
714 def __call__(self
, tree
, phase
=None):
715 print("Parse tree dump at phase '%s'" % phase
)
719 # Don't do anything about process_list, the defaults gives
720 # nice-looking name[idx] nodes which will visually appear
721 # under the parent-node, not displaying the list itself in
723 def visit_Node(self
, node
):
724 if len(self
.access_path
) == 0:
727 parent
, attr
, idx
= self
.access_path
[-1]
729 name
= "%s[%d]" % (attr
, idx
)
732 print("%s- %s: %s" % (self
._indent
, name
, self
.repr_of(node
)))
734 self
.visitchildren(node
)
738 def repr_of(self
, node
):
742 result
= node
.__class
__.__name
__
743 if isinstance(node
, ExprNodes
.NameNode
):
744 result
+= "(type=%s, name=\"%s\")" % (repr(node
.type), node
.name
)
745 elif isinstance(node
, Nodes
.DefNode
):
746 result
+= "(name=\"%s\")" % node
.name
747 elif isinstance(node
, ExprNodes
.ExprNode
):
749 result
+= "(type=%s)" % repr(t
)
752 path
= pos
[0].get_description()
754 path
= path
.split('/')[-1]
756 path
= path
.split('\\')[-1]
757 result
+= "(pos=(%s:%s:%s))" % (path
, pos
[1], pos
[2])
761 if __name__
== "__main__":