3 from Cython
.Compiler
import CmdLine
4 from Cython
.TestUtils
import TransformTest
5 from Cython
.Compiler
.ParseTreeTransforms
import *
6 from Cython
.Compiler
.Nodes
import *
7 from Cython
.Compiler
import Main
, Symtab
10 class TestNormalizeTree(TransformTest
):
11 def test_parserbehaviour_is_what_we_coded_for(self
):
12 t
= self
.fragment(u
"if x: y").root
16 if_clauses[0]: IfClauseNode
20 """, self
.treetypes(t
))
22 def test_wrap_singlestat(self
):
23 t
= self
.run_pipeline([NormalizeTree(None)], u
"if x: y")
27 if_clauses[0]: IfClauseNode
30 stats[0]: ExprStatNode
32 """, self
.treetypes(t
))
34 def test_wrap_multistat(self
):
35 t
= self
.run_pipeline([NormalizeTree(None)], u
"""
43 if_clauses[0]: IfClauseNode
46 stats[0]: ExprStatNode
48 stats[1]: ExprStatNode
50 """, self
.treetypes(t
))
52 def test_statinexpr(self
):
53 t
= self
.run_pipeline([NormalizeTree(None)], u
"""
58 stats[0]: SingleAssignmentNode
65 """, self
.treetypes(t
))
67 def test_wrap_offagain(self
):
68 t
= self
.run_pipeline([NormalizeTree(None)], u
"""
76 stats[0]: ExprStatNode
78 stats[1]: ExprStatNode
81 if_clauses[0]: IfClauseNode
84 stats[0]: ExprStatNode
86 """, self
.treetypes(t
))
89 def test_pass_eliminated(self
):
90 t
= self
.run_pipeline([NormalizeTree(None)], u
"pass")
91 self
.assert_(len(t
.stats
) == 0)
93 class TestWithTransform(object): # (TransformTest): # Disabled!
95 def test_simplified(self
):
96 t
= self
.run_pipeline([WithTransform(None)], u
"""
117 $0_2(None, None, None)
121 def test_basic(self
):
122 t
= self
.run_pipeline([WithTransform(None)], u
"""
130 $0_3 = $0_0.__enter__()
143 $0_2(None, None, None)
148 class TestInterpretCompilerDirectives(TransformTest
):
150 This class tests the parallel directives AST-rewriting and importing.
153 # Test the parallel directives (c)importing
156 cimport cython.parallel
157 cimport cython.parallel as par
158 from cython cimport parallel as par2
159 from cython cimport parallel
161 from cython.parallel cimport threadid as tid
162 from cython.parallel cimport threadavailable as tavail
163 from cython.parallel cimport prange
166 expected_directives_dict
= {
167 u
'cython.parallel': u
'cython.parallel',
168 u
'par': u
'cython.parallel',
169 u
'par2': u
'cython.parallel',
170 u
'parallel': u
'cython.parallel',
172 u
"tid": u
"cython.parallel.threadid",
173 u
"tavail": u
"cython.parallel.threadavailable",
174 u
"prange": u
"cython.parallel.prange",
179 super(TestInterpretCompilerDirectives
, self
).setUp()
181 compilation_options
= Main
.CompilationOptions(Main
.default_options
)
182 ctx
= compilation_options
.create_context()
184 transform
= InterpretCompilerDirectives(ctx
, ctx
.compiler_directives
)
185 transform
.module_scope
= Symtab
.ModuleScope('__main__', None, ctx
)
186 self
.pipeline
= [transform
]
188 self
.debug_exception_on_error
= DebugFlags
.debug_exception_on_error
191 DebugFlags
.debug_exception_on_error
= self
.debug_exception_on_error
193 def test_parallel_directives_cimports(self
):
194 self
.run_pipeline(self
.pipeline
, self
.import_code
)
195 parallel_directives
= self
.pipeline
[0].parallel_directives
196 self
.assertEqual(parallel_directives
, self
.expected_directives_dict
)
198 def test_parallel_directives_imports(self
):
199 self
.run_pipeline(self
.pipeline
,
200 self
.import_code
.replace(u
'cimport', u
'import'))
201 parallel_directives
= self
.pipeline
[0].parallel_directives
202 self
.assertEqual(parallel_directives
, self
.expected_directives_dict
)
205 # TODO: Re-enable once they're more robust.
206 if sys
.version_info
[:2] >= (2, 5) and False:
207 from Cython
.Debugger
import DebugWriter
208 from Cython
.Debugger
.Tests
.TestLibCython
import DebuggerTestCase
210 # skip test, don't let it inherit unittest.TestCase
211 DebuggerTestCase
= object
213 class TestDebugTransform(DebuggerTestCase
):
215 def elem_hasattrs(self
, elem
, attrs
):
216 # we shall supporteth python 2.3 !
217 return all([attr
in elem
.attrib
for attr
in attrs
])
219 def test_debug_info(self
):
221 assert os
.path
.exists(self
.debug_dest
)
223 t
= DebugWriter
.etree
.parse(self
.debug_dest
)
224 # the xpath of the standard ElementTree is primitive, don't use
226 L
= list(t
.find('/Module/Globals'))
227 # assertTrue is retarded, use the normal assert statement
230 [(e
.attrib
['name'], e
.attrib
['type']) for e
in L
])
231 self
.assertEqual(len(L
), len(xml_globals
))
233 L
= list(t
.find('/Module/Functions'))
235 xml_funcs
= dict([(e
.attrib
['qualified_name'], e
) for e
in L
])
236 self
.assertEqual(len(L
), len(xml_funcs
))
239 self
.assertEqual('CObject', xml_globals
.get('c_var'))
240 self
.assertEqual('PythonObject', xml_globals
.get('python_var'))
243 funcnames
= ('codefile.spam', 'codefile.ham', 'codefile.eggs',
244 'codefile.closure', 'codefile.inner')
245 required_xml_attrs
= 'name', 'cname', 'qualified_name'
246 assert all([f
in xml_funcs
for f
in funcnames
])
247 spam
, ham
, eggs
= [xml_funcs
[funcname
] for funcname
in funcnames
]
249 self
.assertEqual(spam
.attrib
['name'], 'spam')
250 self
.assertNotEqual('spam', spam
.attrib
['cname'])
251 assert self
.elem_hasattrs(spam
, required_xml_attrs
)
253 # test locals of functions
254 spam_locals
= list(spam
.find('Locals'))
256 spam_locals
.sort(key
=lambda e
: e
.attrib
['name'])
257 names
= [e
.attrib
['name'] for e
in spam_locals
]
258 self
.assertEqual(list('abcd'), names
)
259 assert self
.elem_hasattrs(spam_locals
[0], required_xml_attrs
)
261 # test arguments of functions
262 spam_arguments
= list(spam
.find('Arguments'))
263 assert spam_arguments
264 self
.assertEqual(1, len(list(spam_arguments
)))
266 # test step-into functions
267 step_into
= spam
.find('StepIntoFunctions')
268 spam_stepinto
= [x
.attrib
['name'] for x
in step_into
]
270 self
.assertEqual(2, len(spam_stepinto
))
271 assert 'puts' in spam_stepinto
272 assert 'some_c_function' in spam_stepinto
274 f
= open(self
.debug_dest
)
283 if __name__
== "__main__":