zip() now uses tuples rather than lists
[pythonc.git] / transform.py
blob97f61b3a88dececb41db35fa0d0821b4d3471850
1 ################################################################################
2 ##
3 ## Pythonc--Python to C++ translator
4 ##
5 ## Copyright 2011 Zach Wegner
6 ##
7 ## This file is part of Pythonc.
8 ##
9 ## Pythonc is free software: you can redistribute it and/or modify
10 ## it under the terms of the GNU General Public License as published by
11 ## the Free Software Foundation, either version 3 of the License, or
12 ## (at your option) any later version.
14 ## Pythonc is distributed in the hope that it will be useful,
15 ## but WITHOUT ANY WARRANTY; without even the implied warranty of
16 ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 ## GNU General Public License for more details.
19 ## You should have received a copy of the GNU General Public License
20 ## along with Pythonc. If not, see <http://www.gnu.org/licenses/>.
22 ################################################################################
24 import ast
25 import sys
27 import syntax
29 class Transformer(ast.NodeTransformer):
30 def __init__(self):
31 self.temp_id = 0
32 self.statements = []
33 self.functions = []
34 self.in_class = False
35 self.in_function = False
36 self.globals_set = None
38 def get_temp(self):
39 self.temp_id += 1
40 return syntax.Identifier('temp_%02i' % self.temp_id)
42 def flatten_node(self, node, statements=None):
43 old_stmts = self.statements
44 if statements is not None:
45 self.statements = statements
46 node = self.visit(node)
47 if node.is_atom():
48 r = node
49 else:
50 temp = self.get_temp()
51 self.statements.append(syntax.Assign(temp, node))
52 r = temp
53 self.statements = old_stmts
54 return r
56 def flatten_list(self, node_list):
57 old_stmts = self.statements
58 statements = []
59 for stmt in node_list:
60 self.statements = []
61 stmts = self.visit(stmt)
62 if stmts:
63 if isinstance(stmts, list):
64 statements += self.statements + stmts
65 else:
66 statements += self.statements + [stmts]
67 self.statements = old_stmts
68 return statements
70 def get_globals(self, node, globals_set, locals_set, all_vars_set):
71 if isinstance(node, ast.Global):
72 for name in node.names:
73 globals_set.add(name)
74 elif isinstance(node, ast.Name):
75 all_vars_set.add(node.id)
76 if isinstance(node.ctx, (ast.Store, ast.AugStore)):
77 locals_set.add(node.id)
78 elif isinstance(node, ast.arg):
79 all_vars_set.add(node.arg)
80 locals_set.add(node.arg)
81 for i in ast.iter_child_nodes(node):
82 self.get_globals(i, globals_set, locals_set, all_vars_set)
84 def get_binding(self, name):
85 if self.in_function:
86 if name in self.globals_set:
87 return 'global'
88 return 'local'
89 elif self.in_class:
90 return 'class'
91 return 'global'
93 def generic_visit(self, node):
94 print(node.lineno)
95 raise RuntimeError('can\'t translate %s' % node)
97 def visit_children(self, node):
98 return [self.visit(i) for i in ast.iter_child_nodes(node)]
100 def visit_Name(self, node):
101 assert isinstance(node.ctx, ast.Load)
102 if node.id in ['True', 'False']:
103 return syntax.BoolConst(node.id == 'True')
104 elif node.id == 'None':
105 return syntax.NoneConst()
106 return syntax.Load(node.id, self.get_binding(node.id))
108 def visit_Num(self, node):
109 if isinstance(node.n, float):
110 raise RuntimeError('Pythonc currently does not support float literals')
111 assert isinstance(node.n, int)
112 return syntax.IntConst(node.n)
114 def visit_Str(self, node):
115 assert isinstance(node.s, str)
116 return syntax.StringConst(node.s)
118 def visit_Bytes(self, node):
119 raise RuntimeError('Pythonc currently does not support bytes literals')
121 # Unary Ops
122 def visit_Invert(self, node): return '__invert__'
123 def visit_Not(self, node): return '__not__'
124 def visit_UAdd(self, node): return '__pos__'
125 def visit_USub(self, node): return '__neg__'
126 def visit_UnaryOp(self, node):
127 op = self.visit(node.op)
128 rhs = self.flatten_node(node.operand)
129 return syntax.UnaryOp(op, rhs)
131 # Binary Ops
132 def visit_Add(self, node): return '__add__'
133 def visit_BitAnd(self, node): return '__and__'
134 def visit_BitOr(self, node): return '__or__'
135 def visit_BitXor(self, node): return '__xor__'
136 def visit_Div(self, node): return '__truediv__'
137 def visit_FloorDiv(self, node): return '__floordiv__'
138 def visit_LShift(self, node): return '__lshift__'
139 def visit_Mod(self, node): return '__mod__'
140 def visit_Mult(self, node): return '__mul__'
141 def visit_Pow(self, node): return '__pow__'
142 def visit_RShift(self, node): return '__rshift__'
143 def visit_Sub(self, node): return '__sub__'
145 def visit_BinOp(self, node):
146 op = self.visit(node.op)
147 lhs = self.flatten_node(node.left)
148 rhs = self.flatten_node(node.right)
149 return syntax.BinaryOp(op, lhs, rhs)
151 # Comparisons
152 def visit_Eq(self, node): return '__eq__'
153 def visit_NotEq(self, node): return '__ne__'
154 def visit_Lt(self, node): return '__lt__'
155 def visit_LtE(self, node): return '__le__'
156 def visit_Gt(self, node): return '__gt__'
157 def visit_GtE(self, node): return '__ge__'
158 def visit_In(self, node): return '__contains__'
159 def visit_NotIn(self, node): return '__ncontains__'
160 def visit_Is(self, node): return '__is__'
161 def visit_IsNot(self, node): return '__isnot__'
163 def visit_Compare(self, node):
164 assert len(node.ops) == 1
165 assert len(node.comparators) == 1
166 op = self.visit(node.ops[0])
167 lhs = self.flatten_node(node.left)
168 rhs = self.flatten_node(node.comparators[0])
169 # Sigh--Python has these ordered weirdly
170 if op in ['__contains__', '__ncontains__']:
171 lhs, rhs = rhs, lhs
172 return syntax.BinaryOp(op, lhs, rhs)
174 # Bool ops
175 def visit_And(self, node): return 'and'
176 def visit_Or(self, node): return 'or'
178 def visit_BoolOp(self, node):
179 assert len(node.values) >= 2
180 op = self.visit(node.op)
181 rhs_stmts = []
182 rhs_expr = self.flatten_node(node.values[-1], statements=rhs_stmts)
183 for v in reversed(node.values[:-1]):
184 lhs_stmts = []
185 lhs = self.flatten_node(v, statements=lhs_stmts)
186 bool_op = syntax.BoolOp(op, lhs, rhs_stmts, rhs_expr)
187 rhs_expr = bool_op.flatten(self, lhs_stmts)
188 rhs_stmts = lhs_stmts
189 self.statements += rhs_stmts
190 return rhs_expr
192 def visit_IfExp(self, node):
193 expr = self.flatten_node(node.test)
194 true_stmts = []
195 true_expr = self.flatten_node(node.body, statements=true_stmts)
196 false_stmts = []
197 false_expr = self.flatten_node(node.orelse, statements=false_stmts)
198 if_exp = syntax.IfExp(expr, true_stmts, true_expr, false_stmts, false_expr)
199 return if_exp.flatten(self)
201 def visit_List(self, node):
202 items = [self.flatten_node(i) for i in node.elts]
203 l = syntax.List(items)
204 return l.flatten(self)
206 def visit_Tuple(self, node):
207 items = [self.flatten_node(i) for i in node.elts]
208 l = syntax.Tuple(items)
209 return l.flatten(self)
211 def visit_Dict(self, node):
212 keys = [self.flatten_node(i) for i in node.keys]
213 values = [self.flatten_node(i) for i in node.values]
214 d = syntax.Dict(keys, values)
215 return d.flatten(self)
217 def visit_Set(self, node):
218 items = [self.flatten_node(i) for i in node.elts]
219 d = syntax.Set(items)
220 return d.flatten(self)
222 def visit_Subscript(self, node):
223 l = self.flatten_node(node.value)
224 if isinstance(node.slice, ast.Index):
225 index = self.flatten_node(node.slice.value)
226 return syntax.Subscript(l, index)
227 elif isinstance(node.slice, ast.Slice):
228 [start, end, step] = [self.flatten_node(a) if a else syntax.NoneConst() for a in
229 [node.slice.lower, node.slice.upper, node.slice.step]]
230 return syntax.Slice(l, start, end, step)
232 def visit_Attribute(self, node):
233 assert isinstance(node.ctx, ast.Load)
234 l = self.flatten_node(node.value)
235 attr = syntax.Attribute(l, syntax.StringConst(node.attr))
236 return attr
238 def visit_Call(self, node):
239 fn = self.flatten_node(node.func)
241 if node.starargs:
242 assert not node.args
243 assert not node.kwargs
244 args = self.flatten_node(node.starargs)
245 kwargs = syntax.Dict([], [])
246 else:
247 args = syntax.List([self.flatten_node(a) for a in node.args])
248 args = args.flatten(self)
250 keys = [syntax.StringConst(i.arg) for i in node.keywords]
251 values = [self.flatten_node(i.value) for i in node.keywords]
252 kwargs = syntax.Dict(keys, values)
254 kwargs = kwargs.flatten(self)
255 return syntax.Call(fn, args, kwargs)
257 def visit_Assign(self, node):
258 assert len(node.targets) == 1
259 target = node.targets[0]
260 value = self.flatten_node(node.value)
261 if isinstance(target, ast.Name):
262 return [syntax.Store(target.id, value, self.get_binding(target.id))]
263 elif isinstance(target, ast.Tuple):
264 assert all(isinstance(t, ast.Name) for t in target.elts)
265 stmts = []
266 for i, t in enumerate(target.elts):
267 stmts += [syntax.Store(t.id, syntax.Subscript(value, syntax.IntConst(i)), self.get_binding(t.id))]
268 return stmts
269 elif isinstance(target, ast.Attribute):
270 base = self.flatten_node(target.value)
271 return [syntax.StoreAttr(base, syntax.StringConst(target.attr), value)]
272 elif isinstance(target, ast.Subscript):
273 assert isinstance(target.slice, ast.Index)
274 base = self.flatten_node(target.value)
275 index = self.flatten_node(target.slice.value)
276 return [syntax.StoreSubscript(base, index, value)]
277 else:
278 assert False
280 def visit_AugAssign(self, node):
281 op = self.visit(node.op)
282 value = self.flatten_node(node.value)
283 if isinstance(node.target, ast.Name):
284 target = node.target.id
285 # XXX HACK: doesn't modify in place
286 binop = syntax.BinaryOp(op, syntax.Load(target, self.get_binding(target)), value)
287 return [syntax.Store(target, binop, self.get_binding(target))]
288 elif isinstance(node.target, ast.Attribute):
289 l = self.flatten_node(node.target.value)
290 attr_name = syntax.StringConst(node.target.attr)
291 attr = syntax.Attribute(l, attr_name)
292 binop = syntax.BinaryOp(op, attr, value)
293 return [syntax.StoreAttr(l, attr_name, binop)]
294 elif isinstance(node.target, ast.Subscript):
295 assert isinstance(node.target.slice, ast.Index)
296 base = self.flatten_node(node.target.value)
297 index = self.flatten_node(node.target.slice.value)
298 old = syntax.Subscript(base, index)
299 binop = syntax.BinaryOp(op, old, value)
300 return [syntax.StoreSubscript(base, index, binop)]
301 else:
302 assert False
304 def visit_Delete(self, node):
305 assert len(node.targets) == 1
306 target = node.targets[0]
307 assert isinstance(target, ast.Subscript)
308 assert isinstance(target.slice, ast.Index)
310 name = self.flatten_node(target.value)
311 value = self.flatten_node(target.slice.value)
312 return [syntax.DeleteSubscript(name, value)]
314 def visit_If(self, node):
315 expr = self.flatten_node(node.test)
316 stmts = self.flatten_list(node.body)
317 if node.orelse:
318 else_block = self.flatten_list(node.orelse)
319 else:
320 else_block = None
321 return syntax.If(expr, stmts, else_block)
323 def visit_Break(self, node):
324 return syntax.Break()
326 def visit_Continue(self, node):
327 return syntax.Continue()
329 def visit_For(self, node):
330 assert not node.orelse
331 iter = self.flatten_node(node.iter)
332 stmts = self.flatten_list(node.body)
334 if isinstance(node.target, ast.Name):
335 target = (node.target.id, self.get_binding(node.target.id))
336 elif isinstance(node.target, ast.Tuple):
337 target = [(t.id, self.get_binding(t.id)) for t in node.target.elts]
338 else:
339 assert False
340 for_loop = syntax.For(target, iter, stmts)
341 return for_loop.flatten(self)
343 def visit_While(self, node):
344 assert not node.orelse
345 test_stmts = []
346 test = self.flatten_node(node.test, statements=test_stmts)
347 stmts = self.flatten_list(node.body)
348 return syntax.While(test_stmts, test, stmts)
350 def visit_Comprehension(self, node, comp_type):
351 assert len(node.generators) == 1
352 gen = node.generators[0]
353 assert len(gen.ifs) <= 1
355 if isinstance(gen.target, ast.Name):
356 target = gen.target.id
357 elif isinstance(gen.target, ast.Tuple):
358 target = gen.target.elts
359 else:
360 assert False
362 iter = self.flatten_node(gen.iter)
363 cond_stmts = []
364 expr_stmts = []
365 cond = None
366 if gen.ifs:
367 cond = self.flatten_node(gen.ifs[0], statements=cond_stmts)
368 if comp_type == 'dict':
369 expr = self.flatten_node(node.key, statements=expr_stmts)
370 expr2 = self.flatten_node(node.value, statements=expr_stmts)
371 else:
372 expr = self.flatten_node(node.elt, statements=expr_stmts)
373 expr2 = None
374 comp = syntax.Comprehension(comp_type, target, iter, cond_stmts, cond, expr_stmts, expr, expr2)
375 return comp.flatten(self)
377 def visit_ListComp(self, node):
378 return self.visit_Comprehension(node, 'list')
380 def visit_SetComp(self, node):
381 return self.visit_Comprehension(node, 'set')
383 def visit_DictComp(self, node):
384 return self.visit_Comprehension(node, 'dict')
386 def visit_GeneratorExp(self, node):
387 return self.visit_Comprehension(node, 'generator')
389 def visit_Return(self, node):
390 if node.value is not None:
391 expr = self.flatten_node(node.value)
392 return syntax.Return(expr)
393 else:
394 return syntax.Return(None)
396 def visit_Assert(self, node):
397 expr = self.flatten_node(node.test)
398 return syntax.Assert(expr, node.lineno)
400 def visit_arguments(self, node):
401 assert not node.vararg
402 assert not node.kwarg
404 args = [a.arg for a in node.args]
405 defaults = self.flatten_list(node.defaults)
406 args = syntax.Arguments(args, defaults)
407 return args.flatten(self)
409 def visit_FunctionDef(self, node):
410 assert not self.in_function
412 # Get bindings of all variables. Globals are the variables that have "global x"
413 # somewhere in the function, or are never written in the function.
414 globals_set = set()
415 locals_set = set()
416 all_vars_set = set()
417 self.get_globals(node, globals_set, locals_set, all_vars_set)
418 globals_set |= (all_vars_set - locals_set)
420 # Set some state and recursively visit child nodes, then restore state
421 self.globals_set = globals_set
422 self.in_function = True
423 args = self.visit(node.args)
424 body = self.flatten_list(node.body)
425 self.globals_set = None
426 self.in_function = False
428 exp_name = node.exp_name if 'exp_name' in dir(node) else None
429 fn = syntax.FunctionDef(node.name, args, body, exp_name, self.get_binding(node.name))
430 return fn.flatten(self)
432 def visit_ClassDef(self, node):
433 assert not node.bases
434 assert not node.keywords
435 assert not node.starargs
436 assert not node.kwargs
437 assert not node.decorator_list
438 assert not self.in_class
439 assert not self.in_function
441 for fn in node.body:
442 if isinstance(fn, ast.FunctionDef):
443 fn.exp_name = '_%s_%s' % (node.name, fn.name)
445 self.in_class = True
446 body = self.flatten_list(node.body)
447 self.in_class = False
449 c = syntax.ClassDef(node.name, body)
450 return c.flatten(self)
452 def visit_Expr(self, node):
453 return self.visit(node.value)
455 def visit_Module(self, node):
456 return self.flatten_list(node.body)
458 def visit_Pass(self, node): pass
459 def visit_Load(self, node): pass
460 def visit_Store(self, node): pass
461 def visit_Global(self, node): pass
463 with open(sys.argv[1]) as f:
464 node = ast.parse(f.read())
466 x = Transformer()
467 node = x.visit(node)
469 with open(sys.argv[2], 'w') as f:
470 f.write('#include "backend.cpp"\n')
471 syntax.export_consts(f)
473 for func in x.functions:
474 f.write('%s\n' % func)
476 f.write('int main(int argc, char **argv) {\n')
477 f.write(' context ctx, *globals = &ctx;\n')
478 f.write(' init_context(&ctx, argc, argv);\n')
480 for stmt in node:
481 f.write(' %s;\n' % stmt)
483 f.write('}\n')