2 # TreeFragments - parsing of strings to trees
6 from StringIO
import StringIO
7 from Scanning
import PyrexScanner
, StringSourceDescriptor
8 from Symtab
import ModuleScope
10 from Visitor
import VisitorTransform
11 from Nodes
import Node
, StatListNode
12 from ExprNodes
import NameNode
18 Support for parsing strings into code trees.
21 class StringParseContext(Main
.Context
):
22 def __init__(self
, name
, include_directories
=None):
23 if include_directories
is None: include_directories
= []
24 Main
.Context
.__init
__(self
, include_directories
, {},
25 create_testscope
=False)
26 self
.module_name
= name
28 def find_module(self
, module_name
, relative_to
= None, pos
= None, need_pxd
= 1):
29 if module_name
not in (self
.module_name
, 'cython'):
30 raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
31 return ModuleScope(module_name
, parent_module
= None, context
= self
)
33 def parse_from_strings(name
, code
, pxds
={}, level
=None, initial_pos
=None,
34 context
=None, allow_struct_enum_decorator
=False):
36 Utility method to parse a (unicode) string of code. This is mostly
37 used for internal Cython compiler purposes (creating code snippets
38 that transforms should emit, as well as unit testing).
40 code - a unicode string containing Cython (module-level) code
41 name - a descriptive name for the code source (to use in error messages etc.)
45 The tree, i.e. a ModuleNode. The ModuleNode's scope attribute is
46 set to the scope used when parsing.
49 context
= StringParseContext(name
)
50 # Since source files carry an encoding, it makes sense in this context
51 # to use a unicode string so that code fragments don't have to bother
52 # with encoding. This means that test code passed in should not have an
54 assert isinstance(code
, unicode), "unicode code snippets only please"
58 if initial_pos
is None:
59 initial_pos
= (name
, 1, 0)
60 code_source
= StringSourceDescriptor(name
, code
)
62 scope
= context
.find_module(module_name
, pos
= initial_pos
, need_pxd
= 0)
66 scanner
= PyrexScanner(buf
, code_source
, source_encoding
= encoding
,
67 scope
= scope
, context
= context
, initial_pos
= initial_pos
)
68 ctx
= Parsing
.Ctx(allow_struct_enum_decorator
=allow_struct_enum_decorator
)
71 tree
= Parsing
.p_module(scanner
, 0, module_name
, ctx
=ctx
)
75 tree
= Parsing
.p_code(scanner
, level
=level
, ctx
=ctx
)
80 class TreeCopier(VisitorTransform
):
81 def visit_Node(self
, node
):
89 class ApplyPositionAndCopy(TreeCopier
):
90 def __init__(self
, pos
):
91 super(ApplyPositionAndCopy
, self
).__init
__()
94 def visit_Node(self
, node
):
95 copy
= super(ApplyPositionAndCopy
, self
).visit_Node(node
)
99 class TemplateTransform(VisitorTransform
):
101 Makes a copy of a template tree while doing substitutions.
103 A dictionary "substitutions" should be passed in when calling
104 the transform; mapping names to replacement nodes. Then replacement
106 - If an ExprStatNode contains a single NameNode, whose name is
107 a key in the substitutions dictionary, the ExprStatNode is
108 replaced with a copy of the tree given in the dictionary.
109 It is the responsibility of the caller that the replacement
110 node is a valid statement.
111 - If a single NameNode is otherwise encountered, it is replaced
112 if its name is listed in the substitutions dictionary in the
113 same way. It is the responsibility of the caller to make sure
114 that the replacement nodes is a valid expression.
116 Also a list "temps" should be passed. Any names listed will
117 be transformed into anonymous, temporary names.
119 Currently supported for tempnames is:
121 (various function and class definition nodes etc. should be added to this)
123 Each replacement node gets the position of the substituted node
124 recursively applied to every member node.
127 temp_name_counter
= 0
129 def __call__(self
, node
, substitutions
, temps
, pos
):
130 self
.substitutions
= substitutions
135 TemplateTransform
.temp_name_counter
+= 1
136 handle
= UtilNodes
.TempHandle(PyrexTypes
.py_object_type
)
137 tempmap
[temp
] = handle
138 temphandles
.append(handle
)
139 self
.tempmap
= tempmap
140 result
= super(TemplateTransform
, self
).__call
__(node
)
142 result
= UtilNodes
.TempsBlockNode(self
.get_pos(node
),
147 def get_pos(self
, node
):
153 def visit_Node(self
, node
):
157 c
= node
.clone_node()
158 if self
.pos
is not None:
160 self
.visitchildren(c
)
163 def try_substitution(self
, node
, key
):
164 sub
= self
.substitutions
.get(key
)
167 if pos
is None: pos
= node
.pos
168 return ApplyPositionAndCopy(pos
)(sub
)
170 return self
.visit_Node(node
) # make copy as usual
172 def visit_NameNode(self
, node
):
173 temphandle
= self
.tempmap
.get(node
.name
)
175 # Replace name with temporary
176 return temphandle
.ref(self
.get_pos(node
))
178 return self
.try_substitution(node
, node
.name
)
180 def visit_ExprStatNode(self
, node
):
181 # If an expression-as-statement consists of only a replaceable
182 # NameNode, we replace the entire statement, not only the NameNode
183 if isinstance(node
.expr
, NameNode
):
184 return self
.try_substitution(node
, node
.expr
.name
)
186 return self
.visit_Node(node
)
188 def copy_code_tree(node
):
189 return TreeCopier()(node
)
191 INDENT_RE
= re
.compile(ur
"^ *")
192 def strip_common_indent(lines
):
193 "Strips empty lines and common indentation from the list of strings given in lines"
194 # TODO: Facilitate textwrap.indent instead
195 lines
= [x
for x
in lines
if x
.strip() != u
""]
196 minindent
= min([len(INDENT_RE
.match(x
).group(0)) for x
in lines
])
197 lines
= [x
[minindent
:] for x
in lines
]
200 class TreeFragment(object):
201 def __init__(self
, code
, name
="(tree fragment)", pxds
={}, temps
=[], pipeline
=[], level
=None, initial_pos
=None):
202 if isinstance(code
, unicode):
203 def fmt(x
): return u
"\n".join(strip_common_indent(x
.split(u
"\n")))
207 for key
, value
in pxds
.iteritems():
208 fmt_pxds
[key
] = fmt(value
)
209 mod
= t
= parse_from_strings(name
, fmt_code
, fmt_pxds
, level
=level
, initial_pos
=initial_pos
)
211 t
= t
.body
# Make sure a StatListNode is at the top
212 if not isinstance(t
, StatListNode
):
213 t
= StatListNode(pos
=mod
.pos
, stats
=[t
])
214 for transform
in pipeline
:
215 if transform
is None:
219 elif isinstance(code
, Node
):
220 if pxds
!= {}: raise NotImplementedError()
223 raise ValueError("Unrecognized code format (accepts unicode and Node)")
227 return copy_code_tree(self
.root
)
229 def substitute(self
, nodes
={}, temps
=[], pos
= None):
230 return TemplateTransform()(self
.root
,
231 substitutions
= nodes
,
232 temps
= self
.temps
+ temps
, pos
= pos
)
234 class SetPosTransform(VisitorTransform
):
235 def __init__(self
, pos
):
236 super(SetPosTransform
, self
).__init
__()
239 def visit_Node(self
, node
):
241 self
.visitchildren(node
)