Remove obsolete entries from .gitignore.
[chromium-blink-merge.git] / third_party / cython / src / Cython / TestUtils.py
blobbf5886aa72f710182544febbe5a6be8d14baef20
1 import Cython.Compiler.Errors as Errors
2 from Cython.CodeWriter import CodeWriter
3 from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
4 from Cython.Compiler.Visitor import TreeVisitor, VisitorTransform
5 from Cython.Compiler import TreePath
7 import unittest
8 import os, sys
9 import tempfile
12 class NodeTypeWriter(TreeVisitor):
13 def __init__(self):
14 super(NodeTypeWriter, self).__init__()
15 self._indents = 0
16 self.result = []
18 def visit_Node(self, node):
19 if not self.access_path:
20 name = u"(root)"
21 else:
22 tip = self.access_path[-1]
23 if tip[2] is not None:
24 name = u"%s[%d]" % tip[1:3]
25 else:
26 name = tip[1]
28 self.result.append(u" " * self._indents +
29 u"%s: %s" % (name, node.__class__.__name__))
30 self._indents += 1
31 self.visitchildren(node)
32 self._indents -= 1
35 def treetypes(root):
36 """Returns a string representing the tree by class names.
37 There's a leading and trailing whitespace so that it can be
38 compared by simple string comparison while still making test
39 cases look ok."""
40 w = NodeTypeWriter()
41 w.visit(root)
42 return u"\n".join([u""] + w.result + [u""])
45 class CythonTest(unittest.TestCase):
47 def setUp(self):
48 self.listing_file = Errors.listing_file
49 self.echo_file = Errors.echo_file
50 Errors.listing_file = Errors.echo_file = None
52 def tearDown(self):
53 Errors.listing_file = self.listing_file
54 Errors.echo_file = self.echo_file
56 def assertLines(self, expected, result):
57 "Checks that the given strings or lists of strings are equal line by line"
58 if not isinstance(expected, list): expected = expected.split(u"\n")
59 if not isinstance(result, list): result = result.split(u"\n")
60 for idx, (expected_line, result_line) in enumerate(zip(expected, result)):
61 self.assertEqual(expected_line, result_line, "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line))
62 self.assertEqual(len(expected), len(result),
63 "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
65 def codeToLines(self, tree):
66 writer = CodeWriter()
67 writer.write(tree)
68 return writer.result.lines
70 def codeToString(self, tree):
71 return "\n".join(self.codeToLines(tree))
73 def assertCode(self, expected, result_tree):
74 result_lines = self.codeToLines(result_tree)
76 expected_lines = strip_common_indent(expected.split("\n"))
78 for idx, (line, expected_line) in enumerate(zip(result_lines, expected_lines)):
79 self.assertEqual(expected_line, line, "Line %d:\nGot: %s\nExp: %s" % (idx, line, expected_line))
80 self.assertEqual(len(result_lines), len(expected_lines),
81 "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
83 def assertNodeExists(self, path, result_tree):
84 self.assertNotEqual(TreePath.find_first(result_tree, path), None,
85 "Path '%s' not found in result tree" % path)
87 def fragment(self, code, pxds={}, pipeline=[]):
88 "Simply create a tree fragment using the name of the test-case in parse errors."
89 name = self.id()
90 if name.startswith("__main__."): name = name[len("__main__."):]
91 name = name.replace(".", "_")
92 return TreeFragment(code, name, pxds, pipeline=pipeline)
94 def treetypes(self, root):
95 return treetypes(root)
97 def should_fail(self, func, exc_type=Exception):
98 """Calls "func" and fails if it doesn't raise the right exception
99 (any exception by default). Also returns the exception in question.
101 try:
102 func()
103 self.fail("Expected an exception of type %r" % exc_type)
104 except exc_type, e:
105 self.assert_(isinstance(e, exc_type))
106 return e
108 def should_not_fail(self, func):
109 """Calls func and succeeds if and only if no exception is raised
110 (i.e. converts exception raising into a failed testcase). Returns
111 the return value of func."""
112 try:
113 return func()
114 except:
115 self.fail(str(sys.exc_info()[1]))
118 class TransformTest(CythonTest):
120 Utility base class for transform unit tests. It is based around constructing
121 test trees (either explicitly or by parsing a Cython code string); running
122 the transform, serialize it using a customized Cython serializer (with
123 special markup for nodes that cannot be represented in Cython),
124 and do a string-comparison line-by-line of the result.
126 To create a test case:
127 - Call run_pipeline. The pipeline should at least contain the transform you
128 are testing; pyx should be either a string (passed to the parser to
129 create a post-parse tree) or a node representing input to pipeline.
130 The result will be a transformed result.
132 - Check that the tree is correct. If wanted, assertCode can be used, which
133 takes a code string as expected, and a ModuleNode in result_tree
134 (it serializes the ModuleNode to a string and compares line-by-line).
136 All code strings are first stripped for whitespace lines and then common
137 indentation.
139 Plans: One could have a pxd dictionary parameter to run_pipeline.
142 def run_pipeline(self, pipeline, pyx, pxds={}):
143 tree = self.fragment(pyx, pxds).root
144 # Run pipeline
145 for T in pipeline:
146 tree = T(tree)
147 return tree
150 class TreeAssertVisitor(VisitorTransform):
151 # actually, a TreeVisitor would be enough, but this needs to run
152 # as part of the compiler pipeline
154 def visit_CompilerDirectivesNode(self, node):
155 directives = node.directives
156 if 'test_assert_path_exists' in directives:
157 for path in directives['test_assert_path_exists']:
158 if TreePath.find_first(node, path) is None:
159 Errors.error(
160 node.pos,
161 "Expected path '%s' not found in result tree" % path)
162 if 'test_fail_if_path_exists' in directives:
163 for path in directives['test_fail_if_path_exists']:
164 if TreePath.find_first(node, path) is not None:
165 Errors.error(
166 node.pos,
167 "Unexpected path '%s' found in result tree" % path)
168 self.visitchildren(node)
169 return node
171 visit_Node = VisitorTransform.recurse_to_children
174 def unpack_source_tree(tree_file, dir=None):
175 if dir is None:
176 dir = tempfile.mkdtemp()
177 header = []
178 cur_file = None
179 f = open(tree_file)
180 try:
181 lines = f.readlines()
182 finally:
183 f.close()
184 del f
185 try:
186 for line in lines:
187 if line[:5] == '#####':
188 filename = line.strip().strip('#').strip().replace('/', os.path.sep)
189 path = os.path.join(dir, filename)
190 if not os.path.exists(os.path.dirname(path)):
191 os.makedirs(os.path.dirname(path))
192 if cur_file is not None:
193 f, cur_file = cur_file, None
194 f.close()
195 cur_file = open(path, 'w')
196 elif cur_file is not None:
197 cur_file.write(line)
198 elif line.strip() and not line.lstrip().startswith('#'):
199 if line.strip() not in ('"""', "'''"):
200 header.append(line)
201 finally:
202 if cur_file is not None:
203 cur_file.close()
204 return dir, ''.join(header)