3 from urlparse
import urlsplit
, urlunsplit
4 from xml
.dom
.minidom
import parseString
, Node
6 from django
.conf
import settings
7 from django
.core
import mail
8 from django
.core
.management
import call_command
9 from django
.core
.urlresolvers
import clear_url_caches
10 from django
.db
import transaction
11 from django
.http
import QueryDict
12 from django
.test
import _doctest
as doctest
13 from django
.test
.client
import Client
14 from django
.utils
import simplejson
16 normalize_long_ints
= lambda s
: re
.sub(r
'(?<![\w])(\d+)L(?![\w])', '\\1', s
)
20 Puts value into a list if it's not already one.
21 Returns an empty list if value is None.
25 elif not isinstance(value
, list):
30 class OutputChecker(doctest
.OutputChecker
):
31 def check_output(self
, want
, got
, optionflags
):
32 "The entry method for doctest output checking. Defers to a sequence of child checkers"
33 checks
= (self
.check_output_default
,
34 self
.check_output_long
,
35 self
.check_output_xml
,
36 self
.check_output_json
)
38 if check(want
, got
, optionflags
):
42 def check_output_default(self
, want
, got
, optionflags
):
43 "The default comparator provided by doctest - not perfect, but good for most purposes"
44 return doctest
.OutputChecker
.check_output(self
, want
, got
, optionflags
)
46 def check_output_long(self
, want
, got
, optionflags
):
47 """Doctest does an exact string comparison of output, which means long
48 integers aren't equal to normal integers ("22L" vs. "22"). The
49 following code normalizes long integers so that they equal normal
52 return normalize_long_ints(want
) == normalize_long_ints(got
)
54 def check_output_xml(self
, want
, got
, optionsflags
):
55 """Tries to do a 'xml-comparision' of want and got. Plain string
56 comparision doesn't always work because, for example, attribute
57 ordering should not be important.
59 Based on http://codespeak.net/svn/lxml/trunk/src/lxml/doctestcompare.py
61 _norm_whitespace_re
= re
.compile(r
'[ \t\n][ \t\n]+')
62 def norm_whitespace(v
):
63 return _norm_whitespace_re
.sub(' ', v
)
65 def child_text(element
):
66 return ''.join([c
.data
for c
in element
.childNodes
67 if c
.nodeType
== Node
.TEXT_NODE
])
69 def children(element
):
70 return [c
for c
in element
.childNodes
71 if c
.nodeType
== Node
.ELEMENT_NODE
]
73 def norm_child_text(element
):
74 return norm_whitespace(child_text(element
))
76 def attrs_dict(element
):
77 return dict(element
.attributes
.items())
79 def check_element(want_element
, got_element
):
80 if want_element
.tagName
!= got_element
.tagName
:
82 if norm_child_text(want_element
) != norm_child_text(got_element
):
84 if attrs_dict(want_element
) != attrs_dict(got_element
):
86 want_children
= children(want_element
)
87 got_children
= children(got_element
)
88 if len(want_children
) != len(got_children
):
90 for want
, got
in zip(want_children
, got_children
):
91 if not check_element(want
, got
):
95 want
, got
= self
._strip
_quotes
(want
, got
)
96 want
= want
.replace('\\n','\n')
97 got
= got
.replace('\\n','\n')
99 # If the string is not a complete xml document, we may need to add a
100 # root element. This allow us to compare fragments, like "<foo/><bar/>"
101 if not want
.startswith('<?xml'):
102 wrapper
= '<root>%s</root>'
103 want
= wrapper
% want
106 # Parse the want and got strings, and compare the parsings.
108 want_root
= parseString(want
).firstChild
109 got_root
= parseString(got
).firstChild
112 return check_element(want_root
, got_root
)
114 def check_output_json(self
, want
, got
, optionsflags
):
115 "Tries to compare want and got as if they were JSON-encoded data"
116 want
, got
= self
._strip
_quotes
(want
, got
)
118 want_json
= simplejson
.loads(want
)
119 got_json
= simplejson
.loads(got
)
122 return want_json
== got_json
124 def _strip_quotes(self
, want
, got
):
126 Strip quotes of doctests output values:
128 >>> o = OutputChecker()
129 >>> o._strip_quotes("'foo'")
131 >>> o._strip_quotes('"foo"')
133 >>> o._strip_quotes("u'foo'")
135 >>> o._strip_quotes('u"foo"')
138 def is_quoted_string(s
):
142 and s
[0] in ('"', "'"))
144 def is_quoted_unicode(s
):
149 and s
[1] in ('"', "'"))
151 if is_quoted_string(want
) and is_quoted_string(got
):
152 want
= want
.strip()[1:-1]
153 got
= got
.strip()[1:-1]
154 elif is_quoted_unicode(want
) and is_quoted_unicode(got
):
155 want
= want
.strip()[2:-1]
156 got
= got
.strip()[2:-1]
160 class DocTestRunner(doctest
.DocTestRunner
):
161 def __init__(self
, *args
, **kwargs
):
162 doctest
.DocTestRunner
.__init
__(self
, *args
, **kwargs
)
163 self
.optionflags
= doctest
.ELLIPSIS
165 def report_unexpected_exception(self
, out
, test
, example
, exc_info
):
166 doctest
.DocTestRunner
.report_unexpected_exception(self
, out
, test
,
168 # Rollback, in case of database errors. Otherwise they'd have
169 # side effects on other tests.
170 transaction
.rollback_unless_managed()
172 class TestCase(unittest
.TestCase
):
173 def _pre_setup(self
):
174 """Performs any pre-test setup. This includes:
176 * Flushing the database.
177 * If the Test Case class has a 'fixtures' member, installing the
179 * If the Test Case class has a 'urls' member, replace the
180 ROOT_URLCONF with it.
181 * Clearing the mail test outbox.
183 call_command('flush', verbosity
=0, interactive
=False)
184 if hasattr(self
, 'fixtures'):
185 # We have to use this slightly awkward syntax due to the fact
186 # that we're using *args and **kwargs together.
187 call_command('loaddata', *self
.fixtures
, **{'verbosity': 0})
188 if hasattr(self
, 'urls'):
189 self
._old
_root
_urlconf
= settings
.ROOT_URLCONF
190 settings
.ROOT_URLCONF
= self
.urls
194 def __call__(self
, result
=None):
196 Wrapper around default __call__ method to perform common Django test
197 set up. This means that user-defined Test Cases aren't required to
198 include a call to super().setUp().
200 self
.client
= Client()
203 except (KeyboardInterrupt, SystemExit):
207 result
.addError(self
, sys
.exc_info())
209 super(TestCase
, self
).__call
__(result
)
211 self
._post
_teardown
()
212 except (KeyboardInterrupt, SystemExit):
216 result
.addError(self
, sys
.exc_info())
219 def _post_teardown(self
):
220 """ Performs any post-test things. This includes:
222 * Putting back the original ROOT_URLCONF if it was changed.
224 if hasattr(self
, '_old_root_urlconf'):
225 settings
.ROOT_URLCONF
= self
._old
_root
_urlconf
228 def assertRedirects(self
, response
, expected_url
, status_code
=302,
229 target_status_code
=200, host
=None):
230 """Asserts that a response redirected to a specific URL, and that the
231 redirect URL can be loaded.
233 Note that assertRedirects won't work for external links since it uses
234 TestClient to do a request.
236 self
.assertEqual(response
.status_code
, status_code
,
237 ("Response didn't redirect as expected: Response code was %d"
238 " (expected %d)" % (response
.status_code
, status_code
)))
239 url
= response
['Location']
240 scheme
, netloc
, path
, query
, fragment
= urlsplit(url
)
241 e_scheme
, e_netloc
, e_path
, e_query
, e_fragment
= urlsplit(expected_url
)
242 if not (e_scheme
or e_netloc
):
243 expected_url
= urlunsplit(('http', host
or 'testserver', e_path
,
244 e_query
, e_fragment
))
245 self
.assertEqual(url
, expected_url
,
246 "Response redirected to '%s', expected '%s'" % (url
, expected_url
))
248 # Get the redirection page, using the same client that was used
249 # to obtain the original response.
250 redirect_response
= response
.client
.get(path
, QueryDict(query
))
251 self
.assertEqual(redirect_response
.status_code
, target_status_code
,
252 ("Couldn't retrieve redirection page '%s': response code was %d"
254 (path
, redirect_response
.status_code
, target_status_code
))
256 def assertContains(self
, response
, text
, count
=None, status_code
=200):
258 Asserts that a response indicates that a page was retrieved
259 successfully, (i.e., the HTTP status code was as expected), and that
260 ``text`` occurs ``count`` times in the content of the response.
261 If ``count`` is None, the count doesn't matter - the assertion is true
262 if the text occurs at least once in the response.
264 self
.assertEqual(response
.status_code
, status_code
,
265 "Couldn't retrieve page: Response code was %d (expected %d)'" %
266 (response
.status_code
, status_code
))
267 real_count
= response
.content
.count(text
)
268 if count
is not None:
269 self
.assertEqual(real_count
, count
,
270 "Found %d instances of '%s' in response (expected %d)" %
271 (real_count
, text
, count
))
273 self
.failUnless(real_count
!= 0,
274 "Couldn't find '%s' in response" % text
)
276 def assertNotContains(self
, response
, text
, status_code
=200):
278 Asserts that a response indicates that a page was retrieved
279 successfully, (i.e., the HTTP status code was as expected), and that
280 ``text`` doesn't occurs in the content of the response.
282 self
.assertEqual(response
.status_code
, status_code
,
283 "Couldn't retrieve page: Response code was %d (expected %d)'" %
284 (response
.status_code
, status_code
))
285 self
.assertEqual(response
.content
.count(text
), 0,
286 "Response should not contain '%s'" % text
)
288 def assertFormError(self
, response
, form
, field
, errors
):
290 Asserts that a form used to render the response has a specific field
293 # Put context(s) into a list to simplify processing.
294 contexts
= to_list(response
.context
)
296 self
.fail('Response did not use any contexts to render the'
299 # Put error(s) into a list to simplify processing.
300 errors
= to_list(errors
)
302 # Search all contexts for the error.
304 for i
,context
in enumerate(contexts
):
305 if form
not in context
:
310 if field
in context
[form
].errors
:
311 field_errors
= context
[form
].errors
[field
]
312 self
.failUnless(err
in field_errors
,
313 "The field '%s' on form '%s' in"
314 " context %d does not contain the"
315 " error '%s' (actual errors: %s)" %
316 (field
, form
, i
, err
,
318 elif field
in context
[form
].fields
:
319 self
.fail("The field '%s' on form '%s' in context %d"
320 " contains no errors" % (field
, form
, i
))
322 self
.fail("The form '%s' in context %d does not"
323 " contain the field '%s'" %
326 non_field_errors
= context
[form
].non_field_errors()
327 self
.failUnless(err
in non_field_errors
,
328 "The form '%s' in context %d does not contain the"
329 " non-field error '%s' (actual errors: %s)" %
330 (form
, i
, err
, non_field_errors
))
332 self
.fail("The form '%s' was not used to render the response" %
335 def assertTemplateUsed(self
, response
, template_name
):
337 Asserts that the template with the provided name was used in rendering
340 template_names
= [t
.name
for t
in to_list(response
.template
)]
341 if not template_names
:
342 self
.fail('No templates used to render the response')
343 self
.failUnless(template_name
in template_names
,
344 (u
"Template '%s' was not a template used to render the response."
345 u
" Actual template(s) used: %s") % (template_name
,
346 u
', '.join(template_names
)))
348 def assertTemplateNotUsed(self
, response
, template_name
):
350 Asserts that the template with the provided name was NOT used in
351 rendering the response.
353 template_names
= [t
.name
for t
in to_list(response
.template
)]
354 self
.failIf(template_name
in template_names
,
355 (u
"Template '%s' was used unexpectedly in rendering the"
356 u
" response") % template_name
)