3 # Copyright 2008 Google Inc.
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
17 # This file is used for testing. The original is at:
18 # http://code.google.com/p/pymox/
20 """Mox, an object-mocking framework for Python.
22 Mox works in the record-replay-verify paradigm. When you first create
23 a mock object, it is in record mode. You then programmatically set
24 the expected behavior of the mock object (what methods are to be
25 called on it, with what parameters, what they should return, and in
28 Once you have set up the expected mock behavior, you put it in replay
29 mode. Now the mock responds to method calls just as you told it to.
30 If an unexpected method (or an expected method with unexpected
31 parameters) is called, then an exception will be raised.
33 Once you are done interacting with the mock, you need to verify that
34 all the expected interactions occured. (Maybe your code exited
35 prematurely without calling some cleanup method!) The verify phase
36 ensures that every expected method was called; otherwise, an exception
39 Suggested usage / workflow:
44 # Create a mock data access object
45 mock_dao = my_mox.CreateMock(DAOClass)
47 # Set up expected behavior
48 mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person)
49 mock_dao.DeletePerson(person)
51 # Put mocks in replay mode
54 # Inject mock object and run test
55 controller.SetDao(mock_dao)
56 controller.DeletePersonById('1')
58 # Verify all methods were called as expected
62 from collections
import deque
69 class Error(AssertionError):
70 """Base exception for this module."""
75 class ExpectedMethodCallsError(Error
):
76 """Raised when Verify() is called before all expected methods have been called
79 def __init__(self
, expected_methods
):
83 # expected_methods: A sequence of MockMethod objects that should have been
85 expected_methods: [MockMethod]
88 ValueError: if expected_methods contains no methods.
91 if not expected_methods
:
92 raise ValueError("There must be at least one expected method")
94 self
._expected
_methods
= expected_methods
97 calls
= "\n".join(["%3d. %s" % (i
, m
)
98 for i
, m
in enumerate(self
._expected
_methods
)])
99 return "Verify: Expected methods never called:\n%s" % (calls
,)
102 class UnexpectedMethodCallError(Error
):
103 """Raised when an unexpected method is called.
105 This can occur if a method is called with incorrect parameters, or out of the
109 def __init__(self
, unexpected_method
, expected
):
113 # unexpected_method: MockMethod that was called but was not at the head of
114 # the expected_method queue.
115 # expected: MockMethod or UnorderedGroup the method should have
117 unexpected_method: MockMethod
118 expected: MockMethod or UnorderedGroup
122 self
._unexpected
_method
= unexpected_method
123 self
._expected
= expected
126 return "Unexpected method call: %s. Expecting: %s" % \
127 (self
._unexpected
_method
, self
._expected
)
130 class UnknownMethodCallError(Error
):
131 """Raised if an unknown method is requested of the mock object."""
133 def __init__(self
, unknown_method_name
):
137 # unknown_method_name: Method call that is not part of the mocked class's
139 unknown_method_name: str
143 self
._unknown
_method
_name
= unknown_method_name
146 return "Method called is not a member of the object: %s" % \
147 self
._unknown
_method
_name
151 """Mox: a factory for creating mock objects."""
153 # A list of types that should be stubbed out with MockObjects (as
154 # opposed to MockAnythings).
155 _USE_MOCK_OBJECT
= [types
.ClassType
, types
.InstanceType
, types
.ModuleType
,
156 types
.ObjectType
, types
.TypeType
]
159 """Initialize a new Mox."""
161 self
._mock
_objects
= []
162 self
.stubs
= stubout
.StubOutForTesting()
164 def CreateMock(self
, class_to_mock
):
165 """Create a new mock object.
168 # class_to_mock: the class to be mocked
172 MockObject that can be used as the class_to_mock would be.
175 new_mock
= MockObject(class_to_mock
)
176 self
._mock
_objects
.append(new_mock
)
179 def CreateMockAnything(self
):
180 """Create a mock that will accept any method calls.
182 This does not enforce an interface.
185 new_mock
= MockAnything()
186 self
._mock
_objects
.append(new_mock
)
190 """Set all mock objects to replay mode."""
192 for mock_obj
in self
._mock
_objects
:
197 """Call verify on all mock objects created."""
199 for mock_obj
in self
._mock
_objects
:
203 """Call reset on all mock objects. This does not unset stubs."""
205 for mock_obj
in self
._mock
_objects
:
208 def StubOutWithMock(self
, obj
, attr_name
, use_mock_anything
=False):
209 """Replace a method, attribute, etc. with a Mock.
211 This will replace a class or module with a MockObject, and everything else
212 (method, function, etc) with a MockAnything. This can be overridden to
213 always use a MockAnything by setting use_mock_anything to True.
216 obj: A Python object (class, module, instance, callable).
217 attr_name: str. The name of the attribute to replace with a mock.
218 use_mock_anything: bool. True if a MockAnything should be used regardless
219 of the type of attribute.
222 attr_to_replace
= getattr(obj
, attr_name
)
223 if type(attr_to_replace
) in self
._USE
_MOCK
_OBJECT
and not use_mock_anything
:
224 stub
= self
.CreateMock(attr_to_replace
)
226 stub
= self
.CreateMockAnything()
228 self
.stubs
.Set(obj
, attr_name
, stub
)
230 def UnsetStubs(self
):
231 """Restore stubs to their original state."""
233 self
.stubs
.UnsetAll()
236 """Put mocks into Replay mode.
239 # args is any number of mocks to put into replay mode.
250 # args is any number of mocks to be verified.
261 # args is any number of mocks to be reset.
269 """A mock that can be used to mock anything.
271 This is helpful for mocking classes that do not provide a public interface.
278 def __getattr__(self
, method_name
):
279 """Intercept method calls on this object.
281 A new MockMethod is returned that is aware of the MockAnything's
282 state (record or replay). The call will be recorded or replayed
283 by the MockMethod's __call__.
286 # method name: the name of the method being called.
290 A new MockMethod aware of MockAnything's state (record or replay).
293 return self
._CreateMockMethod
(method_name
)
295 def _CreateMockMethod(self
, method_name
):
296 """Create a new mock method call and return it.
299 # method name: the name of the method being called.
303 A new MockMethod aware of MockAnything's state (record or replay).
306 return MockMethod(method_name
, self
._expected
_calls
_queue
,
309 def __nonzero__(self
):
310 """Return 1 for nonzero so the mock can be used as a conditional."""
314 def __eq__(self
, rhs
):
315 """Provide custom logic to compare objects."""
317 return (isinstance(rhs
, MockAnything
) and
318 self
._replay
_mode
== rhs
._replay
_mode
and
319 self
._expected
_calls
_queue
== rhs
._expected
_calls
_queue
)
321 def __ne__(self
, rhs
):
322 """Provide custom logic to compare objects."""
324 return not self
== rhs
327 """Start replaying expected method calls."""
329 self
._replay
_mode
= True
332 """Verify that all of the expected calls have been made.
335 ExpectedMethodCallsError: if there are still more method calls in the
339 # If the list of expected calls is not empty, raise an exception
340 if self
._expected
_calls
_queue
:
341 # The last MultipleTimesGroup is not popped from the queue.
342 if (len(self
._expected
_calls
_queue
) == 1 and
343 isinstance(self
._expected
_calls
_queue
[0], MultipleTimesGroup
) and
344 self
._expected
_calls
_queue
[0].IsSatisfied()):
347 raise ExpectedMethodCallsError(self
._expected
_calls
_queue
)
350 """Reset the state of this mock to record mode with an empty queue."""
352 # Maintain a list of method calls we are expecting
353 self
._expected
_calls
_queue
= deque()
355 # Make sure we are in setup mode, not replay mode
356 self
._replay
_mode
= False
359 class MockObject(MockAnything
, object):
360 """A mock object that simulates the public/protected interface of a class."""
362 def __init__(self
, class_to_mock
):
363 """Initialize a mock object.
365 This determines the methods and properties of the class and stores them.
368 # class_to_mock: class to be mocked
372 # This is used to hack around the mixin/inheritance of MockAnything, which
373 # is not a proper object (it can be anything. :-)
374 MockAnything
.__dict
__['__init__'](self
)
376 # Get a list of all the public and special methods we should mock.
377 self
._known
_methods
= set()
378 self
._known
_vars
= set()
379 self
._class
_to
_mock
= class_to_mock
380 for method
in dir(class_to_mock
):
381 if callable(getattr(class_to_mock
, method
)):
382 self
._known
_methods
.add(method
)
384 self
._known
_vars
.add(method
)
386 def __getattr__(self
, name
):
387 """Intercept attribute request on this object.
389 If the attribute is a public class variable, it will be returned and not
392 If the attribute is not a variable, it is handled like a method
393 call. The method name is checked against the set of mockable
394 methods, and a new MockMethod is returned that is aware of the
395 MockObject's state (record or replay). The call will be recorded
396 or replayed by the MockMethod's __call__.
399 # name: the name of the attribute being requested.
403 Either a class variable or a new MockMethod that is aware of the state
404 of the mock (record or replay).
407 UnknownMethodCallError if the MockObject does not mock the requested
411 if name
in self
._known
_vars
:
412 return getattr(self
._class
_to
_mock
, name
)
414 if name
in self
._known
_methods
:
415 return self
._CreateMockMethod
(name
)
417 raise UnknownMethodCallError(name
)
419 def __eq__(self
, rhs
):
420 """Provide custom logic to compare objects."""
422 return (isinstance(rhs
, MockObject
) and
423 self
._class
_to
_mock
== rhs
._class
_to
_mock
and
424 self
._replay
_mode
== rhs
._replay
_mode
and
425 self
._expected
_calls
_queue
== rhs
._expected
_calls
_queue
)
427 def __setitem__(self
, key
, value
):
428 """Provide custom logic for mocking classes that support item assignment.
431 key: Key to set the value for.
435 Expected return value in replay mode. A MockMethod object for the
436 __setitem__ method that has already been called if not in replay mode.
439 TypeError if the underlying class does not support item assignment.
440 UnexpectedMethodCallError if the object does not expect the call to
444 setitem
= self
._class
_to
_mock
.__dict
__.get('__setitem__', None)
446 # Verify the class supports item assignment.
448 raise TypeError('object does not support item assignment')
450 # If we are in replay mode then simply call the mock __setitem__ method.
451 if self
._replay
_mode
:
452 return MockMethod('__setitem__', self
._expected
_calls
_queue
,
453 self
._replay
_mode
)(key
, value
)
456 # Otherwise, create a mock method __setitem__.
457 return self
._CreateMockMethod
('__setitem__')(key
, value
)
459 def __getitem__(self
, key
):
460 """Provide custom logic for mocking classes that are subscriptable.
463 key: Key to return the value for.
466 Expected return value in replay mode. A MockMethod object for the
467 __getitem__ method that has already been called if not in replay mode.
470 TypeError if the underlying class is not subscriptable.
471 UnexpectedMethodCallError if the object does not expect the call to
475 getitem
= self
._class
_to
_mock
.__dict
__.get('__getitem__', None)
477 # Verify the class supports item assignment.
479 raise TypeError('unsubscriptable object')
481 # If we are in replay mode then simply call the mock __getitem__ method.
482 if self
._replay
_mode
:
483 return MockMethod('__getitem__', self
._expected
_calls
_queue
,
484 self
._replay
_mode
)(key
)
487 # Otherwise, create a mock method __getitem__.
488 return self
._CreateMockMethod
('__getitem__')(key
)
490 def __call__(self
, *params
, **named_params
):
491 """Provide custom logic for mocking classes that are callable."""
493 # Verify the class we are mocking is callable
494 callable = self
._class
_to
_mock
.__dict
__.get('__call__', None)
496 raise TypeError('Not callable')
498 # Because the call is happening directly on this object instead of a method,
499 # the call on the mock method is made right here
500 mock_method
= self
._CreateMockMethod
('__call__')
501 return mock_method(*params
, **named_params
)
505 """Return the class that is being mocked."""
507 return self
._class
_to
_mock
510 class MockMethod(object):
511 """Callable mock method.
513 A MockMethod should act exactly like the method it mocks, accepting parameters
514 and returning a value, or throwing an exception (as specified). When this
515 method is called, it can optionally verify whether the called method (name and
516 signature) matches the expected method.
519 def __init__(self
, method_name
, call_queue
, replay_mode
):
520 """Construct a new mock method.
523 # method_name: the name of the method
524 # call_queue: deque of calls, verify this call against the head, or add
525 # this call to the queue.
526 # replay_mode: False if we are recording, True if we are verifying calls
527 # against the call queue.
529 call_queue: list or deque
533 self
._name
= method_name
534 self
._call
_queue
= call_queue
535 if not isinstance(call_queue
, deque
):
536 self
._call
_queue
= deque(self
._call
_queue
)
537 self
._replay
_mode
= replay_mode
540 self
._named
_params
= None
541 self
._return
_value
= None
542 self
._exception
= None
543 self
._side
_effects
= None
545 def __call__(self
, *params
, **named_params
):
546 """Log parameters and return the specified return value.
548 If the Mock(Anything/Object) associated with this call is in record mode,
549 this MockMethod will be pushed onto the expected call queue. If the mock
550 is in replay mode, this will pop a MockMethod off the top of the queue and
551 verify this call is equal to the expected call.
554 UnexpectedMethodCall if this call is supposed to match an expected method
555 call and it does not.
558 self
._params
= params
559 self
._named
_params
= named_params
561 if not self
._replay
_mode
:
562 self
._call
_queue
.append(self
)
565 expected_method
= self
._VerifyMethodCall
()
567 if expected_method
._side
_effects
:
568 expected_method
._side
_effects
(*params
, **named_params
)
570 if expected_method
._exception
:
571 raise expected_method
._exception
573 return expected_method
._return
_value
575 def __getattr__(self
, name
):
576 """Raise an AttributeError with a helpful message."""
578 raise AttributeError('MockMethod has no attribute "%s". '
579 'Did you remember to put your mocks in replay mode?' % name
)
581 def _PopNextMethod(self
):
582 """Pop the next method from our call queue."""
584 return self
._call
_queue
.popleft()
586 raise UnexpectedMethodCallError(self
, None)
588 def _VerifyMethodCall(self
):
589 """Verify the called method is expected.
591 This can be an ordered method, or part of an unordered set.
594 The expected mock method.
597 UnexpectedMethodCall if the method called was not expected.
600 expected
= self
._PopNextMethod
()
602 # Loop here, because we might have a MethodGroup followed by another
604 while isinstance(expected
, MethodGroup
):
605 expected
, method
= expected
.MethodCalled(self
)
606 if method
is not None:
609 # This is a mock method, so just check equality.
611 raise UnexpectedMethodCallError(self
, expected
)
617 [repr(p
) for p
in self
._params
or []] +
618 ['%s=%r' % x
for x
in sorted((self
._named
_params
or {}).items())])
619 desc
= "%s(%s) -> %r" % (self
._name
, params
, self
._return
_value
)
622 def __eq__(self
, rhs
):
623 """Test whether this MockMethod is equivalent to another MockMethod.
626 # rhs: the right hand side of the test
630 return (isinstance(rhs
, MockMethod
) and
631 self
._name
== rhs
._name
and
632 self
._params
== rhs
._params
and
633 self
._named
_params
== rhs
._named
_params
)
635 def __ne__(self
, rhs
):
636 """Test whether this MockMethod is not equivalent to another MockMethod.
639 # rhs: the right hand side of the test
643 return not self
== rhs
645 def GetPossibleGroup(self
):
646 """Returns a possible group from the end of the call queue or None if no
647 other methods are on the stack.
650 # Remove this method from the tail of the queue so we can add it to a group.
651 this_method
= self
._call
_queue
.pop()
652 assert this_method
== self
654 # Determine if the tail of the queue is a group, or just a regular ordered
658 group
= self
._call
_queue
[-1]
664 def _CheckAndCreateNewGroup(self
, group_name
, group_class
):
665 """Checks if the last method (a possible group) is an instance of our
666 group_class. Adds the current method to this group or creates a new one.
670 group_name: the name of the group.
671 group_class: the class used to create instance of this new group
673 group
= self
.GetPossibleGroup()
675 # If this is a group, and it is the correct group, add the method.
676 if isinstance(group
, group_class
) and group
.group_name() == group_name
:
677 group
.AddMethod(self
)
680 # Create a new group and add the method.
681 new_group
= group_class(group_name
)
682 new_group
.AddMethod(self
)
683 self
._call
_queue
.append(new_group
)
686 def InAnyOrder(self
, group_name
="default"):
687 """Move this method into a group of unordered calls.
689 A group of unordered calls must be defined together, and must be executed
690 in full before the next expected method can be called. There can be
691 multiple groups that are expected serially, if they are given
692 different group names. The same group name can be reused if there is a
693 standard method call, or a group with a different name, spliced between
697 group_name: the name of the unordered group.
702 return self
._CheckAndCreateNewGroup
(group_name
, UnorderedGroup
)
704 def MultipleTimes(self
, group_name
="default"):
705 """Move this method into group of calls which may be called multiple times.
707 A group of repeating calls must be defined together, and must be executed in
708 full before the next expected mehtod can be called.
711 group_name: the name of the unordered group.
716 return self
._CheckAndCreateNewGroup
(group_name
, MultipleTimesGroup
)
718 def AndReturn(self
, return_value
):
719 """Set the value to return when this method is called.
722 # return_value can be anything.
725 self
._return
_value
= return_value
728 def AndRaise(self
, exception
):
729 """Set the exception to raise when this method is called.
732 # exception: the exception to raise when this method is called.
736 self
._exception
= exception
738 def WithSideEffects(self
, side_effects
):
739 """Set the side effects that are simulated when this method is called.
742 side_effects: A callable which modifies the parameters or other relevant
743 state which a given test case depends on.
746 Self for chaining with AndReturn and AndRaise.
748 self
._side
_effects
= side_effects
752 """Base class for all Mox comparators.
754 A Comparator can be used as a parameter to a mocked method when the exact
755 value is not known. For example, the code you are testing might build up a
756 long SQL string that is passed to your mock DAO. You're only interested that
757 the IN clause contains the proper primary keys, so you can set your mock
760 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
762 Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'.
764 A Comparator may replace one or more parameters, for example:
765 # return at most 10 rows
766 mock_dao.RunQuery(StrContains('SELECT'), 10)
770 # Return some non-deterministic number of rows
771 mock_dao.RunQuery(StrContains('SELECT'), IsA(int))
774 def equals(self
, rhs
):
775 """Special equals method that all comparators must implement.
778 rhs: any python object
781 raise NotImplementedError, 'method must be implemented by a subclass.'
783 def __eq__(self
, rhs
):
784 return self
.equals(rhs
)
786 def __ne__(self
, rhs
):
787 return not self
.equals(rhs
)
790 class IsA(Comparator
):
791 """This class wraps a basic Python type or class. It is used to verify
792 that a parameter is of the given type or class.
795 mock_dao.Connect(IsA(DbConnectInfo))
798 def __init__(self
, class_name
):
802 class_name: basic python type or a class
805 self
._class
_name
= class_name
807 def equals(self
, rhs
):
808 """Check to see if the RHS is an instance of class_name.
811 # rhs: the right hand side of the test
819 return isinstance(rhs
, self
._class
_name
)
821 # Check raw types if there was a type error. This is helpful for
822 # things like cStringIO.StringIO.
823 return type(rhs
) == type(self
._class
_name
)
826 return str(self
._class
_name
)
828 class IsAlmost(Comparator
):
829 """Comparison class used to check whether a parameter is nearly equal
830 to a given value. Generally useful for floating point numbers.
832 Example mock_dao.SetTimeout((IsAlmost(3.9)))
835 def __init__(self
, float_value
, places
=7):
836 """Initialize IsAlmost.
839 float_value: The value for making the comparison.
840 places: The number of decimal places to round to.
843 self
._float
_value
= float_value
844 self
._places
= places
846 def equals(self
, rhs
):
847 """Check to see if RHS is almost equal to float_value
850 rhs: the value to compare to float_value
857 return round(rhs
-self
._float
_value
, self
._places
) == 0
859 # This is probably because either float_value or rhs is not a number.
863 return str(self
._float
_value
)
865 class StrContains(Comparator
):
866 """Comparison class used to check whether a substring exists in a
867 string parameter. This can be useful in mocking a database with SQL
868 passed in as a string parameter, for example.
871 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
874 def __init__(self
, search_string
):
878 # search_string: the string you are searching for
882 self
._search
_string
= search_string
884 def equals(self
, rhs
):
885 """Check to see if the search_string is contained in the rhs string.
888 # rhs: the right hand side of the test
896 return rhs
.find(self
._search
_string
) > -1
901 return '<str containing \'%s\'>' % self
._search
_string
904 class Regex(Comparator
):
905 """Checks if a string matches a regular expression.
907 This uses a given regular expression to determine equality.
910 def __init__(self
, pattern
, flags
=0):
914 # pattern is the regular expression to search for
916 # flags passed to re.compile function as the second argument
920 self
.regex
= re
.compile(pattern
, flags
=flags
)
922 def equals(self
, rhs
):
923 """Check to see if rhs matches regular expression pattern.
929 return self
.regex
.search(rhs
) is not None
932 s
= '<regular expression \'%s\'' % self
.regex
.pattern
934 s
+= ', flags=%d' % self
.regex
.flags
939 class In(Comparator
):
940 """Checks whether an item (or key) is in a list (or dict) parameter.
943 mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result)
946 def __init__(self
, key
):
950 # key is any thing that could be in a list or a key in a dict
955 def equals(self
, rhs
):
956 """Check to see whether key is in rhs.
965 return self
._key
in rhs
968 return '<sequence or map containing \'%s\'>' % self
._key
971 class ContainsKeyValue(Comparator
):
972 """Checks whether a key/value pair is in a dict parameter.
975 mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info))
978 def __init__(self
, key
, value
):
982 # key: a key in a dict
983 # value: the corresponding value
989 def equals(self
, rhs
):
990 """Check whether the given key/value pair is in the rhs dict.
997 return rhs
[self
._key
] == self
._value
1002 return '<map containing the entry \'%s: %s\'>' % (self
._key
, self
._value
)
1005 class SameElementsAs(Comparator
):
1006 """Checks whether iterables contain the same elements (ignoring order).
1009 mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
1012 def __init__(self
, expected_seq
):
1016 expected_seq: a sequence
1019 self
._expected
_seq
= expected_seq
1021 def equals(self
, actual_seq
):
1022 """Check to see whether actual_seq has same elements as expected_seq.
1025 actual_seq: sequence
1032 expected
= dict([(element
, None) for element
in self
._expected
_seq
])
1033 actual
= dict([(element
, None) for element
in actual_seq
])
1035 # Fall back to slower list-compare if any of the objects are unhashable.
1036 expected
= list(self
._expected
_seq
)
1037 actual
= list(actual_seq
)
1040 return expected
== actual
1043 return '<sequence with same elements as \'%s\'>' % self
._expected
_seq
1046 class And(Comparator
):
1047 """Evaluates one or more Comparators on RHS and returns an AND of the results.
1050 def __init__(self
, *args
):
1054 *args: One or more Comparator
1057 self
._comparators
= args
1059 def equals(self
, rhs
):
1060 """Checks whether all Comparators are equal to rhs.
1063 # rhs: can be anything
1069 for comparator
in self
._comparators
:
1070 if not comparator
.equals(rhs
):
1076 return '<AND %s>' % str(self
._comparators
)
1079 class Or(Comparator
):
1080 """Evaluates one or more Comparators on RHS and returns an OR of the results.
1083 def __init__(self
, *args
):
1087 *args: One or more Mox comparators
1090 self
._comparators
= args
1092 def equals(self
, rhs
):
1093 """Checks whether any Comparator is equal to rhs.
1096 # rhs: can be anything
1102 for comparator
in self
._comparators
:
1103 if comparator
.equals(rhs
):
1109 return '<OR %s>' % str(self
._comparators
)
1112 class Func(Comparator
):
1113 """Call a function that should verify the parameter passed in is correct.
1115 You may need the ability to perform more advanced operations on the parameter
1116 in order to validate it. You can use this to have a callable validate any
1117 parameter. The callable should return either True or False.
1122 def myParamValidator(param):
1123 # Advanced logic here
1126 mock_dao.DoSomething(Func(myParamValidator), true)
1129 def __init__(self
, func
):
1133 func: callable that takes one parameter and returns a bool
1138 def equals(self
, rhs
):
1139 """Test whether rhs passes the function test.
1141 rhs is passed into func.
1144 rhs: any python object
1147 the result of func(rhs)
1150 return self
._func
(rhs
)
1153 return str(self
._func
)
1156 class IgnoreArg(Comparator
):
1157 """Ignore an argument.
1159 This can be used when we don't care about an argument of a method call.
1162 # Check if CastMagic is called with 3 as first arg and 'disappear' as third.
1163 mymock.CastMagic(3, IgnoreArg(), 'disappear')
1166 def equals(self
, unused_rhs
):
1167 """Ignores arguments and returns True.
1170 unused_rhs: any python object
1179 return '<IgnoreArg>'
1182 class MethodGroup(object):
1183 """Base class containing common behaviour for MethodGroups."""
1185 def __init__(self
, group_name
):
1186 self
._group
_name
= group_name
1188 def group_name(self
):
1189 return self
._group
_name
1192 return '<%s "%s">' % (self
.__class
__.__name
__, self
._group
_name
)
1194 def AddMethod(self
, mock_method
):
1195 raise NotImplementedError
1197 def MethodCalled(self
, mock_method
):
1198 raise NotImplementedError
1200 def IsSatisfied(self
):
1201 raise NotImplementedError
1203 class UnorderedGroup(MethodGroup
):
1204 """UnorderedGroup holds a set of method calls that may occur in any order.
1206 This construct is helpful for non-deterministic events, such as iterating
1207 over the keys of a dict.
1210 def __init__(self
, group_name
):
1211 super(UnorderedGroup
, self
).__init
__(group_name
)
1214 def AddMethod(self
, mock_method
):
1215 """Add a method to this group.
1218 mock_method: A mock method to be added to this group.
1221 self
._methods
.append(mock_method
)
1223 def MethodCalled(self
, mock_method
):
1224 """Remove a method call from the group.
1226 If the method is not in the set, an UnexpectedMethodCallError will be
1230 mock_method: a mock method that should be equal to a method in the group.
1233 The mock method from the group
1236 UnexpectedMethodCallError if the mock_method was not in the group.
1239 # Check to see if this method exists, and if so, remove it from the set
1241 for method
in self
._methods
:
1242 if method
== mock_method
:
1243 # Remove the called mock_method instead of the method in the group.
1244 # The called method will match any comparators when equality is checked
1245 # during removal. The method in the group could pass a comparator to
1246 # another comparator during the equality check.
1247 self
._methods
.remove(mock_method
)
1249 # If this group is not empty, put it back at the head of the queue.
1250 if not self
.IsSatisfied():
1251 mock_method
._call
_queue
.appendleft(self
)
1255 raise UnexpectedMethodCallError(mock_method
, self
)
1257 def IsSatisfied(self
):
1258 """Return True if there are not any methods in this group."""
1260 return len(self
._methods
) == 0
1263 class MultipleTimesGroup(MethodGroup
):
1264 """MultipleTimesGroup holds methods that may be called any number of times.
1266 Note: Each method must be called at least once.
1268 This is helpful, if you don't know or care how many times a method is called.
1271 def __init__(self
, group_name
):
1272 super(MultipleTimesGroup
, self
).__init
__(group_name
)
1273 self
._methods
= set()
1274 self
._methods
_called
= set()
1276 def AddMethod(self
, mock_method
):
1277 """Add a method to this group.
1280 mock_method: A mock method to be added to this group.
1283 self
._methods
.add(mock_method
)
1285 def MethodCalled(self
, mock_method
):
1286 """Remove a method call from the group.
1288 If the method is not in the set, an UnexpectedMethodCallError will be
1292 mock_method: a mock method that should be equal to a method in the group.
1295 The mock method from the group
1298 UnexpectedMethodCallError if the mock_method was not in the group.
1301 # Check to see if this method exists, and if so add it to the set of
1304 for method
in self
._methods
:
1305 if method
== mock_method
:
1306 self
._methods
_called
.add(mock_method
)
1307 # Always put this group back on top of the queue, because we don't know
1309 mock_method
._call
_queue
.appendleft(self
)
1312 if self
.IsSatisfied():
1313 next_method
= mock_method
._PopNextMethod
();
1314 return next_method
, None
1316 raise UnexpectedMethodCallError(mock_method
, self
)
1318 def IsSatisfied(self
):
1319 """Return True if all methods in this group are called at least once."""
1320 # NOTE(psycho): We can't use the simple set difference here because we want
1321 # to match different parameters which are considered the same e.g. IsA(str)
1322 # and some string. This solution is O(n^2) but n should be small.
1323 tmp
= self
._methods
.copy()
1324 for called
in self
._methods
_called
:
1325 for expected
in tmp
:
1326 if called
== expected
:
1327 tmp
.remove(expected
)
1334 class MoxMetaTestBase(type):
1335 """Metaclass to add mox cleanup and verification to every test.
1337 As the mox unit testing class is being constructed (MoxTestBase or a
1338 subclass), this metaclass will modify all test functions to call the
1339 CleanUpMox method of the test class after they finish. This means that
1340 unstubbing and verifying will happen for every test with no additional code,
1341 and any failures will result in test failures as opposed to errors.
1344 def __init__(cls
, name
, bases
, d
):
1345 type.__init
__(cls
, name
, bases
, d
)
1347 # also get all the attributes from the base classes to account
1348 # for a case when test class is not the immediate child of MoxTestBase
1350 for attr_name
in dir(base
):
1351 d
[attr_name
] = getattr(base
, attr_name
)
1353 for func_name
, func
in d
.items():
1354 if func_name
.startswith('test') and callable(func
):
1355 setattr(cls
, func_name
, MoxMetaTestBase
.CleanUpTest(cls
, func
))
1358 def CleanUpTest(cls
, func
):
1359 """Adds Mox cleanup code to any MoxTestBase method.
1361 Always unsets stubs after a test. Will verify all mocks for tests that
1365 cls: MoxTestBase or subclass; the class whose test method we are altering.
1366 func: method; the method of the MoxTestBase test class we wish to alter.
1369 The modified method.
1371 def new_method(self
, *args
, **kwargs
):
1372 mox_obj
= getattr(self
, 'mox', None)
1374 if mox_obj
and isinstance(mox_obj
, Mox
):
1377 func(self
, *args
, **kwargs
)
1380 mox_obj
.UnsetStubs()
1383 new_method
.__name
__ = func
.__name
__
1384 new_method
.__doc
__ = func
.__doc
__
1385 new_method
.__module
__ = func
.__module
__
1389 class MoxTestBase(unittest
.TestCase
):
1390 """Convenience test class to make stubbing easier.
1392 Sets up a "mox" attribute which is an instance of Mox - any mox tests will
1393 want this. Also automatically unsets any stubs and verifies that all mock
1394 methods have been called at the end of each test, eliminating boilerplate
1398 __metaclass__
= MoxMetaTestBase