added a few test cases for zip() with lists not the same length
[pythonc.git] / transform.py
blobaeb098b1a6bccd7c2063e24191ab3f2ca707354e
1 ################################################################################
2 ##
3 ## Pythonc--Python to C++ translator
4 ##
5 ## Copyright 2011 Zach Wegner
6 ##
7 ## This file is part of Pythonc.
8 ##
9 ## Pythonc is free software: you can redistribute it and/or modify
10 ## it under the terms of the GNU General Public License as published by
11 ## the Free Software Foundation, either version 3 of the License, or
12 ## (at your option) any later version.
14 ## Pythonc is distributed in the hope that it will be useful,
15 ## but WITHOUT ANY WARRANTY; without even the implied warranty of
16 ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 ## GNU General Public License for more details.
19 ## You should have received a copy of the GNU General Public License
20 ## along with Pythonc. If not, see <http://www.gnu.org/licenses/>.
22 ################################################################################
24 import ast
25 import sys
27 import syntax
29 builtin_functions = [
30 'fread',
31 'isinstance',
32 'len',
33 'open',
34 'ord',
35 'print',
36 'print_nonl',
37 'repr',
38 'sorted',
40 builtin_classes = [
41 'bool',
42 'dict',
43 'enumerate',
44 'int',
45 'list',
46 'range',
47 'reversed',
48 'set',
49 'str',
50 'tuple',
51 'zip',
53 builtin_symbols = builtin_functions + builtin_classes + [
54 '__name__',
55 '__args__',
58 class Transformer(ast.NodeTransformer):
59 def __init__(self):
60 self.temp_id = 0
61 self.statements = []
62 self.functions = []
63 self.in_class = False
64 self.in_function = False
65 self.globals_set = None
67 def get_temp_name(self):
68 self.temp_id += 1
69 return 'temp_%02i' % self.temp_id
71 def get_temp(self):
72 self.temp_id += 1
73 return syntax.Identifier('temp_%02i' % self.temp_id)
75 def flatten_node(self, node, statements=None):
76 old_stmts = self.statements
77 if statements is not None:
78 self.statements = statements
79 node = self.visit(node)
80 if node.is_atom():
81 r = node
82 else:
83 temp = self.get_temp()
84 self.statements.append(syntax.Assign(temp, node))
85 r = temp
86 self.statements = old_stmts
87 return r
89 def flatten_list(self, node_list):
90 old_stmts = self.statements
91 statements = []
92 for stmt in node_list:
93 self.statements = []
94 stmts = self.visit(stmt)
95 if stmts:
96 if isinstance(stmts, list):
97 statements += self.statements + stmts
98 else:
99 statements += self.statements + [stmts]
100 self.statements = old_stmts
101 return statements
103 def index_global_class_symbols(self, node, globals_set, class_set):
104 if isinstance(node, ast.Global):
105 for name in node.names:
106 globals_set.add(name)
107 # XXX make this check scope
108 elif isinstance(node, ast.Name) and isinstance(node.ctx,
109 (ast.Store, ast.AugStore)):
110 globals_set.add(node.id)
111 class_set.add(node.id)
112 elif isinstance(node, (ast.FunctionDef, ast.ClassDef)):
113 globals_set.add(node.name)
114 class_set.add(node.name)
115 elif isinstance(node, (ast.For, ast.ListComp, ast.DictComp, ast.SetComp,
116 ast.GeneratorExp)):
117 # HACK: set self.iter_temp for the space in the symbol table
118 node.iter_temp = self.get_temp_name()
119 globals_set.add(node.iter_temp)
120 for i in ast.iter_child_nodes(node):
121 self.index_global_class_symbols(i, globals_set, class_set)
123 def get_globals(self, node, globals_set, locals_set, all_vars_set):
124 if isinstance(node, ast.Global):
125 for name in node.names:
126 globals_set.add(name)
127 elif isinstance(node, ast.Name):
128 all_vars_set.add(node.id)
129 if isinstance(node.ctx, (ast.Store, ast.AugStore)):
130 locals_set.add(node.id)
131 elif isinstance(node, ast.arg):
132 all_vars_set.add(node.arg)
133 locals_set.add(node.arg)
134 elif isinstance(node, (ast.For, ast.ListComp, ast.DictComp, ast.SetComp,
135 ast.GeneratorExp)):
136 locals_set.add(node.iter_temp)
137 for i in ast.iter_child_nodes(node):
138 self.get_globals(i, globals_set, locals_set, all_vars_set)
140 def get_binding(self, name):
141 if self.in_function:
142 if name in self.globals_set:
143 scope = 'global'
144 else:
145 scope = 'local'
146 elif self.in_class:
147 scope = 'class'
148 else:
149 scope = 'global'
150 return (scope, self.symbol_idx[scope][name])
152 def generic_visit(self, node):
153 print(node.lineno)
154 raise RuntimeError('can\'t translate %s' % node)
156 def visit_children(self, node):
157 return [self.visit(i) for i in ast.iter_child_nodes(node)]
159 def visit_Name(self, node):
160 assert isinstance(node.ctx, ast.Load)
161 if node.id in ['True', 'False']:
162 return syntax.BoolConst(node.id == 'True')
163 elif node.id == 'None':
164 return syntax.NoneConst()
165 return syntax.Load(node.id, self.get_binding(node.id))
167 def visit_Num(self, node):
168 if isinstance(node.n, float):
169 raise RuntimeError('Pythonc currently does not support float literals')
170 assert isinstance(node.n, int)
171 return syntax.IntConst(node.n)
173 def visit_Str(self, node):
174 assert isinstance(node.s, str)
175 return syntax.StringConst(node.s)
177 def visit_Bytes(self, node):
178 raise RuntimeError('Pythonc currently does not support bytes literals')
180 # Unary Ops
181 def visit_Invert(self, node): return '__invert__'
182 def visit_Not(self, node): return '__not__'
183 def visit_UAdd(self, node): return '__pos__'
184 def visit_USub(self, node): return '__neg__'
185 def visit_UnaryOp(self, node):
186 op = self.visit(node.op)
187 rhs = self.flatten_node(node.operand)
188 return syntax.UnaryOp(op, rhs)
190 # Binary Ops
191 def visit_Add(self, node): return '__add__'
192 def visit_BitAnd(self, node): return '__and__'
193 def visit_BitOr(self, node): return '__or__'
194 def visit_BitXor(self, node): return '__xor__'
195 def visit_Div(self, node): return '__truediv__'
196 def visit_FloorDiv(self, node): return '__floordiv__'
197 def visit_LShift(self, node): return '__lshift__'
198 def visit_Mod(self, node): return '__mod__'
199 def visit_Mult(self, node): return '__mul__'
200 def visit_Pow(self, node): return '__pow__'
201 def visit_RShift(self, node): return '__rshift__'
202 def visit_Sub(self, node): return '__sub__'
204 def visit_BinOp(self, node):
205 op = self.visit(node.op)
206 lhs = self.flatten_node(node.left)
207 rhs = self.flatten_node(node.right)
208 return syntax.BinaryOp(op, lhs, rhs)
210 # Comparisons
211 def visit_Eq(self, node): return '__eq__'
212 def visit_NotEq(self, node): return '__ne__'
213 def visit_Lt(self, node): return '__lt__'
214 def visit_LtE(self, node): return '__le__'
215 def visit_Gt(self, node): return '__gt__'
216 def visit_GtE(self, node): return '__ge__'
217 def visit_In(self, node): return '__contains__'
218 def visit_NotIn(self, node): return '__ncontains__'
219 def visit_Is(self, node): return '__is__'
220 def visit_IsNot(self, node): return '__isnot__'
222 def visit_Compare(self, node):
223 assert len(node.ops) == 1
224 assert len(node.comparators) == 1
225 op = self.visit(node.ops[0])
226 lhs = self.flatten_node(node.left)
227 rhs = self.flatten_node(node.comparators[0])
228 # Sigh--Python has these ordered weirdly
229 if op in ['__contains__', '__ncontains__']:
230 lhs, rhs = rhs, lhs
231 return syntax.BinaryOp(op, lhs, rhs)
233 # Bool ops
234 def visit_And(self, node): return 'and'
235 def visit_Or(self, node): return 'or'
237 def visit_BoolOp(self, node):
238 assert len(node.values) >= 2
239 op = self.visit(node.op)
240 rhs_stmts = []
241 rhs_expr = self.flatten_node(node.values[-1], statements=rhs_stmts)
242 for v in reversed(node.values[:-1]):
243 lhs_stmts = []
244 lhs = self.flatten_node(v, statements=lhs_stmts)
245 bool_op = syntax.BoolOp(op, lhs, rhs_stmts, rhs_expr)
246 rhs_expr = bool_op.flatten(self, lhs_stmts)
247 rhs_stmts = lhs_stmts
248 self.statements += rhs_stmts
249 return rhs_expr
251 def visit_IfExp(self, node):
252 expr = self.flatten_node(node.test)
253 true_stmts = []
254 true_expr = self.flatten_node(node.body, statements=true_stmts)
255 false_stmts = []
256 false_expr = self.flatten_node(node.orelse, statements=false_stmts)
257 if_exp = syntax.IfExp(expr, true_stmts, true_expr, false_stmts, false_expr)
258 return if_exp.flatten(self)
260 def visit_List(self, node):
261 items = [self.flatten_node(i) for i in node.elts]
262 l = syntax.List(items)
263 return l.flatten(self)
265 def visit_Tuple(self, node):
266 items = [self.flatten_node(i) for i in node.elts]
267 l = syntax.Tuple(items)
268 return l.flatten(self)
270 def visit_Dict(self, node):
271 keys = [self.flatten_node(i) for i in node.keys]
272 values = [self.flatten_node(i) for i in node.values]
273 d = syntax.Dict(keys, values)
274 return d.flatten(self)
276 def visit_Set(self, node):
277 items = [self.flatten_node(i) for i in node.elts]
278 d = syntax.Set(items)
279 return d.flatten(self)
281 def visit_Subscript(self, node):
282 l = self.flatten_node(node.value)
283 if isinstance(node.slice, ast.Index):
284 index = self.flatten_node(node.slice.value)
285 return syntax.Subscript(l, index)
286 elif isinstance(node.slice, ast.Slice):
287 [start, end, step] = [self.flatten_node(a) if a else syntax.NoneConst() for a in
288 [node.slice.lower, node.slice.upper, node.slice.step]]
289 return syntax.Slice(l, start, end, step)
291 def visit_Attribute(self, node):
292 assert isinstance(node.ctx, ast.Load)
293 l = self.flatten_node(node.value)
294 attr = syntax.Attribute(l, syntax.StringConst(node.attr))
295 return attr
297 def visit_Call(self, node):
298 fn = self.flatten_node(node.func)
300 if node.starargs:
301 assert not node.args
302 assert not node.kwargs
303 args = syntax.Tuple(self.flatten_node(node.starargs))
304 args = args.flatten(self)
305 kwargs = syntax.Dict([], [])
306 else:
307 args = syntax.Tuple([self.flatten_node(a) for a in node.args])
308 args = args.flatten(self)
310 keys = [syntax.StringConst(i.arg) for i in node.keywords]
311 values = [self.flatten_node(i.value) for i in node.keywords]
312 kwargs = syntax.Dict(keys, values)
314 kwargs = kwargs.flatten(self)
315 return syntax.Call(fn, args, kwargs)
317 def visit_Assign(self, node):
318 assert len(node.targets) == 1
319 target = node.targets[0]
320 value = self.flatten_node(node.value)
321 if isinstance(target, ast.Name):
322 return [syntax.Store(target.id, value, self.get_binding(target.id))]
323 elif isinstance(target, ast.Tuple):
324 assert all(isinstance(t, ast.Name) for t in target.elts)
325 stmts = []
326 for i, t in enumerate(target.elts):
327 stmts += [syntax.Store(t.id, syntax.Subscript(value, syntax.IntConst(i)), self.get_binding(t.id))]
328 return stmts
329 elif isinstance(target, ast.Attribute):
330 base = self.flatten_node(target.value)
331 return [syntax.StoreAttr(base, syntax.StringConst(target.attr), value)]
332 elif isinstance(target, ast.Subscript):
333 assert isinstance(target.slice, ast.Index)
334 base = self.flatten_node(target.value)
335 index = self.flatten_node(target.slice.value)
336 return [syntax.StoreSubscript(base, index, value)]
337 else:
338 assert False
340 def visit_AugAssign(self, node):
341 op = self.visit(node.op)
342 value = self.flatten_node(node.value)
343 if isinstance(node.target, ast.Name):
344 target = node.target.id
345 # XXX HACK: doesn't modify in place
346 binop = syntax.BinaryOp(op, syntax.Load(target, self.get_binding(target)), value)
347 return [syntax.Store(target, binop, self.get_binding(target))]
348 elif isinstance(node.target, ast.Attribute):
349 l = self.flatten_node(node.target.value)
350 attr_name = syntax.StringConst(node.target.attr)
351 attr = syntax.Attribute(l, attr_name)
352 binop = syntax.BinaryOp(op, attr, value)
353 return [syntax.StoreAttr(l, attr_name, binop)]
354 elif isinstance(node.target, ast.Subscript):
355 assert isinstance(node.target.slice, ast.Index)
356 base = self.flatten_node(node.target.value)
357 index = self.flatten_node(node.target.slice.value)
358 old = syntax.Subscript(base, index)
359 binop = syntax.BinaryOp(op, old, value)
360 return [syntax.StoreSubscript(base, index, binop)]
361 else:
362 assert False
364 def visit_Delete(self, node):
365 assert len(node.targets) == 1
366 target = node.targets[0]
367 assert isinstance(target, ast.Subscript)
368 assert isinstance(target.slice, ast.Index)
370 name = self.flatten_node(target.value)
371 value = self.flatten_node(target.slice.value)
372 return [syntax.DeleteSubscript(name, value)]
374 def visit_If(self, node):
375 expr = self.flatten_node(node.test)
376 stmts = self.flatten_list(node.body)
377 if node.orelse:
378 else_block = self.flatten_list(node.orelse)
379 else:
380 else_block = None
381 return syntax.If(expr, stmts, else_block)
383 def visit_Break(self, node):
384 return syntax.Break()
386 def visit_Continue(self, node):
387 return syntax.Continue()
389 def visit_For(self, node):
390 assert not node.orelse
391 iter = self.flatten_node(node.iter)
392 stmts = self.flatten_list(node.body)
394 if isinstance(node.target, ast.Name):
395 target = (node.target.id, self.get_binding(node.target.id))
396 elif isinstance(node.target, ast.Tuple):
397 target = [(t.id, self.get_binding(t.id)) for t in node.target.elts]
398 else:
399 assert False
400 # HACK: self.iter_temp gets set when enumerating symbols
401 for_loop = syntax.For(target, iter, stmts, node.iter_temp, self.get_binding(node.iter_temp))
402 return for_loop.flatten(self)
404 def visit_While(self, node):
405 assert not node.orelse
406 test_stmts = []
407 test = self.flatten_node(node.test, statements=test_stmts)
408 stmts = self.flatten_list(node.body)
409 return syntax.While(test_stmts, test, stmts)
411 def visit_Comprehension(self, node, comp_type):
412 assert len(node.generators) == 1
413 gen = node.generators[0]
414 assert len(gen.ifs) <= 1
416 if isinstance(gen.target, ast.Name):
417 target = (gen.target.id, self.get_binding(gen.target.id))
418 elif isinstance(gen.target, ast.Tuple):
419 target = [(t.id, self.get_binding(t.id)) for t in gen.target.elts]
420 else:
421 assert False
423 iter = self.flatten_node(gen.iter)
424 cond_stmts = []
425 expr_stmts = []
426 cond = None
427 if gen.ifs:
428 cond = self.flatten_node(gen.ifs[0], statements=cond_stmts)
429 if comp_type == 'dict':
430 expr = self.flatten_node(node.key, statements=expr_stmts)
431 expr2 = self.flatten_node(node.value, statements=expr_stmts)
432 else:
433 expr = self.flatten_node(node.elt, statements=expr_stmts)
434 expr2 = None
435 comp = syntax.Comprehension(comp_type, target, iter, node.iter_temp,
436 self.get_binding(node.iter_temp), cond_stmts, cond, expr_stmts,
437 expr, expr2)
438 return comp.flatten(self)
440 def visit_ListComp(self, node):
441 return self.visit_Comprehension(node, 'list')
443 def visit_SetComp(self, node):
444 return self.visit_Comprehension(node, 'set')
446 def visit_DictComp(self, node):
447 return self.visit_Comprehension(node, 'dict')
449 def visit_GeneratorExp(self, node):
450 return self.visit_Comprehension(node, 'generator')
452 def visit_Return(self, node):
453 if node.value is not None:
454 expr = self.flatten_node(node.value)
455 return syntax.Return(expr)
456 else:
457 return syntax.Return(None)
459 def visit_Assert(self, node):
460 expr = self.flatten_node(node.test)
461 return syntax.Assert(expr, node.lineno)
463 def visit_arguments(self, node):
464 assert not node.vararg
465 assert not node.kwarg
467 args = [a.arg for a in node.args]
468 binding = [self.get_binding(a) for a in args]
469 defaults = self.flatten_list(node.defaults)
470 args = syntax.Arguments(args, binding, defaults)
471 return args.flatten(self)
473 def visit_FunctionDef(self, node):
474 assert not self.in_function
476 # Get bindings of all variables. Globals are the variables that have "global x"
477 # somewhere in the function, or are never written in the function.
478 globals_set = set()
479 locals_set = set()
480 all_vars_set = set()
481 self.get_globals(node, globals_set, locals_set, all_vars_set)
482 globals_set |= (all_vars_set - locals_set)
484 self.symbol_idx['local'] = {symbol: idx for idx, symbol in enumerate(sorted(locals_set))}
486 # Set some state and recursively visit child nodes, then restore state
487 self.globals_set = globals_set
488 self.in_function = True
489 args = self.visit(node.args)
490 body = self.flatten_list(node.body)
491 self.globals_set = None
492 self.in_function = False
494 exp_name = node.exp_name if 'exp_name' in dir(node) else None
495 fn = syntax.FunctionDef(node.name, args, body, exp_name, self.get_binding(node.name), len(locals_set))
496 return fn.flatten(self)
498 def visit_ClassDef(self, node):
499 assert not node.bases
500 assert not node.keywords
501 assert not node.starargs
502 assert not node.kwargs
503 assert not node.decorator_list
504 assert not self.in_class
505 assert not self.in_function
507 for fn in node.body:
508 if isinstance(fn, ast.FunctionDef):
509 fn.exp_name = '_%s_%s' % (node.name, fn.name)
511 self.in_class = True
512 body = self.flatten_list(node.body)
513 self.in_class = False
515 c = syntax.ClassDef(node.name, self.get_binding(node.name), body)
516 return c.flatten(self)
518 def visit_Expr(self, node):
519 return self.visit(node.value)
521 def visit_Module(self, node):
522 # Set up an index of all possible global/class symbols
523 all_global_syms = set()
524 all_class_syms = set()
525 self.index_global_class_symbols(node, all_global_syms, all_class_syms)
527 all_global_syms |= set(builtin_symbols)
529 self.symbol_idx = {
530 scope: {symbol: idx for idx, symbol in enumerate(sorted(symbols))}
531 for scope, symbols in [['class', all_class_syms], ['global', all_global_syms]]
533 self.global_sym_count = len(all_global_syms)
534 self.class_sym_count = len(all_class_syms)
536 return self.flatten_list(node.body)
538 def visit_Pass(self, node): pass
539 def visit_Load(self, node): pass
540 def visit_Store(self, node): pass
541 def visit_Global(self, node): pass
543 with open(sys.argv[1]) as f:
544 node = ast.parse(f.read())
546 transformer = Transformer()
547 node = transformer.visit(node)
549 with open(sys.argv[2], 'w') as f:
550 f.write('#define LIST_BUILTIN_FUNCTIONS(x) %s\n' % ' '.join('x(%s)' % x
551 for x in builtin_functions))
552 f.write('#define LIST_BUILTIN_CLASSES(x) %s\n' % ' '.join('x(%s)' % x
553 for x in builtin_classes))
554 for x in builtin_symbols:
555 f.write('#define sym_id_%s %s\n' % (x, transformer.symbol_idx['global'][x]))
556 f.write('#include "backend.cpp"\n')
557 syntax.export_consts(f)
559 for func in transformer.functions:
560 f.write('%s\n' % func)
562 f.write('int main(int argc, char **argv) {\n')
563 f.write(' node *global_syms[%s];\n' % (transformer.global_sym_count))
564 f.write(' context ctx(%s, global_syms), *globals = &ctx;\n' % (transformer.global_sym_count))
565 f.write(' init_context(&ctx, argc, argv);\n')
567 for stmt in node:
568 f.write(' %s;\n' % stmt)
570 f.write('}\n')