Pin Chrome's shortcut to the Win10 Start menu on install and OS upgrade.
[chromium-blink-merge.git] / third_party / cython / src / Cython / Compiler / Tests / TestParseTreeTransforms.py
blobb7cd0ee5464c5dfdda317bc7aba753b385476aac
1 import os
3 from Cython.Compiler import CmdLine
4 from Cython.TestUtils import TransformTest
5 from Cython.Compiler.ParseTreeTransforms import *
6 from Cython.Compiler.Nodes import *
7 from Cython.Compiler import Main, Symtab
10 class TestNormalizeTree(TransformTest):
11 def test_parserbehaviour_is_what_we_coded_for(self):
12 t = self.fragment(u"if x: y").root
13 self.assertLines(u"""
14 (root): StatListNode
15 stats[0]: IfStatNode
16 if_clauses[0]: IfClauseNode
17 condition: NameNode
18 body: ExprStatNode
19 expr: NameNode
20 """, self.treetypes(t))
22 def test_wrap_singlestat(self):
23 t = self.run_pipeline([NormalizeTree(None)], u"if x: y")
24 self.assertLines(u"""
25 (root): StatListNode
26 stats[0]: IfStatNode
27 if_clauses[0]: IfClauseNode
28 condition: NameNode
29 body: StatListNode
30 stats[0]: ExprStatNode
31 expr: NameNode
32 """, self.treetypes(t))
34 def test_wrap_multistat(self):
35 t = self.run_pipeline([NormalizeTree(None)], u"""
36 if z:
39 """)
40 self.assertLines(u"""
41 (root): StatListNode
42 stats[0]: IfStatNode
43 if_clauses[0]: IfClauseNode
44 condition: NameNode
45 body: StatListNode
46 stats[0]: ExprStatNode
47 expr: NameNode
48 stats[1]: ExprStatNode
49 expr: NameNode
50 """, self.treetypes(t))
52 def test_statinexpr(self):
53 t = self.run_pipeline([NormalizeTree(None)], u"""
54 a, b = x, y
55 """)
56 self.assertLines(u"""
57 (root): StatListNode
58 stats[0]: SingleAssignmentNode
59 lhs: TupleNode
60 args[0]: NameNode
61 args[1]: NameNode
62 rhs: TupleNode
63 args[0]: NameNode
64 args[1]: NameNode
65 """, self.treetypes(t))
67 def test_wrap_offagain(self):
68 t = self.run_pipeline([NormalizeTree(None)], u"""
71 if z:
73 """)
74 self.assertLines(u"""
75 (root): StatListNode
76 stats[0]: ExprStatNode
77 expr: NameNode
78 stats[1]: ExprStatNode
79 expr: NameNode
80 stats[2]: IfStatNode
81 if_clauses[0]: IfClauseNode
82 condition: NameNode
83 body: StatListNode
84 stats[0]: ExprStatNode
85 expr: NameNode
86 """, self.treetypes(t))
89 def test_pass_eliminated(self):
90 t = self.run_pipeline([NormalizeTree(None)], u"pass")
91 self.assert_(len(t.stats) == 0)
93 class TestWithTransform(object): # (TransformTest): # Disabled!
95 def test_simplified(self):
96 t = self.run_pipeline([WithTransform(None)], u"""
97 with x:
98 y = z ** 3
99 """)
101 self.assertCode(u"""
103 $0_0 = x
104 $0_2 = $0_0.__exit__
105 $0_0.__enter__()
106 $0_1 = True
107 try:
108 try:
109 $1_0 = None
110 y = z ** 3
111 except:
112 $0_1 = False
113 if (not $0_2($1_0)):
114 raise
115 finally:
116 if $0_1:
117 $0_2(None, None, None)
119 """, t)
121 def test_basic(self):
122 t = self.run_pipeline([WithTransform(None)], u"""
123 with x as y:
124 y = z ** 3
125 """)
126 self.assertCode(u"""
128 $0_0 = x
129 $0_2 = $0_0.__exit__
130 $0_3 = $0_0.__enter__()
131 $0_1 = True
132 try:
133 try:
134 $1_0 = None
135 y = $0_3
136 y = z ** 3
137 except:
138 $0_1 = False
139 if (not $0_2($1_0)):
140 raise
141 finally:
142 if $0_1:
143 $0_2(None, None, None)
145 """, t)
148 class TestInterpretCompilerDirectives(TransformTest):
150 This class tests the parallel directives AST-rewriting and importing.
153 # Test the parallel directives (c)importing
155 import_code = u"""
156 cimport cython.parallel
157 cimport cython.parallel as par
158 from cython cimport parallel as par2
159 from cython cimport parallel
161 from cython.parallel cimport threadid as tid
162 from cython.parallel cimport threadavailable as tavail
163 from cython.parallel cimport prange
166 expected_directives_dict = {
167 u'cython.parallel': u'cython.parallel',
168 u'par': u'cython.parallel',
169 u'par2': u'cython.parallel',
170 u'parallel': u'cython.parallel',
172 u"tid": u"cython.parallel.threadid",
173 u"tavail": u"cython.parallel.threadavailable",
174 u"prange": u"cython.parallel.prange",
178 def setUp(self):
179 super(TestInterpretCompilerDirectives, self).setUp()
181 compilation_options = Main.CompilationOptions(Main.default_options)
182 ctx = compilation_options.create_context()
184 transform = InterpretCompilerDirectives(ctx, ctx.compiler_directives)
185 transform.module_scope = Symtab.ModuleScope('__main__', None, ctx)
186 self.pipeline = [transform]
188 self.debug_exception_on_error = DebugFlags.debug_exception_on_error
190 def tearDown(self):
191 DebugFlags.debug_exception_on_error = self.debug_exception_on_error
193 def test_parallel_directives_cimports(self):
194 self.run_pipeline(self.pipeline, self.import_code)
195 parallel_directives = self.pipeline[0].parallel_directives
196 self.assertEqual(parallel_directives, self.expected_directives_dict)
198 def test_parallel_directives_imports(self):
199 self.run_pipeline(self.pipeline,
200 self.import_code.replace(u'cimport', u'import'))
201 parallel_directives = self.pipeline[0].parallel_directives
202 self.assertEqual(parallel_directives, self.expected_directives_dict)
205 # TODO: Re-enable once they're more robust.
206 if sys.version_info[:2] >= (2, 5) and False:
207 from Cython.Debugger import DebugWriter
208 from Cython.Debugger.Tests.TestLibCython import DebuggerTestCase
209 else:
210 # skip test, don't let it inherit unittest.TestCase
211 DebuggerTestCase = object
213 class TestDebugTransform(DebuggerTestCase):
215 def elem_hasattrs(self, elem, attrs):
216 # we shall supporteth python 2.3 !
217 return all([attr in elem.attrib for attr in attrs])
219 def test_debug_info(self):
220 try:
221 assert os.path.exists(self.debug_dest)
223 t = DebugWriter.etree.parse(self.debug_dest)
224 # the xpath of the standard ElementTree is primitive, don't use
225 # anything fancy
226 L = list(t.find('/Module/Globals'))
227 # assertTrue is retarded, use the normal assert statement
228 assert L
229 xml_globals = dict(
230 [(e.attrib['name'], e.attrib['type']) for e in L])
231 self.assertEqual(len(L), len(xml_globals))
233 L = list(t.find('/Module/Functions'))
234 assert L
235 xml_funcs = dict([(e.attrib['qualified_name'], e) for e in L])
236 self.assertEqual(len(L), len(xml_funcs))
238 # test globals
239 self.assertEqual('CObject', xml_globals.get('c_var'))
240 self.assertEqual('PythonObject', xml_globals.get('python_var'))
242 # test functions
243 funcnames = ('codefile.spam', 'codefile.ham', 'codefile.eggs',
244 'codefile.closure', 'codefile.inner')
245 required_xml_attrs = 'name', 'cname', 'qualified_name'
246 assert all([f in xml_funcs for f in funcnames])
247 spam, ham, eggs = [xml_funcs[funcname] for funcname in funcnames]
249 self.assertEqual(spam.attrib['name'], 'spam')
250 self.assertNotEqual('spam', spam.attrib['cname'])
251 assert self.elem_hasattrs(spam, required_xml_attrs)
253 # test locals of functions
254 spam_locals = list(spam.find('Locals'))
255 assert spam_locals
256 spam_locals.sort(key=lambda e: e.attrib['name'])
257 names = [e.attrib['name'] for e in spam_locals]
258 self.assertEqual(list('abcd'), names)
259 assert self.elem_hasattrs(spam_locals[0], required_xml_attrs)
261 # test arguments of functions
262 spam_arguments = list(spam.find('Arguments'))
263 assert spam_arguments
264 self.assertEqual(1, len(list(spam_arguments)))
266 # test step-into functions
267 step_into = spam.find('StepIntoFunctions')
268 spam_stepinto = [x.attrib['name'] for x in step_into]
269 assert spam_stepinto
270 self.assertEqual(2, len(spam_stepinto))
271 assert 'puts' in spam_stepinto
272 assert 'some_c_function' in spam_stepinto
273 except:
274 f = open(self.debug_dest)
275 try:
276 print(f.read())
277 finally:
278 f.close()
279 raise
283 if __name__ == "__main__":
284 import unittest
285 unittest.main()