2 --[[--------------------------------------------------------------------------
4 This file is part of lunit 0.5.
6 For Details about lunit look at: http://www.mroth.net/lunit/
8 Author: Michael Roth <mroth@nessie.de>
10 Copyright (c) 2004, 2006-2009 Michael Roth <mroth@nessie.de>
12 Permission is hereby granted, free of charge, to any person
13 obtaining a copy of this software and associated documentation
14 files (the "Software"), to deal in the Software without restriction,
15 including without limitation the rights to use, copy, modify, merge,
16 publish, distribute, sublicense, and/or sell copies of the Software,
17 and to permit persons to whom the Software is furnished to do so,
18 subject to the following conditions:
20 The above copyright notice and this permission notice shall be
21 included in all copies or substantial portions of the Software.
23 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24 EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25 MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
26 IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
27 CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
28 TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
29 SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
31 --]]--------------------------------------------------------------------------
36 local orig_assert
= assert
43 local tostring = tostring
45 local string_sub
= string.sub
46 local string_format
= string.format
49 module("lunit", package
.seeall
) -- FIXME: Remove package.seeall
53 local __failure__
= {} -- Type tag for failed assertions
55 local typenames
= { "nil", "boolean", "number", "string", "table", "function", "thread", "userdata" }
59 local traceback_hide
-- Traceback function which hides lunit internals
60 local mypcall
-- Protected call to a function with own traceback
62 local _tb_hide
= setmetatable( {}, {__mode
="k"} )
64 function traceback_hide(func
)
68 local function my_traceback(errobj
)
69 if is_table(errobj
) and errobj
.type == __failure__
then
70 local info
= debug
.getinfo(5, "Sl") -- FIXME: Hardcoded integers are bad...
71 errobj
.where
= string_format( "%s:%d", info
.short_src
, info
.currentline
)
73 errobj
= { msg
= tostring(errobj
) }
77 local info
= debug
.getinfo(i
, "Snlf")
78 if not is_table(info
) then
81 if not _tb_hide
[info
.func
] then
82 local line
= {} -- Ripped from ldblib.c...
83 line
[#line
+1] = string_format("%s:", info
.short_src
)
84 if info
.currentline
> 0 then
85 line
[#line
+1] = string_format("%d:", info
.currentline
)
87 if info
.namewhat
~= "" then
88 line
[#line
+1] = string_format(" in function '%s'", info
.name
)
90 if info
.what
== "main" then
91 line
[#line
+1] = " in main chunk"
92 elseif info
.what
== "C" or info
.what
== "tail" then
95 line
[#line
+1] = string_format(" in function <%s:%d>", info
.short_src
, info
.linedefined
)
98 errobj
.tb
[#errobj
.tb
+1] = table.concat(line
)
106 function mypcall(func
)
107 orig_assert( is_function(func
) )
108 local ok
, errobj
= xpcall(func
, my_traceback
)
113 traceback_hide(mypcall
)
117 -- Type check functions
119 for _
, typename
in ipairs(typenames
) do
120 lunit
["is_"..typename
] = function(x
)
121 return type(x
) == typename
125 local is_nil
= is_nil
126 local is_boolean
= is_boolean
127 local is_number
= is_number
128 local is_string
= is_string
129 local is_table
= is_table
130 local is_function
= is_function
131 local is_thread
= is_thread
132 local is_userdata
= is_userdata
135 local function failure(name
, usermsg
, defaultmsg
, ...)
139 msg
= string_format(defaultmsg
,...),
144 traceback_hide( failure
)
147 local function format_arg(arg
)
148 local argtype
= type(arg
)
149 if argtype
== "string" then
151 elseif argtype
== "number" or argtype
== "boolean" or argtype
== "nil" then
154 return "["..tostring(arg
).."]"
160 stats
.assertions
= stats
.assertions
+ 1
161 failure( "fail", msg
, "failure" )
163 traceback_hide( fail
)
166 function assert(assertion
, msg
)
167 stats
.assertions
= stats
.assertions
+ 1
168 if not assertion
then
169 failure( "assert", msg
, "assertion failed" )
173 traceback_hide( assert )
176 function assert_true(actual
, msg
)
177 stats
.assertions
= stats
.assertions
+ 1
178 local actualtype
= type(actual
)
179 if actualtype
~= "boolean" then
180 failure( "assert_true", msg
, "true expected but was a "..actualtype
)
182 if actual
~= true then
183 failure( "assert_true", msg
, "true expected but was false" )
187 traceback_hide( assert_true
)
190 function assert_false(actual
, msg
)
191 stats
.assertions
= stats
.assertions
+ 1
192 local actualtype
= type(actual
)
193 if actualtype
~= "boolean" then
194 failure( "assert_false", msg
, "false expected but was a "..actualtype
)
196 if actual
~= false then
197 failure( "assert_false", msg
, "false expected but was true" )
201 traceback_hide( assert_false
)
204 function assert_equal(expected
, actual
, msg
)
205 stats
.assertions
= stats
.assertions
+ 1
206 if expected
~= actual
then
207 failure( "assert_equal", msg
, "expected %s but was %s", format_arg(expected
), format_arg(actual
) )
211 traceback_hide( assert_equal
)
214 function assert_not_equal(unexpected
, actual
, msg
)
215 stats
.assertions
= stats
.assertions
+ 1
216 if unexpected
== actual
then
217 failure( "assert_not_equal", msg
, "%s not expected but was one", format_arg(unexpected
) )
221 traceback_hide( assert_not_equal
)
224 function assert_match(pattern
, actual
, msg
)
225 stats
.assertions
= stats
.assertions
+ 1
226 local patterntype
= type(pattern
)
227 if patterntype
~= "string" then
228 failure( "assert_match", msg
, "expected the pattern as a string but was a "..patterntype
)
230 local actualtype
= type(actual
)
231 if actualtype
~= "string" then
232 failure( "assert_match", msg
, "expected a string to match pattern '%s' but was a %s", pattern
, actualtype
)
234 if not string.find(actual
, pattern
) then
235 failure( "assert_match", msg
, "expected '%s' to match pattern '%s' but doesn't", actual
, pattern
)
239 traceback_hide( assert_match
)
242 function assert_not_match(pattern
, actual
, msg
)
243 stats
.assertions
= stats
.assertions
+ 1
244 local patterntype
= type(pattern
)
245 if patterntype
~= "string" then
246 failure( "assert_not_match", msg
, "expected the pattern as a string but was a "..patterntype
)
248 local actualtype
= type(actual
)
249 if actualtype
~= "string" then
250 failure( "assert_not_match", msg
, "expected a string to not match pattern '%s' but was a %s", pattern
, actualtype
)
252 if string.find(actual
, pattern
) then
253 failure( "assert_not_match", msg
, "expected '%s' to not match pattern '%s' but it does", actual
, pattern
)
257 traceback_hide( assert_not_match
)
260 function assert_error(msg
, func
)
261 stats
.assertions
= stats
.assertions
+ 1
265 local functype
= type(func
)
266 if functype
~= "function" then
267 failure( "assert_error", msg
, "expected a function as last argument but was a "..functype
)
269 local ok
, errmsg
= pcall(func
)
271 failure( "assert_error", msg
, "error expected but no error occurred" )
274 traceback_hide( assert_error
)
277 function assert_error_match(msg
, pattern
, func
)
278 stats
.assertions
= stats
.assertions
+ 1
280 msg
, pattern
, func
= nil, msg
, pattern
282 local patterntype
= type(pattern
)
283 if patterntype
~= "string" then
284 failure( "assert_error_match", msg
, "expected the pattern as a string but was a "..patterntype
)
286 local functype
= type(func
)
287 if functype
~= "function" then
288 failure( "assert_error_match", msg
, "expected a function as last argument but was a "..functype
)
290 local ok
, errmsg
= pcall(func
)
292 failure( "assert_error_match", msg
, "error expected but no error occurred" )
294 local errmsgtype
= type(errmsg
)
295 if errmsgtype
~= "string" then
296 failure( "assert_error_match", msg
, "error as string expected but was a "..errmsgtype
)
298 if not string.find(errmsg
, pattern
) then
299 failure( "assert_error_match", msg
, "expected error '%s' to match pattern '%s' but doesn't", errmsg
, pattern
)
302 traceback_hide( assert_error_match
)
305 function assert_pass(msg
, func
)
306 stats
.assertions
= stats
.assertions
+ 1
310 local functype
= type(func
)
311 if functype
~= "function" then
312 failure( "assert_pass", msg
, "expected a function as last argument but was a %s", functype
)
314 local ok
, errmsg
= pcall(func
)
316 failure( "assert_pass", msg
, "no error expected but error was: '%s'", errmsg
)
319 traceback_hide( assert_pass
)
322 -- lunit.assert_typename functions
324 for _
, typename
in ipairs(typenames
) do
325 local assert_typename
= "assert_"..typename
326 lunit
[assert_typename
] = function(actual
, msg
)
327 stats
.assertions
= stats
.assertions
+ 1
328 local actualtype
= type(actual
)
329 if actualtype
~= typename
then
330 failure( assert_typename
, msg
, typename
.." expected but was a "..actualtype
)
334 traceback_hide( lunit
[assert_typename
] )
338 -- lunit.assert_not_typename functions
340 for _
, typename
in ipairs(typenames
) do
341 local assert_not_typename
= "assert_not_"..typename
342 lunit
[assert_not_typename
] = function(actual
, msg
)
343 stats
.assertions
= stats
.assertions
+ 1
344 if type(actual
) == typename
then
345 failure( assert_not_typename
, msg
, typename
.." not expected but was one" )
348 traceback_hide( lunit
[assert_not_typename
] )
352 function lunit
.clearstats()
362 local report
, reporterrobj
366 function lunit
.setrunner(newrunner
)
367 if not ( is_table(newrunner
) or is_nil(newrunner
) ) then
368 return error("lunit.setrunner: Invalid argument", 0)
370 local oldrunner
= testrunner
371 testrunner
= newrunner
375 function lunit
.loadrunner(name
)
376 if not is_string(name
) then
377 return error("lunit.loadrunner: Invalid argument", 0)
379 local ok
, runner
= pcall( require
, name
)
381 return error("lunit.loadrunner: Can't load test runner: "..runner
, 0)
383 return setrunner(runner
)
386 function report(event
, ...)
387 local f
= testrunner
and testrunner
[event
]
388 if is_function(f
) then
393 function reporterrobj(context
, tcname
, testname
, errobj
)
394 local fullname
= tcname
.. "." .. testname
395 if context
== "setup" then
396 fullname
= fullname
.. ":" .. setupname(tcname
, testname
)
397 elseif context
== "teardown" then
398 fullname
= fullname
.. ":" .. teardownname(tcname
, testname
)
400 if errobj
.type == __failure__
then
401 stats
.failed
= stats
.failed
+ 1
402 report("fail", fullname
, errobj
.where
, errobj
.msg
, errobj
.usermsg
)
404 stats
.errors
= stats
.errors
+ 1
405 report("err", fullname
, errobj
.msg
, errobj
.tb
)
412 local function key_iter(t
, k
)
419 -- Array with all registered testcases
420 local _testcases
= {}
422 -- Marks a module as a testcase.
423 -- Applied over a module from module("xyz", lunit.testcase).
424 function lunit
.testcase(m
)
425 orig_assert( is_table(m
) )
426 --orig_assert( m._M == m )
427 orig_assert( is_string(m
._NAME
) )
428 --orig_assert( is_string(m._PACKAGE) )
430 -- Register the module as a testcase
431 _testcases
[m
._NAME
] = m
433 -- Import lunit, fail, assert* and is_* function to the module/testcase
436 for funcname
, func
in pairs(lunit
) do
437 if "assert" == string_sub(funcname
, 1, 6) or "is_" == string_sub(funcname
, 1, 3) then
443 -- Iterator (testcasename) over all Testcases
444 function lunit
.testcases()
445 -- Make a copy of testcases to prevent confusing the iterator when
446 -- new testcase are defined
447 local _testcases2
= {}
448 for k
,v
in pairs(_testcases
) do
449 _testcases2
[k
] = true
451 return key_iter
, _testcases2
, nil
454 function testcase(tcname
)
455 return _testcases
[tcname
]
461 -- Finds a function in a testcase case insensitive
462 local function findfuncname(tcname
, name
)
463 for key
, value
in pairs(testcase(tcname
)) do
464 if is_string(key
) and is_function(value
) and string.lower(key
) == name
then
470 function lunit
.setupname(tcname
)
471 return findfuncname(tcname
, "setup")
474 function lunit
.teardownname(tcname
)
475 return findfuncname(tcname
, "teardown")
478 -- Iterator over all test names in a testcase.
479 -- Have to collect the names first in case one of the test
480 -- functions creates a new global and throws off the iteration.
481 function lunit
.tests(tcname
)
483 for key
, value
in pairs(testcase(tcname
)) do
484 if is_string(key
) and is_function(value
) then
485 local lfn
= string.lower(key
)
486 if string.sub(lfn
, 1, 4) == "test" or string.sub(lfn
, -4) == "test" then
487 testnames
[key
] = true
491 return key_iter
, testnames
, nil
498 function lunit
.runtest(tcname
, testname
)
499 orig_assert( is_string(tcname
) )
500 orig_assert( is_string(testname
) )
502 local function callit(context
, func
)
504 local err
= mypcall(func
)
506 reporterrobj(context
, tcname
, testname
, err
)
512 traceback_hide(callit
)
514 report("run", tcname
, testname
)
516 local tc
= testcase(tcname
)
517 local setup
= tc
[setupname(tcname
)]
518 local test
= tc
[testname
]
519 local teardown
= tc
[teardownname(tcname
)]
521 local setup_ok
= callit( "setup", setup
)
522 local test_ok
= setup_ok
and callit( "test", test
)
523 local teardown_ok
= setup_ok
and callit( "teardown", teardown
)
525 if setup_ok
and test_ok
and teardown_ok
then
526 stats
.passed
= stats
.passed
+ 1
527 report("pass", tcname
, testname
)
530 traceback_hide(runtest
)
537 for testcasename
in lunit
.testcases() do
538 -- Run tests in the testcases
539 for testname
in lunit
.tests(testcasename
) do
540 runtest(testcasename
, testname
)
549 function lunit
.loadonly()
564 local lunitpat2luapat
580 function lunitpat2luapat(str
)
581 return "^" .. string.gsub(str
, "%W", conv
) .. "$"
587 local function in_patternmap(map
, name
)
588 if map
[name
] == true then
591 for _
, pat
in ipairs(map
) do
592 if string.find(name
, pat
) then
607 -- Called from 'lunit' shell script.
612 -- FIXME: Error handling and error messages aren't nice.
614 local function checkarg(optname
, arg
)
615 if not is_string(arg
) then
616 return error("lunit.main: option "..optname
..": argument missing.", 0)
620 local function loadtestcase(filename
)
621 if not is_string(filename
) then
622 return error("lunit.main: invalid argument")
624 local chunk
, err
= loadfile(filename
)
632 local testpatterns
= nil
633 local doloadonly
= false
640 if arg
== "--loadonly" then
642 elseif arg
== "--runner" or arg
== "-r" then
643 local optname
= arg
; i
= i
+ 1; arg
= argv
[i
]
644 checkarg(optname
, arg
)
646 elseif arg
== "--test" or arg
== "-t" then
647 local optname
= arg
; i
= i
+ 1; arg
= argv
[i
]
648 checkarg(optname
, arg
)
649 testpatterns
= testpatterns
or {}
650 testpatterns
[#testpatterns
+1] = arg
651 elseif arg
== "--" then
653 i
= i
+ 1; arg
= argv
[i
]
661 loadrunner(runner
or "lunit-console")
666 return run(testpatterns
)