support python -O option; compiles the C code with -O3
[pythonc.git] / transform.py
blobe03050fe51b1148a9831765e28281aa01923b3ab
1 ################################################################################
2 ##
3 ## Pythonc--Python to C++ translator
4 ##
5 ## Copyright 2011 Zach Wegner
6 ##
7 ## This file is part of Pythonc.
8 ##
9 ## Pythonc is free software: you can redistribute it and/or modify
10 ## it under the terms of the GNU General Public License as published by
11 ## the Free Software Foundation, either version 3 of the License, or
12 ## (at your option) any later version.
14 ## Pythonc is distributed in the hope that it will be useful,
15 ## but WITHOUT ANY WARRANTY; without even the implied warranty of
16 ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 ## GNU General Public License for more details.
19 ## You should have received a copy of the GNU General Public License
20 ## along with Pythonc. If not, see <http://www.gnu.org/licenses/>.
22 ################################################################################
24 import ast
25 import sys
27 import syntax
29 class Transformer(ast.NodeTransformer):
30 def __init__(self):
31 self.temp_id = 0
32 self.statements = []
33 self.functions = []
34 self.in_class = False
36 def get_temp(self):
37 self.temp_id += 1
38 return syntax.Identifier('temp_%02i' % self.temp_id)
40 def flatten_node(self, node, statements=None):
41 old_stmts = self.statements
42 if statements is not None:
43 self.statements = statements
44 node = self.visit(node)
45 if node.is_atom():
46 r = node
47 else:
48 temp = self.get_temp()
49 self.statements.append(syntax.Assign(temp, node))
50 r = temp
51 self.statements = old_stmts
52 return r
54 def flatten_list(self, node_list):
55 old_stmts = self.statements
56 statements = []
57 for stmt in node_list:
58 self.statements = []
59 stmts = self.visit(stmt)
60 if stmts:
61 if isinstance(stmts, list):
62 statements += self.statements + stmts
63 else:
64 statements += self.statements + [stmts]
65 self.statements = old_stmts
66 return statements
68 def generic_visit(self, node):
69 print(node.lineno)
70 raise RuntimeError('can\'t translate %s' % node)
72 def visit_children(self, node):
73 return [self.visit(i) for i in ast.iter_child_nodes(node)]
75 def visit_Name(self, node):
76 assert isinstance(node.ctx, ast.Load)
77 if node.id in ['True', 'False']:
78 return syntax.BoolConst(node.id == 'True')
79 elif node.id == 'None':
80 return syntax.NoneConst()
81 return syntax.Load(node.id)
83 def visit_Num(self, node):
84 assert isinstance(node.n, int)
85 return syntax.IntConst(node.n)
87 def visit_Str(self, node):
88 assert isinstance(node.s, str)
89 return syntax.StringConst(node.s)
91 # Unary Ops
92 def visit_Invert(self, node): return '__invert__'
93 def visit_Not(self, node): return '__not__'
94 def visit_UAdd(self, node): return '__pos__'
95 def visit_USub(self, node): return '__neg__'
96 def visit_UnaryOp(self, node):
97 op = self.visit(node.op)
98 rhs = self.flatten_node(node.operand)
99 return syntax.UnaryOp(op, rhs)
101 # Binary Ops
102 def visit_Add(self, node): return '__add__'
103 def visit_BitAnd(self, node): return '__and__'
104 def visit_BitOr(self, node): return '__or__'
105 def visit_BitXor(self, node): return '__xor__'
106 def visit_Div(self, node): return '__truediv__'
107 def visit_FloorDiv(self, node): return '__floordiv__'
108 def visit_LShift(self, node): return '__lshift__'
109 def visit_Mod(self, node): return '__mod__'
110 def visit_Mult(self, node): return '__mul__'
111 def visit_Pow(self, node): return '__pow__'
112 def visit_RShift(self, node): return '__rshift__'
113 def visit_Sub(self, node): return '__sub__'
115 def visit_BinOp(self, node):
116 op = self.visit(node.op)
117 lhs = self.flatten_node(node.left)
118 rhs = self.flatten_node(node.right)
119 return syntax.BinaryOp(op, lhs, rhs)
121 # Comparisons
122 def visit_Eq(self, node): return '__eq__'
123 def visit_NotEq(self, node): return '__ne__'
124 def visit_Lt(self, node): return '__lt__'
125 def visit_LtE(self, node): return '__le__'
126 def visit_Gt(self, node): return '__gt__'
127 def visit_GtE(self, node): return '__ge__'
128 def visit_In(self, node): return '__contains__'
129 def visit_NotIn(self, node): return '__ncontains__'
130 def visit_Is(self, node): return '__is__'
131 def visit_IsNot(self, node): return '__isnot__'
133 def visit_Compare(self, node):
134 assert len(node.ops) == 1
135 assert len(node.comparators) == 1
136 op = self.visit(node.ops[0])
137 lhs = self.flatten_node(node.left)
138 rhs = self.flatten_node(node.comparators[0])
139 # Sigh--Python has these ordered weirdly
140 if op in ['__contains__', '__ncontains__']:
141 lhs, rhs = rhs, lhs
142 return syntax.BinaryOp(op, lhs, rhs)
144 # Bool ops
145 def visit_And(self, node): return 'and'
146 def visit_Or(self, node): return 'or'
148 def visit_BoolOp(self, node):
149 assert len(node.values) >= 2
150 op = self.visit(node.op)
151 rhs_stmts = []
152 rhs_expr = self.flatten_node(node.values[-1], statements=rhs_stmts)
153 for v in reversed(node.values[:-1]):
154 lhs_stmts = []
155 lhs = self.flatten_node(v, statements=lhs_stmts)
156 bool_op = syntax.BoolOp(op, lhs, rhs_stmts, rhs_expr)
157 rhs_expr = bool_op.flatten(self, lhs_stmts)
158 rhs_stmts = lhs_stmts
159 self.statements += rhs_stmts
160 return rhs_expr
162 def visit_IfExp(self, node):
163 expr = self.flatten_node(node.test)
164 true_stmts = []
165 true_expr = self.flatten_node(node.body, statements=true_stmts)
166 false_stmts = []
167 false_expr = self.flatten_node(node.orelse, statements=false_stmts)
168 if_exp = syntax.IfExp(expr, true_stmts, true_expr, false_stmts, false_expr)
169 return if_exp.flatten(self)
171 def visit_List(self, node):
172 items = [self.flatten_node(i) for i in node.elts]
173 l = syntax.List(items)
174 return l.flatten(self)
176 def visit_Tuple(self, node):
177 items = [self.flatten_node(i) for i in node.elts]
178 l = syntax.List(items)
179 return l.flatten(self)
181 def visit_Dict(self, node):
182 keys = [self.flatten_node(i) for i in node.keys]
183 values = [self.flatten_node(i) for i in node.values]
184 d = syntax.Dict(keys, values)
185 return d.flatten(self)
187 def visit_Set(self, node):
188 items = [self.flatten_node(i) for i in node.elts]
189 d = syntax.Set(items)
190 return d.flatten(self)
192 def visit_Subscript(self, node):
193 l = self.flatten_node(node.value)
194 if isinstance(node.slice, ast.Index):
195 index = self.flatten_node(node.slice.value)
196 return syntax.Subscript(l, index)
197 elif isinstance(node.slice, ast.Slice):
198 [start, end, step] = [self.flatten_node(a) if a else syntax.NoneConst() for a in
199 [node.slice.lower, node.slice.upper, node.slice.step]]
200 return syntax.Slice(l, start, end, step)
202 def visit_Attribute(self, node):
203 assert isinstance(node.ctx, ast.Load)
204 l = self.flatten_node(node.value)
205 attr = syntax.Attribute(l, syntax.StringConst(node.attr))
206 return attr
208 def visit_Call(self, node):
209 fn = self.flatten_node(node.func)
211 if node.starargs:
212 assert not node.args
213 assert not node.kwargs
214 args = self.flatten_node(node.starargs)
215 kwargs = syntax.Dict([], [])
216 else:
217 args = syntax.List([self.flatten_node(a) for a in node.args])
218 args = args.flatten(self)
220 keys = [syntax.StringConst(i.arg) for i in node.keywords]
221 values = [self.flatten_node(i.value) for i in node.keywords]
222 kwargs = syntax.Dict(keys, values)
224 kwargs = kwargs.flatten(self)
225 return syntax.Call(fn, args, kwargs)
227 def visit_Assign(self, node):
228 assert len(node.targets) == 1
229 target = node.targets[0]
230 value = self.flatten_node(node.value)
231 if isinstance(target, ast.Name):
232 return [syntax.Store(target.id, value)]
233 elif isinstance(target, ast.Tuple):
234 assert all(isinstance(t, ast.Name) for t in target.elts)
235 stmts = []
236 for i, t in enumerate(target.elts):
237 stmts += [syntax.Store(t.id, syntax.Subscript(value, syntax.IntConst(i)))]
238 return stmts
239 elif isinstance(target, ast.Attribute):
240 base = self.flatten_node(target.value)
241 return [syntax.StoreAttr(base, syntax.StringConst(target.attr), value)]
242 elif isinstance(target, ast.Subscript):
243 assert isinstance(target.slice, ast.Index)
244 base = self.flatten_node(target.value)
245 index = self.flatten_node(target.slice.value)
246 return [syntax.StoreSubscript(base, index, value)]
247 else:
248 assert False
250 def visit_AugAssign(self, node):
251 op = self.visit(node.op)
252 value = self.flatten_node(node.value)
253 if isinstance(node.target, ast.Name):
254 target = node.target.id
255 # XXX HACK: doesn't modify in place
256 binop = syntax.BinaryOp(op, syntax.Load(target), value)
257 return [syntax.Store(target, binop)]
258 elif isinstance(node.target, ast.Attribute):
259 l = self.flatten_node(node.target.value)
260 attr_name = syntax.StringConst(node.target.attr)
261 attr = syntax.Attribute(l, attr_name)
262 binop = syntax.BinaryOp(op, attr, value)
263 return [syntax.StoreAttr(l, attr_name, binop)]
264 else:
265 assert False
267 def visit_Delete(self, node):
268 assert len(node.targets) == 1
269 target = node.targets[0]
270 assert isinstance(target, ast.Subscript)
271 assert isinstance(target.slice, ast.Index)
273 name = self.flatten_node(target.value)
274 value = self.flatten_node(target.slice.value)
275 return [syntax.DeleteSubscript(name, value)]
277 def visit_Global(self, node):
278 return [syntax.Global(name) for name in node.names]
280 def visit_If(self, node):
281 expr = self.flatten_node(node.test)
282 stmts = self.flatten_list(node.body)
283 if node.orelse:
284 else_block = self.flatten_list(node.orelse)
285 else:
286 else_block = None
287 return syntax.If(expr, stmts, else_block)
289 def visit_Break(self, node):
290 return syntax.Break()
292 def visit_Continue(self, node):
293 return syntax.Continue()
295 def visit_For(self, node):
296 assert not node.orelse
297 iter = self.flatten_node(node.iter)
298 stmts = self.flatten_list(node.body)
300 if isinstance(node.target, ast.Name):
301 target = node.target.id
302 elif isinstance(node.target, ast.Tuple):
303 target = node.target.elts
304 else:
305 assert False
306 return syntax.For(target, iter, stmts)
308 def visit_While(self, node):
309 assert not node.orelse
310 test_stmts = []
311 test = self.flatten_node(node.test, statements=test_stmts)
312 stmts = self.flatten_list(node.body)
313 return syntax.While(test_stmts, test, stmts)
315 def visit_ListComp(self, node):
316 assert len(node.generators) == 1
317 gen = node.generators[0]
318 assert not gen.ifs
320 if isinstance(gen.target, ast.Name):
321 target = gen.target.id
322 elif isinstance(gen.target, ast.Tuple):
323 target = gen.target.elts
324 else:
325 assert False
327 iter = self.flatten_node(gen.iter)
328 stmts = []
329 expr = self.flatten_node(node.elt, statements=stmts)
330 comp = syntax.ListComp(target, iter, stmts, expr)
331 return comp.flatten(self)
333 # XXX HACK if we ever want these to differ...
334 def visit_GeneratorExp(self, node):
335 return self.visit_ListComp(node)
337 def visit_Return(self, node):
338 if node.value is not None:
339 expr = self.flatten_node(node.value)
340 return syntax.Return(expr)
341 else:
342 return syntax.Return(None)
344 def visit_Assert(self, node):
345 expr = self.flatten_node(node.test)
346 return syntax.Assert(expr, node.lineno)
348 def visit_arguments(self, node):
349 assert not node.vararg
350 assert not node.kwarg
352 args = [a.arg for a in node.args]
353 defaults = self.flatten_list(node.defaults)
354 args = syntax.Arguments(args, defaults)
355 return args.flatten(self)
357 def visit_FunctionDef(self, node):
358 args = self.visit(node.args)
359 body = self.flatten_list(node.body)
360 exp_name = node.exp_name if 'exp_name' in dir(node) else None
361 fn = syntax.FunctionDef(node.name, args, body, exp_name).flatten(self)
362 return fn
364 def visit_ClassDef(self, node):
365 assert not node.bases
366 assert not node.keywords
367 assert not node.starargs
368 assert not node.kwargs
369 assert not node.decorator_list
371 for fn in node.body:
372 if isinstance(fn, ast.FunctionDef):
373 fn.exp_name = '_%s_%s' % (node.name, fn.name)
375 body = self.flatten_list(node.body)
377 c = syntax.ClassDef(node.name, body)
378 return c.flatten(self)
380 def visit_Expr(self, node):
381 return self.visit(node.value)
383 def visit_Module(self, node):
384 return self.flatten_list(node.body)
386 def visit_Pass(self, node): pass
387 def visit_Load(self, node): pass
388 def visit_Store(self, node): pass
390 with open(sys.argv[1]) as f:
391 node = ast.parse(f.read())
393 x = Transformer()
394 node = x.visit(node)
396 with open(sys.argv[2], 'w') as f:
397 f.write('#include "backend.cpp"\n')
398 syntax.export_consts(f)
400 for func in x.functions:
401 f.write('%s\n' % func)
403 f.write('int main(int argc, char **argv) {\n')
404 f.write(' context *ctx = new(allocator) context();\n')
405 f.write(' init_context(ctx, argc, argv);\n')
407 for stmt in node:
408 f.write(' %s;\n' % stmt)
410 f.write('}\n')