1 import Cython
.Compiler
.Errors
as Errors
2 from Cython
.CodeWriter
import CodeWriter
3 from Cython
.Compiler
.TreeFragment
import TreeFragment
, strip_common_indent
4 from Cython
.Compiler
.Visitor
import TreeVisitor
, VisitorTransform
5 from Cython
.Compiler
import TreePath
12 class NodeTypeWriter(TreeVisitor
):
14 super(NodeTypeWriter
, self
).__init
__()
18 def visit_Node(self
, node
):
19 if not self
.access_path
:
22 tip
= self
.access_path
[-1]
23 if tip
[2] is not None:
24 name
= u
"%s[%d]" % tip
[1:3]
28 self
.result
.append(u
" " * self
._indents
+
29 u
"%s: %s" % (name
, node
.__class
__.__name
__))
31 self
.visitchildren(node
)
36 """Returns a string representing the tree by class names.
37 There's a leading and trailing whitespace so that it can be
38 compared by simple string comparison while still making test
42 return u
"\n".join([u
""] + w
.result
+ [u
""])
45 class CythonTest(unittest
.TestCase
):
48 self
.listing_file
= Errors
.listing_file
49 self
.echo_file
= Errors
.echo_file
50 Errors
.listing_file
= Errors
.echo_file
= None
53 Errors
.listing_file
= self
.listing_file
54 Errors
.echo_file
= self
.echo_file
56 def assertLines(self
, expected
, result
):
57 "Checks that the given strings or lists of strings are equal line by line"
58 if not isinstance(expected
, list): expected
= expected
.split(u
"\n")
59 if not isinstance(result
, list): result
= result
.split(u
"\n")
60 for idx
, (expected_line
, result_line
) in enumerate(zip(expected
, result
)):
61 self
.assertEqual(expected_line
, result_line
, "Line %d:\nExp: %s\nGot: %s" % (idx
, expected_line
, result_line
))
62 self
.assertEqual(len(expected
), len(result
),
63 "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected
), u
"\n".join(result
)))
65 def codeToLines(self
, tree
):
68 return writer
.result
.lines
70 def codeToString(self
, tree
):
71 return "\n".join(self
.codeToLines(tree
))
73 def assertCode(self
, expected
, result_tree
):
74 result_lines
= self
.codeToLines(result_tree
)
76 expected_lines
= strip_common_indent(expected
.split("\n"))
78 for idx
, (line
, expected_line
) in enumerate(zip(result_lines
, expected_lines
)):
79 self
.assertEqual(expected_line
, line
, "Line %d:\nGot: %s\nExp: %s" % (idx
, line
, expected_line
))
80 self
.assertEqual(len(result_lines
), len(expected_lines
),
81 "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines
), expected
))
83 def assertNodeExists(self
, path
, result_tree
):
84 self
.assertNotEqual(TreePath
.find_first(result_tree
, path
), None,
85 "Path '%s' not found in result tree" % path
)
87 def fragment(self
, code
, pxds
={}, pipeline
=[]):
88 "Simply create a tree fragment using the name of the test-case in parse errors."
90 if name
.startswith("__main__."): name
= name
[len("__main__."):]
91 name
= name
.replace(".", "_")
92 return TreeFragment(code
, name
, pxds
, pipeline
=pipeline
)
94 def treetypes(self
, root
):
95 return treetypes(root
)
97 def should_fail(self
, func
, exc_type
=Exception):
98 """Calls "func" and fails if it doesn't raise the right exception
99 (any exception by default). Also returns the exception in question.
103 self
.fail("Expected an exception of type %r" % exc_type
)
105 self
.assert_(isinstance(e
, exc_type
))
108 def should_not_fail(self
, func
):
109 """Calls func and succeeds if and only if no exception is raised
110 (i.e. converts exception raising into a failed testcase). Returns
111 the return value of func."""
115 self
.fail(str(sys
.exc_info()[1]))
118 class TransformTest(CythonTest
):
120 Utility base class for transform unit tests. It is based around constructing
121 test trees (either explicitly or by parsing a Cython code string); running
122 the transform, serialize it using a customized Cython serializer (with
123 special markup for nodes that cannot be represented in Cython),
124 and do a string-comparison line-by-line of the result.
126 To create a test case:
127 - Call run_pipeline. The pipeline should at least contain the transform you
128 are testing; pyx should be either a string (passed to the parser to
129 create a post-parse tree) or a node representing input to pipeline.
130 The result will be a transformed result.
132 - Check that the tree is correct. If wanted, assertCode can be used, which
133 takes a code string as expected, and a ModuleNode in result_tree
134 (it serializes the ModuleNode to a string and compares line-by-line).
136 All code strings are first stripped for whitespace lines and then common
139 Plans: One could have a pxd dictionary parameter to run_pipeline.
142 def run_pipeline(self
, pipeline
, pyx
, pxds
={}):
143 tree
= self
.fragment(pyx
, pxds
).root
150 class TreeAssertVisitor(VisitorTransform
):
151 # actually, a TreeVisitor would be enough, but this needs to run
152 # as part of the compiler pipeline
154 def visit_CompilerDirectivesNode(self
, node
):
155 directives
= node
.directives
156 if 'test_assert_path_exists' in directives
:
157 for path
in directives
['test_assert_path_exists']:
158 if TreePath
.find_first(node
, path
) is None:
161 "Expected path '%s' not found in result tree" % path
)
162 if 'test_fail_if_path_exists' in directives
:
163 for path
in directives
['test_fail_if_path_exists']:
164 if TreePath
.find_first(node
, path
) is not None:
167 "Unexpected path '%s' found in result tree" % path
)
168 self
.visitchildren(node
)
171 visit_Node
= VisitorTransform
.recurse_to_children
174 def unpack_source_tree(tree_file
, dir=None):
176 dir = tempfile
.mkdtemp()
181 lines
= f
.readlines()
187 if line
[:5] == '#####':
188 filename
= line
.strip().strip('#').strip().replace('/', os
.path
.sep
)
189 path
= os
.path
.join(dir, filename
)
190 if not os
.path
.exists(os
.path
.dirname(path
)):
191 os
.makedirs(os
.path
.dirname(path
))
192 if cur_file
is not None:
193 f
, cur_file
= cur_file
, None
195 cur_file
= open(path
, 'w')
196 elif cur_file
is not None:
198 elif line
.strip() and not line
.lstrip().startswith('#'):
199 if line
.strip() not in ('"""', "'''"):
202 if cur_file
is not None:
204 return dir, ''.join(header
)