load() needs to check for undefined variables
[pythonc.git] / transform.py
blob363503374c9671d771e83dc80e2ba190de1157b3
1 ################################################################################
2 ##
3 ## Pythonc--Python to C++ translator
4 ##
5 ## Copyright 2011 Zach Wegner
6 ##
7 ## This file is part of Pythonc.
8 ##
9 ## Pythonc is free software: you can redistribute it and/or modify
10 ## it under the terms of the GNU General Public License as published by
11 ## the Free Software Foundation, either version 3 of the License, or
12 ## (at your option) any later version.
14 ## Pythonc is distributed in the hope that it will be useful,
15 ## but WITHOUT ANY WARRANTY; without even the implied warranty of
16 ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 ## GNU General Public License for more details.
19 ## You should have received a copy of the GNU General Public License
20 ## along with Pythonc. If not, see <http://www.gnu.org/licenses/>.
22 ################################################################################
24 import ast
25 import sys
27 import syntax
29 builtin_functions = [
30 'all',
31 'any',
32 'isinstance',
33 'len',
34 'open',
35 'ord',
36 'print',
37 'print_nonl',
38 'repr',
39 'sorted',
41 builtin_classes = [
42 'bool',
43 'dict',
44 'enumerate',
45 'int',
46 'list',
47 'range',
48 'reversed',
49 'set',
50 'str',
51 'tuple',
52 'zip',
54 builtin_symbols = builtin_functions + builtin_classes + [
55 '__name__',
56 '__args__',
59 class Transformer(ast.NodeTransformer):
60 def __init__(self):
61 self.temp_id = 0
62 self.statements = []
63 self.functions = []
64 self.in_class = False
65 self.in_function = False
66 self.globals_set = None
68 def get_temp_name(self):
69 self.temp_id += 1
70 return 'temp_%02i' % self.temp_id
72 def get_temp(self):
73 self.temp_id += 1
74 return syntax.Identifier('temp_%02i' % self.temp_id)
76 def flatten_node(self, node, statements=None):
77 old_stmts = self.statements
78 if statements is not None:
79 self.statements = statements
80 node = self.visit(node)
81 if node.is_atom():
82 r = node
83 else:
84 temp = self.get_temp()
85 self.statements.append(syntax.Assign(temp, node))
86 r = temp
87 self.statements = old_stmts
88 return r
90 def flatten_list(self, node_list):
91 old_stmts = self.statements
92 statements = []
93 for stmt in node_list:
94 self.statements = []
95 stmts = self.visit(stmt)
96 if stmts:
97 if isinstance(stmts, list):
98 statements += self.statements + stmts
99 else:
100 statements += self.statements + [stmts]
101 self.statements = old_stmts
102 return statements
104 def index_global_class_symbols(self, node, globals_set, class_set):
105 if isinstance(node, ast.Global):
106 for name in node.names:
107 globals_set.add(name)
108 # XXX make this check scope
109 elif isinstance(node, ast.Name) and isinstance(node.ctx,
110 (ast.Store, ast.AugStore)):
111 globals_set.add(node.id)
112 class_set.add(node.id)
113 elif isinstance(node, (ast.FunctionDef, ast.ClassDef)):
114 globals_set.add(node.name)
115 class_set.add(node.name)
116 elif isinstance(node, ast.Import):
117 for name in node.names:
118 globals_set.add(name.name)
119 elif isinstance(node, (ast.For, ast.ListComp, ast.DictComp, ast.SetComp,
120 ast.GeneratorExp)):
121 # HACK: set self.iter_temp for the space in the symbol table
122 node.iter_temp = self.get_temp_name()
123 globals_set.add(node.iter_temp)
124 for i in ast.iter_child_nodes(node):
125 self.index_global_class_symbols(i, globals_set, class_set)
127 def get_globals(self, node, globals_set, locals_set, all_vars_set):
128 if isinstance(node, ast.Global):
129 for name in node.names:
130 globals_set.add(name)
131 elif isinstance(node, ast.Name):
132 all_vars_set.add(node.id)
133 if isinstance(node.ctx, (ast.Store, ast.AugStore)):
134 locals_set.add(node.id)
135 elif isinstance(node, ast.arg):
136 all_vars_set.add(node.arg)
137 locals_set.add(node.arg)
138 elif isinstance(node, (ast.For, ast.ListComp, ast.DictComp, ast.SetComp,
139 ast.GeneratorExp)):
140 locals_set.add(node.iter_temp)
141 for i in ast.iter_child_nodes(node):
142 self.get_globals(i, globals_set, locals_set, all_vars_set)
144 def get_binding(self, name):
145 if self.in_function:
146 if name in self.globals_set:
147 scope = 'global'
148 else:
149 scope = 'local'
150 elif self.in_class:
151 scope = 'class'
152 else:
153 scope = 'global'
154 return (scope, self.symbol_idx[scope][name])
156 def generic_visit(self, node):
157 print(node.lineno)
158 raise RuntimeError('can\'t translate %s' % node)
160 def visit_children(self, node):
161 return [self.visit(i) for i in ast.iter_child_nodes(node)]
163 def visit_Name(self, node):
164 assert isinstance(node.ctx, ast.Load)
165 if node.id in ['True', 'False']:
166 return syntax.BoolConst(node.id == 'True')
167 elif node.id == 'None':
168 return syntax.NoneConst()
169 return syntax.Load(node.id, self.get_binding(node.id))
171 def visit_Num(self, node):
172 if isinstance(node.n, float):
173 raise RuntimeError('Pythonc currently does not support float literals')
174 assert isinstance(node.n, int)
175 return syntax.IntConst(node.n)
177 def visit_Str(self, node):
178 assert isinstance(node.s, str)
179 return syntax.StringConst(node.s)
181 def visit_Bytes(self, node):
182 raise RuntimeError('Pythonc currently does not support bytes literals')
184 # Unary Ops
185 def visit_Invert(self, node): return '__invert__'
186 def visit_Not(self, node): return '__not__'
187 def visit_UAdd(self, node): return '__pos__'
188 def visit_USub(self, node): return '__neg__'
189 def visit_UnaryOp(self, node):
190 op = self.visit(node.op)
191 rhs = self.flatten_node(node.operand)
192 return syntax.UnaryOp(op, rhs)
194 # Binary Ops
195 def visit_Add(self, node): return '__add__'
196 def visit_BitAnd(self, node): return '__and__'
197 def visit_BitOr(self, node): return '__or__'
198 def visit_BitXor(self, node): return '__xor__'
199 def visit_Div(self, node): return '__truediv__'
200 def visit_FloorDiv(self, node): return '__floordiv__'
201 def visit_LShift(self, node): return '__lshift__'
202 def visit_Mod(self, node): return '__mod__'
203 def visit_Mult(self, node): return '__mul__'
204 def visit_Pow(self, node): return '__pow__'
205 def visit_RShift(self, node): return '__rshift__'
206 def visit_Sub(self, node): return '__sub__'
208 def visit_BinOp(self, node):
209 op = self.visit(node.op)
210 lhs = self.flatten_node(node.left)
211 rhs = self.flatten_node(node.right)
212 return syntax.BinaryOp(op, lhs, rhs)
214 # Comparisons
215 def visit_Eq(self, node): return '__eq__'
216 def visit_NotEq(self, node): return '__ne__'
217 def visit_Lt(self, node): return '__lt__'
218 def visit_LtE(self, node): return '__le__'
219 def visit_Gt(self, node): return '__gt__'
220 def visit_GtE(self, node): return '__ge__'
221 def visit_In(self, node): return '__contains__'
222 def visit_NotIn(self, node): return '__ncontains__'
223 def visit_Is(self, node): return '__is__'
224 def visit_IsNot(self, node): return '__isnot__'
226 def visit_Compare(self, node):
227 assert len(node.ops) == 1
228 assert len(node.comparators) == 1
229 op = self.visit(node.ops[0])
230 lhs = self.flatten_node(node.left)
231 rhs = self.flatten_node(node.comparators[0])
232 # Sigh--Python has these ordered weirdly
233 if op in ['__contains__', '__ncontains__']:
234 lhs, rhs = rhs, lhs
235 return syntax.BinaryOp(op, lhs, rhs)
237 # Bool ops
238 def visit_And(self, node): return 'and'
239 def visit_Or(self, node): return 'or'
241 def visit_BoolOp(self, node):
242 assert len(node.values) >= 2
243 op = self.visit(node.op)
244 rhs_stmts = []
245 rhs_expr = self.flatten_node(node.values[-1], statements=rhs_stmts)
246 for v in reversed(node.values[:-1]):
247 lhs_stmts = []
248 lhs = self.flatten_node(v, statements=lhs_stmts)
249 bool_op = syntax.BoolOp(op, lhs, rhs_stmts, rhs_expr)
250 rhs_expr = bool_op.flatten(self, lhs_stmts)
251 rhs_stmts = lhs_stmts
252 self.statements += rhs_stmts
253 return rhs_expr
255 def visit_IfExp(self, node):
256 expr = self.flatten_node(node.test)
257 true_stmts = []
258 true_expr = self.flatten_node(node.body, statements=true_stmts)
259 false_stmts = []
260 false_expr = self.flatten_node(node.orelse, statements=false_stmts)
261 if_exp = syntax.IfExp(expr, true_stmts, true_expr, false_stmts, false_expr)
262 return if_exp.flatten(self)
264 def visit_List(self, node):
265 items = [self.flatten_node(i) for i in node.elts]
266 l = syntax.List(items)
267 return l.flatten(self)
269 def visit_Tuple(self, node):
270 items = [self.flatten_node(i) for i in node.elts]
271 l = syntax.Tuple(items)
272 return l.flatten(self)
274 def visit_Dict(self, node):
275 keys = [self.flatten_node(i) for i in node.keys]
276 values = [self.flatten_node(i) for i in node.values]
277 d = syntax.Dict(keys, values)
278 return d.flatten(self)
280 def visit_Set(self, node):
281 items = [self.flatten_node(i) for i in node.elts]
282 d = syntax.Set(items)
283 return d.flatten(self)
285 def visit_Subscript(self, node):
286 l = self.flatten_node(node.value)
287 if isinstance(node.slice, ast.Index):
288 index = self.flatten_node(node.slice.value)
289 return syntax.Subscript(l, index)
290 elif isinstance(node.slice, ast.Slice):
291 [start, end, step] = [self.flatten_node(a) if a else syntax.NoneConst() for a in
292 [node.slice.lower, node.slice.upper, node.slice.step]]
293 return syntax.Slice(l, start, end, step)
295 def visit_Attribute(self, node):
296 assert isinstance(node.ctx, ast.Load)
297 l = self.flatten_node(node.value)
298 attr = syntax.Attribute(l, syntax.StringConst(node.attr))
299 return attr
301 def visit_Call(self, node):
302 fn = self.flatten_node(node.func)
304 if node.starargs:
305 assert not node.args
306 assert not node.kwargs
307 args = syntax.Tuple(self.flatten_node(node.starargs))
308 args = args.flatten(self)
309 kwargs = syntax.Dict([], [])
310 else:
311 args = syntax.Tuple([self.flatten_node(a) for a in node.args])
312 args = args.flatten(self)
314 keys = [syntax.StringConst(i.arg) for i in node.keywords]
315 values = [self.flatten_node(i.value) for i in node.keywords]
316 kwargs = syntax.Dict(keys, values)
318 kwargs = kwargs.flatten(self)
319 return syntax.Call(fn, args, kwargs)
321 def visit_Assign(self, node):
322 assert len(node.targets) == 1
323 target = node.targets[0]
324 value = self.flatten_node(node.value)
325 if isinstance(target, ast.Name):
326 return [syntax.Store(target.id, value, self.get_binding(target.id))]
327 elif isinstance(target, ast.Tuple):
328 assert all(isinstance(t, ast.Name) for t in target.elts)
329 stmts = []
330 for i, t in enumerate(target.elts):
331 stmts += [syntax.Store(t.id, syntax.Subscript(value, syntax.IntConst(i)), self.get_binding(t.id))]
332 return stmts
333 elif isinstance(target, ast.Attribute):
334 base = self.flatten_node(target.value)
335 return [syntax.StoreAttr(base, syntax.StringConst(target.attr), value)]
336 elif isinstance(target, ast.Subscript):
337 assert isinstance(target.slice, ast.Index)
338 base = self.flatten_node(target.value)
339 index = self.flatten_node(target.slice.value)
340 return [syntax.StoreSubscript(base, index, value)]
341 else:
342 assert False
344 def visit_AugAssign(self, node):
345 op = self.visit(node.op)
346 value = self.flatten_node(node.value)
347 if isinstance(node.target, ast.Name):
348 target = node.target.id
349 # XXX HACK: doesn't modify in place
350 binop = syntax.BinaryOp(op, syntax.Load(target, self.get_binding(target)), value)
351 return [syntax.Store(target, binop, self.get_binding(target))]
352 elif isinstance(node.target, ast.Attribute):
353 l = self.flatten_node(node.target.value)
354 attr_name = syntax.StringConst(node.target.attr)
355 attr = syntax.Attribute(l, attr_name)
356 binop = syntax.BinaryOp(op, attr, value)
357 return [syntax.StoreAttr(l, attr_name, binop)]
358 elif isinstance(node.target, ast.Subscript):
359 assert isinstance(node.target.slice, ast.Index)
360 base = self.flatten_node(node.target.value)
361 index = self.flatten_node(node.target.slice.value)
362 old = syntax.Subscript(base, index)
363 binop = syntax.BinaryOp(op, old, value)
364 return [syntax.StoreSubscript(base, index, binop)]
365 else:
366 assert False
368 def visit_Delete(self, node):
369 assert len(node.targets) == 1
370 target = node.targets[0]
371 assert isinstance(target, ast.Subscript)
372 assert isinstance(target.slice, ast.Index)
374 name = self.flatten_node(target.value)
375 value = self.flatten_node(target.slice.value)
376 return [syntax.DeleteSubscript(name, value)]
378 def visit_If(self, node):
379 expr = self.flatten_node(node.test)
380 stmts = self.flatten_list(node.body)
381 if node.orelse:
382 else_block = self.flatten_list(node.orelse)
383 else:
384 else_block = None
385 return syntax.If(expr, stmts, else_block)
387 def visit_Break(self, node):
388 return syntax.Break()
390 def visit_Continue(self, node):
391 return syntax.Continue()
393 def visit_For(self, node):
394 assert not node.orelse
395 iter = self.flatten_node(node.iter)
396 stmts = self.flatten_list(node.body)
398 if isinstance(node.target, ast.Name):
399 target = (node.target.id, self.get_binding(node.target.id))
400 elif isinstance(node.target, ast.Tuple):
401 target = [(t.id, self.get_binding(t.id)) for t in node.target.elts]
402 else:
403 assert False
404 # HACK: self.iter_temp gets set when enumerating symbols
405 for_loop = syntax.For(target, iter, stmts, node.iter_temp, self.get_binding(node.iter_temp))
406 return for_loop.flatten(self)
408 def visit_While(self, node):
409 assert not node.orelse
410 test_stmts = []
411 test = self.flatten_node(node.test, statements=test_stmts)
412 stmts = self.flatten_list(node.body)
413 return syntax.While(test_stmts, test, stmts)
415 # XXX We are just flattening "with x as y:" into "y = x" (this works in some simple cases with open()).
416 def visit_With(self, node):
417 assert node.optional_vars
418 expr = self.flatten_node(node.context_expr)
419 stmts = [syntax.Store(node.optional_vars.id, expr, self.get_binding(node.optional_vars.id))]
420 stmts += self.flatten_list(node.body)
421 return stmts
423 def visit_Comprehension(self, node, comp_type):
424 assert len(node.generators) == 1
425 gen = node.generators[0]
426 assert len(gen.ifs) <= 1
428 if isinstance(gen.target, ast.Name):
429 target = (gen.target.id, self.get_binding(gen.target.id))
430 elif isinstance(gen.target, ast.Tuple):
431 target = [(t.id, self.get_binding(t.id)) for t in gen.target.elts]
432 else:
433 assert False
435 iter = self.flatten_node(gen.iter)
436 cond_stmts = []
437 expr_stmts = []
438 cond = None
439 if gen.ifs:
440 cond = self.flatten_node(gen.ifs[0], statements=cond_stmts)
441 if comp_type == 'dict':
442 expr = self.flatten_node(node.key, statements=expr_stmts)
443 expr2 = self.flatten_node(node.value, statements=expr_stmts)
444 else:
445 expr = self.flatten_node(node.elt, statements=expr_stmts)
446 expr2 = None
447 comp = syntax.Comprehension(comp_type, target, iter, node.iter_temp,
448 self.get_binding(node.iter_temp), cond_stmts, cond, expr_stmts,
449 expr, expr2)
450 return comp.flatten(self)
452 def visit_ListComp(self, node):
453 return self.visit_Comprehension(node, 'list')
455 def visit_SetComp(self, node):
456 return self.visit_Comprehension(node, 'set')
458 def visit_DictComp(self, node):
459 return self.visit_Comprehension(node, 'dict')
461 def visit_GeneratorExp(self, node):
462 return self.visit_Comprehension(node, 'generator')
464 def visit_Return(self, node):
465 if node.value is not None:
466 expr = self.flatten_node(node.value)
467 return syntax.Return(expr)
468 else:
469 return syntax.Return(None)
471 def visit_Assert(self, node):
472 expr = self.flatten_node(node.test)
473 return syntax.Assert(expr, node.lineno)
475 def visit_Raise(self, node):
476 assert not node.cause
477 expr = self.flatten_node(node.exc)
478 return syntax.Raise(expr, node.lineno)
480 def visit_arguments(self, node):
481 assert not node.vararg
482 assert not node.kwarg
484 args = [a.arg for a in node.args]
485 binding = [self.get_binding(a) for a in args]
486 defaults = self.flatten_list(node.defaults)
487 args = syntax.Arguments(args, binding, defaults)
488 return args.flatten(self)
490 def visit_FunctionDef(self, node):
491 assert not self.in_function
493 # Get bindings of all variables. Globals are the variables that have "global x"
494 # somewhere in the function, or are never written in the function.
495 globals_set = set()
496 locals_set = set()
497 all_vars_set = set()
498 self.get_globals(node, globals_set, locals_set, all_vars_set)
499 globals_set |= (all_vars_set - locals_set)
501 self.symbol_idx['local'] = {symbol: idx for idx, symbol in enumerate(sorted(locals_set))}
503 # Set some state and recursively visit child nodes, then restore state
504 self.globals_set = globals_set
505 self.in_function = True
506 args = self.visit(node.args)
507 body = self.flatten_list(node.body)
508 self.globals_set = None
509 self.in_function = False
511 exp_name = node.exp_name if 'exp_name' in dir(node) else None
512 fn = syntax.FunctionDef(node.name, args, body, exp_name, self.get_binding(node.name), len(locals_set))
513 return fn.flatten(self)
515 def visit_ClassDef(self, node):
516 assert not node.bases
517 assert not node.keywords
518 assert not node.starargs
519 assert not node.kwargs
520 assert not node.decorator_list
521 assert not self.in_class
522 assert not self.in_function
524 for fn in node.body:
525 if isinstance(fn, ast.FunctionDef):
526 fn.exp_name = '_%s_%s' % (node.name, fn.name)
528 self.in_class = True
529 body = self.flatten_list(node.body)
530 self.in_class = False
532 c = syntax.ClassDef(node.name, self.get_binding(node.name), body)
533 return c.flatten(self)
535 # XXX This just turns "import x" into "x = 0". It's certainly not what we really want...
536 def visit_Import(self, node):
537 statements = []
538 for name in node.names:
539 assert not name.asname
540 assert name.name
541 statements.append(syntax.Store(name.name, syntax.IntConst(0), self.get_binding(name.name)))
542 return statements
544 def visit_Expr(self, node):
545 return self.visit(node.value)
547 def visit_Module(self, node):
548 # Set up an index of all possible global/class symbols
549 all_global_syms = set()
550 all_class_syms = set()
551 self.index_global_class_symbols(node, all_global_syms, all_class_syms)
553 all_global_syms |= set(builtin_symbols)
555 self.symbol_idx = {
556 scope: {symbol: idx for idx, symbol in enumerate(sorted(symbols))}
557 for scope, symbols in [['class', all_class_syms], ['global', all_global_syms]]
559 self.global_sym_count = len(all_global_syms)
560 self.class_sym_count = len(all_class_syms)
562 return self.flatten_list(node.body)
564 def visit_Pass(self, node): pass
565 def visit_Load(self, node): pass
566 def visit_Store(self, node): pass
567 def visit_Global(self, node): pass
569 with open(sys.argv[1]) as f:
570 node = ast.parse(f.read())
572 transformer = Transformer()
573 node = transformer.visit(node)
575 with open(sys.argv[2], 'w') as f:
576 f.write('#define LIST_BUILTIN_FUNCTIONS(x) %s\n' % ' '.join('x(%s)' % x
577 for x in builtin_functions))
578 f.write('#define LIST_BUILTIN_CLASSES(x) %s\n' % ' '.join('x(%s)' % x
579 for x in builtin_classes))
580 for x in builtin_symbols:
581 f.write('#define sym_id_%s %s\n' % (x, transformer.symbol_idx['global'][x]))
582 f.write('#include "backend.cpp"\n')
583 syntax.export_consts(f)
585 for func in transformer.functions:
586 f.write('%s\n' % func)
588 f.write('int main(int argc, char **argv) {\n')
589 f.write(' node *global_syms[%s] = {0};\n' % (transformer.global_sym_count))
590 f.write(' context ctx(%s, global_syms), *globals = &ctx;\n' % (transformer.global_sym_count))
591 f.write(' init_context(&ctx, argc, argv);\n')
593 for stmt in node:
594 f.write(' %s;\n' % stmt)
596 f.write('}\n')