Implement OCSP stapling in Windows BoringSSL port.
[chromium-blink-merge.git] / third_party / protobuf / python / mox.py
blobce80ba505c93698ae91766ea9ec8d83876479d6f
1 #!/usr/bin/python2.4
3 # Copyright 2008 Google Inc.
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
17 # This file is used for testing. The original is at:
18 # http://code.google.com/p/pymox/
20 """Mox, an object-mocking framework for Python.
22 Mox works in the record-replay-verify paradigm. When you first create
23 a mock object, it is in record mode. You then programmatically set
24 the expected behavior of the mock object (what methods are to be
25 called on it, with what parameters, what they should return, and in
26 what order).
28 Once you have set up the expected mock behavior, you put it in replay
29 mode. Now the mock responds to method calls just as you told it to.
30 If an unexpected method (or an expected method with unexpected
31 parameters) is called, then an exception will be raised.
33 Once you are done interacting with the mock, you need to verify that
34 all the expected interactions occured. (Maybe your code exited
35 prematurely without calling some cleanup method!) The verify phase
36 ensures that every expected method was called; otherwise, an exception
37 will be raised.
39 Suggested usage / workflow:
41 # Create Mox factory
42 my_mox = Mox()
44 # Create a mock data access object
45 mock_dao = my_mox.CreateMock(DAOClass)
47 # Set up expected behavior
48 mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person)
49 mock_dao.DeletePerson(person)
51 # Put mocks in replay mode
52 my_mox.ReplayAll()
54 # Inject mock object and run test
55 controller.SetDao(mock_dao)
56 controller.DeletePersonById('1')
58 # Verify all methods were called as expected
59 my_mox.VerifyAll()
60 """
62 from collections import deque
63 import re
64 import types
65 import unittest
67 import stubout
69 class Error(AssertionError):
70 """Base exception for this module."""
72 pass
75 class ExpectedMethodCallsError(Error):
76 """Raised when Verify() is called before all expected methods have been called
77 """
79 def __init__(self, expected_methods):
80 """Init exception.
82 Args:
83 # expected_methods: A sequence of MockMethod objects that should have been
84 # called.
85 expected_methods: [MockMethod]
87 Raises:
88 ValueError: if expected_methods contains no methods.
89 """
91 if not expected_methods:
92 raise ValueError("There must be at least one expected method")
93 Error.__init__(self)
94 self._expected_methods = expected_methods
96 def __str__(self):
97 calls = "\n".join(["%3d. %s" % (i, m)
98 for i, m in enumerate(self._expected_methods)])
99 return "Verify: Expected methods never called:\n%s" % (calls,)
102 class UnexpectedMethodCallError(Error):
103 """Raised when an unexpected method is called.
105 This can occur if a method is called with incorrect parameters, or out of the
106 specified order.
109 def __init__(self, unexpected_method, expected):
110 """Init exception.
112 Args:
113 # unexpected_method: MockMethod that was called but was not at the head of
114 # the expected_method queue.
115 # expected: MockMethod or UnorderedGroup the method should have
116 # been in.
117 unexpected_method: MockMethod
118 expected: MockMethod or UnorderedGroup
121 Error.__init__(self)
122 self._unexpected_method = unexpected_method
123 self._expected = expected
125 def __str__(self):
126 return "Unexpected method call: %s. Expecting: %s" % \
127 (self._unexpected_method, self._expected)
130 class UnknownMethodCallError(Error):
131 """Raised if an unknown method is requested of the mock object."""
133 def __init__(self, unknown_method_name):
134 """Init exception.
136 Args:
137 # unknown_method_name: Method call that is not part of the mocked class's
138 # public interface.
139 unknown_method_name: str
142 Error.__init__(self)
143 self._unknown_method_name = unknown_method_name
145 def __str__(self):
146 return "Method called is not a member of the object: %s" % \
147 self._unknown_method_name
150 class Mox(object):
151 """Mox: a factory for creating mock objects."""
153 # A list of types that should be stubbed out with MockObjects (as
154 # opposed to MockAnythings).
155 _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType,
156 types.ObjectType, types.TypeType]
158 def __init__(self):
159 """Initialize a new Mox."""
161 self._mock_objects = []
162 self.stubs = stubout.StubOutForTesting()
164 def CreateMock(self, class_to_mock):
165 """Create a new mock object.
167 Args:
168 # class_to_mock: the class to be mocked
169 class_to_mock: class
171 Returns:
172 MockObject that can be used as the class_to_mock would be.
175 new_mock = MockObject(class_to_mock)
176 self._mock_objects.append(new_mock)
177 return new_mock
179 def CreateMockAnything(self):
180 """Create a mock that will accept any method calls.
182 This does not enforce an interface.
185 new_mock = MockAnything()
186 self._mock_objects.append(new_mock)
187 return new_mock
189 def ReplayAll(self):
190 """Set all mock objects to replay mode."""
192 for mock_obj in self._mock_objects:
193 mock_obj._Replay()
196 def VerifyAll(self):
197 """Call verify on all mock objects created."""
199 for mock_obj in self._mock_objects:
200 mock_obj._Verify()
202 def ResetAll(self):
203 """Call reset on all mock objects. This does not unset stubs."""
205 for mock_obj in self._mock_objects:
206 mock_obj._Reset()
208 def StubOutWithMock(self, obj, attr_name, use_mock_anything=False):
209 """Replace a method, attribute, etc. with a Mock.
211 This will replace a class or module with a MockObject, and everything else
212 (method, function, etc) with a MockAnything. This can be overridden to
213 always use a MockAnything by setting use_mock_anything to True.
215 Args:
216 obj: A Python object (class, module, instance, callable).
217 attr_name: str. The name of the attribute to replace with a mock.
218 use_mock_anything: bool. True if a MockAnything should be used regardless
219 of the type of attribute.
222 attr_to_replace = getattr(obj, attr_name)
223 if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything:
224 stub = self.CreateMock(attr_to_replace)
225 else:
226 stub = self.CreateMockAnything()
228 self.stubs.Set(obj, attr_name, stub)
230 def UnsetStubs(self):
231 """Restore stubs to their original state."""
233 self.stubs.UnsetAll()
235 def Replay(*args):
236 """Put mocks into Replay mode.
238 Args:
239 # args is any number of mocks to put into replay mode.
242 for mock in args:
243 mock._Replay()
246 def Verify(*args):
247 """Verify mocks.
249 Args:
250 # args is any number of mocks to be verified.
253 for mock in args:
254 mock._Verify()
257 def Reset(*args):
258 """Reset mocks.
260 Args:
261 # args is any number of mocks to be reset.
264 for mock in args:
265 mock._Reset()
268 class MockAnything:
269 """A mock that can be used to mock anything.
271 This is helpful for mocking classes that do not provide a public interface.
274 def __init__(self):
275 """ """
276 self._Reset()
278 def __getattr__(self, method_name):
279 """Intercept method calls on this object.
281 A new MockMethod is returned that is aware of the MockAnything's
282 state (record or replay). The call will be recorded or replayed
283 by the MockMethod's __call__.
285 Args:
286 # method name: the name of the method being called.
287 method_name: str
289 Returns:
290 A new MockMethod aware of MockAnything's state (record or replay).
293 return self._CreateMockMethod(method_name)
295 def _CreateMockMethod(self, method_name):
296 """Create a new mock method call and return it.
298 Args:
299 # method name: the name of the method being called.
300 method_name: str
302 Returns:
303 A new MockMethod aware of MockAnything's state (record or replay).
306 return MockMethod(method_name, self._expected_calls_queue,
307 self._replay_mode)
309 def __nonzero__(self):
310 """Return 1 for nonzero so the mock can be used as a conditional."""
312 return 1
314 def __eq__(self, rhs):
315 """Provide custom logic to compare objects."""
317 return (isinstance(rhs, MockAnything) and
318 self._replay_mode == rhs._replay_mode and
319 self._expected_calls_queue == rhs._expected_calls_queue)
321 def __ne__(self, rhs):
322 """Provide custom logic to compare objects."""
324 return not self == rhs
326 def _Replay(self):
327 """Start replaying expected method calls."""
329 self._replay_mode = True
331 def _Verify(self):
332 """Verify that all of the expected calls have been made.
334 Raises:
335 ExpectedMethodCallsError: if there are still more method calls in the
336 expected queue.
339 # If the list of expected calls is not empty, raise an exception
340 if self._expected_calls_queue:
341 # The last MultipleTimesGroup is not popped from the queue.
342 if (len(self._expected_calls_queue) == 1 and
343 isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and
344 self._expected_calls_queue[0].IsSatisfied()):
345 pass
346 else:
347 raise ExpectedMethodCallsError(self._expected_calls_queue)
349 def _Reset(self):
350 """Reset the state of this mock to record mode with an empty queue."""
352 # Maintain a list of method calls we are expecting
353 self._expected_calls_queue = deque()
355 # Make sure we are in setup mode, not replay mode
356 self._replay_mode = False
359 class MockObject(MockAnything, object):
360 """A mock object that simulates the public/protected interface of a class."""
362 def __init__(self, class_to_mock):
363 """Initialize a mock object.
365 This determines the methods and properties of the class and stores them.
367 Args:
368 # class_to_mock: class to be mocked
369 class_to_mock: class
372 # This is used to hack around the mixin/inheritance of MockAnything, which
373 # is not a proper object (it can be anything. :-)
374 MockAnything.__dict__['__init__'](self)
376 # Get a list of all the public and special methods we should mock.
377 self._known_methods = set()
378 self._known_vars = set()
379 self._class_to_mock = class_to_mock
380 for method in dir(class_to_mock):
381 if callable(getattr(class_to_mock, method)):
382 self._known_methods.add(method)
383 else:
384 self._known_vars.add(method)
386 def __getattr__(self, name):
387 """Intercept attribute request on this object.
389 If the attribute is a public class variable, it will be returned and not
390 recorded as a call.
392 If the attribute is not a variable, it is handled like a method
393 call. The method name is checked against the set of mockable
394 methods, and a new MockMethod is returned that is aware of the
395 MockObject's state (record or replay). The call will be recorded
396 or replayed by the MockMethod's __call__.
398 Args:
399 # name: the name of the attribute being requested.
400 name: str
402 Returns:
403 Either a class variable or a new MockMethod that is aware of the state
404 of the mock (record or replay).
406 Raises:
407 UnknownMethodCallError if the MockObject does not mock the requested
408 method.
411 if name in self._known_vars:
412 return getattr(self._class_to_mock, name)
414 if name in self._known_methods:
415 return self._CreateMockMethod(name)
417 raise UnknownMethodCallError(name)
419 def __eq__(self, rhs):
420 """Provide custom logic to compare objects."""
422 return (isinstance(rhs, MockObject) and
423 self._class_to_mock == rhs._class_to_mock and
424 self._replay_mode == rhs._replay_mode and
425 self._expected_calls_queue == rhs._expected_calls_queue)
427 def __setitem__(self, key, value):
428 """Provide custom logic for mocking classes that support item assignment.
430 Args:
431 key: Key to set the value for.
432 value: Value to set.
434 Returns:
435 Expected return value in replay mode. A MockMethod object for the
436 __setitem__ method that has already been called if not in replay mode.
438 Raises:
439 TypeError if the underlying class does not support item assignment.
440 UnexpectedMethodCallError if the object does not expect the call to
441 __setitem__.
444 setitem = self._class_to_mock.__dict__.get('__setitem__', None)
446 # Verify the class supports item assignment.
447 if setitem is None:
448 raise TypeError('object does not support item assignment')
450 # If we are in replay mode then simply call the mock __setitem__ method.
451 if self._replay_mode:
452 return MockMethod('__setitem__', self._expected_calls_queue,
453 self._replay_mode)(key, value)
456 # Otherwise, create a mock method __setitem__.
457 return self._CreateMockMethod('__setitem__')(key, value)
459 def __getitem__(self, key):
460 """Provide custom logic for mocking classes that are subscriptable.
462 Args:
463 key: Key to return the value for.
465 Returns:
466 Expected return value in replay mode. A MockMethod object for the
467 __getitem__ method that has already been called if not in replay mode.
469 Raises:
470 TypeError if the underlying class is not subscriptable.
471 UnexpectedMethodCallError if the object does not expect the call to
472 __setitem__.
475 getitem = self._class_to_mock.__dict__.get('__getitem__', None)
477 # Verify the class supports item assignment.
478 if getitem is None:
479 raise TypeError('unsubscriptable object')
481 # If we are in replay mode then simply call the mock __getitem__ method.
482 if self._replay_mode:
483 return MockMethod('__getitem__', self._expected_calls_queue,
484 self._replay_mode)(key)
487 # Otherwise, create a mock method __getitem__.
488 return self._CreateMockMethod('__getitem__')(key)
490 def __call__(self, *params, **named_params):
491 """Provide custom logic for mocking classes that are callable."""
493 # Verify the class we are mocking is callable
494 callable = self._class_to_mock.__dict__.get('__call__', None)
495 if callable is None:
496 raise TypeError('Not callable')
498 # Because the call is happening directly on this object instead of a method,
499 # the call on the mock method is made right here
500 mock_method = self._CreateMockMethod('__call__')
501 return mock_method(*params, **named_params)
503 @property
504 def __class__(self):
505 """Return the class that is being mocked."""
507 return self._class_to_mock
510 class MockMethod(object):
511 """Callable mock method.
513 A MockMethod should act exactly like the method it mocks, accepting parameters
514 and returning a value, or throwing an exception (as specified). When this
515 method is called, it can optionally verify whether the called method (name and
516 signature) matches the expected method.
519 def __init__(self, method_name, call_queue, replay_mode):
520 """Construct a new mock method.
522 Args:
523 # method_name: the name of the method
524 # call_queue: deque of calls, verify this call against the head, or add
525 # this call to the queue.
526 # replay_mode: False if we are recording, True if we are verifying calls
527 # against the call queue.
528 method_name: str
529 call_queue: list or deque
530 replay_mode: bool
533 self._name = method_name
534 self._call_queue = call_queue
535 if not isinstance(call_queue, deque):
536 self._call_queue = deque(self._call_queue)
537 self._replay_mode = replay_mode
539 self._params = None
540 self._named_params = None
541 self._return_value = None
542 self._exception = None
543 self._side_effects = None
545 def __call__(self, *params, **named_params):
546 """Log parameters and return the specified return value.
548 If the Mock(Anything/Object) associated with this call is in record mode,
549 this MockMethod will be pushed onto the expected call queue. If the mock
550 is in replay mode, this will pop a MockMethod off the top of the queue and
551 verify this call is equal to the expected call.
553 Raises:
554 UnexpectedMethodCall if this call is supposed to match an expected method
555 call and it does not.
558 self._params = params
559 self._named_params = named_params
561 if not self._replay_mode:
562 self._call_queue.append(self)
563 return self
565 expected_method = self._VerifyMethodCall()
567 if expected_method._side_effects:
568 expected_method._side_effects(*params, **named_params)
570 if expected_method._exception:
571 raise expected_method._exception
573 return expected_method._return_value
575 def __getattr__(self, name):
576 """Raise an AttributeError with a helpful message."""
578 raise AttributeError('MockMethod has no attribute "%s". '
579 'Did you remember to put your mocks in replay mode?' % name)
581 def _PopNextMethod(self):
582 """Pop the next method from our call queue."""
583 try:
584 return self._call_queue.popleft()
585 except IndexError:
586 raise UnexpectedMethodCallError(self, None)
588 def _VerifyMethodCall(self):
589 """Verify the called method is expected.
591 This can be an ordered method, or part of an unordered set.
593 Returns:
594 The expected mock method.
596 Raises:
597 UnexpectedMethodCall if the method called was not expected.
600 expected = self._PopNextMethod()
602 # Loop here, because we might have a MethodGroup followed by another
603 # group.
604 while isinstance(expected, MethodGroup):
605 expected, method = expected.MethodCalled(self)
606 if method is not None:
607 return method
609 # This is a mock method, so just check equality.
610 if expected != self:
611 raise UnexpectedMethodCallError(self, expected)
613 return expected
615 def __str__(self):
616 params = ', '.join(
617 [repr(p) for p in self._params or []] +
618 ['%s=%r' % x for x in sorted((self._named_params or {}).items())])
619 desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
620 return desc
622 def __eq__(self, rhs):
623 """Test whether this MockMethod is equivalent to another MockMethod.
625 Args:
626 # rhs: the right hand side of the test
627 rhs: MockMethod
630 return (isinstance(rhs, MockMethod) and
631 self._name == rhs._name and
632 self._params == rhs._params and
633 self._named_params == rhs._named_params)
635 def __ne__(self, rhs):
636 """Test whether this MockMethod is not equivalent to another MockMethod.
638 Args:
639 # rhs: the right hand side of the test
640 rhs: MockMethod
643 return not self == rhs
645 def GetPossibleGroup(self):
646 """Returns a possible group from the end of the call queue or None if no
647 other methods are on the stack.
650 # Remove this method from the tail of the queue so we can add it to a group.
651 this_method = self._call_queue.pop()
652 assert this_method == self
654 # Determine if the tail of the queue is a group, or just a regular ordered
655 # mock method.
656 group = None
657 try:
658 group = self._call_queue[-1]
659 except IndexError:
660 pass
662 return group
664 def _CheckAndCreateNewGroup(self, group_name, group_class):
665 """Checks if the last method (a possible group) is an instance of our
666 group_class. Adds the current method to this group or creates a new one.
668 Args:
670 group_name: the name of the group.
671 group_class: the class used to create instance of this new group
673 group = self.GetPossibleGroup()
675 # If this is a group, and it is the correct group, add the method.
676 if isinstance(group, group_class) and group.group_name() == group_name:
677 group.AddMethod(self)
678 return self
680 # Create a new group and add the method.
681 new_group = group_class(group_name)
682 new_group.AddMethod(self)
683 self._call_queue.append(new_group)
684 return self
686 def InAnyOrder(self, group_name="default"):
687 """Move this method into a group of unordered calls.
689 A group of unordered calls must be defined together, and must be executed
690 in full before the next expected method can be called. There can be
691 multiple groups that are expected serially, if they are given
692 different group names. The same group name can be reused if there is a
693 standard method call, or a group with a different name, spliced between
694 usages.
696 Args:
697 group_name: the name of the unordered group.
699 Returns:
700 self
702 return self._CheckAndCreateNewGroup(group_name, UnorderedGroup)
704 def MultipleTimes(self, group_name="default"):
705 """Move this method into group of calls which may be called multiple times.
707 A group of repeating calls must be defined together, and must be executed in
708 full before the next expected mehtod can be called.
710 Args:
711 group_name: the name of the unordered group.
713 Returns:
714 self
716 return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup)
718 def AndReturn(self, return_value):
719 """Set the value to return when this method is called.
721 Args:
722 # return_value can be anything.
725 self._return_value = return_value
726 return return_value
728 def AndRaise(self, exception):
729 """Set the exception to raise when this method is called.
731 Args:
732 # exception: the exception to raise when this method is called.
733 exception: Exception
736 self._exception = exception
738 def WithSideEffects(self, side_effects):
739 """Set the side effects that are simulated when this method is called.
741 Args:
742 side_effects: A callable which modifies the parameters or other relevant
743 state which a given test case depends on.
745 Returns:
746 Self for chaining with AndReturn and AndRaise.
748 self._side_effects = side_effects
749 return self
751 class Comparator:
752 """Base class for all Mox comparators.
754 A Comparator can be used as a parameter to a mocked method when the exact
755 value is not known. For example, the code you are testing might build up a
756 long SQL string that is passed to your mock DAO. You're only interested that
757 the IN clause contains the proper primary keys, so you can set your mock
758 up as follows:
760 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
762 Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'.
764 A Comparator may replace one or more parameters, for example:
765 # return at most 10 rows
766 mock_dao.RunQuery(StrContains('SELECT'), 10)
770 # Return some non-deterministic number of rows
771 mock_dao.RunQuery(StrContains('SELECT'), IsA(int))
774 def equals(self, rhs):
775 """Special equals method that all comparators must implement.
777 Args:
778 rhs: any python object
781 raise NotImplementedError, 'method must be implemented by a subclass.'
783 def __eq__(self, rhs):
784 return self.equals(rhs)
786 def __ne__(self, rhs):
787 return not self.equals(rhs)
790 class IsA(Comparator):
791 """This class wraps a basic Python type or class. It is used to verify
792 that a parameter is of the given type or class.
794 Example:
795 mock_dao.Connect(IsA(DbConnectInfo))
798 def __init__(self, class_name):
799 """Initialize IsA
801 Args:
802 class_name: basic python type or a class
805 self._class_name = class_name
807 def equals(self, rhs):
808 """Check to see if the RHS is an instance of class_name.
810 Args:
811 # rhs: the right hand side of the test
812 rhs: object
814 Returns:
815 bool
818 try:
819 return isinstance(rhs, self._class_name)
820 except TypeError:
821 # Check raw types if there was a type error. This is helpful for
822 # things like cStringIO.StringIO.
823 return type(rhs) == type(self._class_name)
825 def __repr__(self):
826 return str(self._class_name)
828 class IsAlmost(Comparator):
829 """Comparison class used to check whether a parameter is nearly equal
830 to a given value. Generally useful for floating point numbers.
832 Example mock_dao.SetTimeout((IsAlmost(3.9)))
835 def __init__(self, float_value, places=7):
836 """Initialize IsAlmost.
838 Args:
839 float_value: The value for making the comparison.
840 places: The number of decimal places to round to.
843 self._float_value = float_value
844 self._places = places
846 def equals(self, rhs):
847 """Check to see if RHS is almost equal to float_value
849 Args:
850 rhs: the value to compare to float_value
852 Returns:
853 bool
856 try:
857 return round(rhs-self._float_value, self._places) == 0
858 except TypeError:
859 # This is probably because either float_value or rhs is not a number.
860 return False
862 def __repr__(self):
863 return str(self._float_value)
865 class StrContains(Comparator):
866 """Comparison class used to check whether a substring exists in a
867 string parameter. This can be useful in mocking a database with SQL
868 passed in as a string parameter, for example.
870 Example:
871 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
874 def __init__(self, search_string):
875 """Initialize.
877 Args:
878 # search_string: the string you are searching for
879 search_string: str
882 self._search_string = search_string
884 def equals(self, rhs):
885 """Check to see if the search_string is contained in the rhs string.
887 Args:
888 # rhs: the right hand side of the test
889 rhs: object
891 Returns:
892 bool
895 try:
896 return rhs.find(self._search_string) > -1
897 except Exception:
898 return False
900 def __repr__(self):
901 return '<str containing \'%s\'>' % self._search_string
904 class Regex(Comparator):
905 """Checks if a string matches a regular expression.
907 This uses a given regular expression to determine equality.
910 def __init__(self, pattern, flags=0):
911 """Initialize.
913 Args:
914 # pattern is the regular expression to search for
915 pattern: str
916 # flags passed to re.compile function as the second argument
917 flags: int
920 self.regex = re.compile(pattern, flags=flags)
922 def equals(self, rhs):
923 """Check to see if rhs matches regular expression pattern.
925 Returns:
926 bool
929 return self.regex.search(rhs) is not None
931 def __repr__(self):
932 s = '<regular expression \'%s\'' % self.regex.pattern
933 if self.regex.flags:
934 s += ', flags=%d' % self.regex.flags
935 s += '>'
936 return s
939 class In(Comparator):
940 """Checks whether an item (or key) is in a list (or dict) parameter.
942 Example:
943 mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result)
946 def __init__(self, key):
947 """Initialize.
949 Args:
950 # key is any thing that could be in a list or a key in a dict
953 self._key = key
955 def equals(self, rhs):
956 """Check to see whether key is in rhs.
958 Args:
959 rhs: dict
961 Returns:
962 bool
965 return self._key in rhs
967 def __repr__(self):
968 return '<sequence or map containing \'%s\'>' % self._key
971 class ContainsKeyValue(Comparator):
972 """Checks whether a key/value pair is in a dict parameter.
974 Example:
975 mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info))
978 def __init__(self, key, value):
979 """Initialize.
981 Args:
982 # key: a key in a dict
983 # value: the corresponding value
986 self._key = key
987 self._value = value
989 def equals(self, rhs):
990 """Check whether the given key/value pair is in the rhs dict.
992 Returns:
993 bool
996 try:
997 return rhs[self._key] == self._value
998 except Exception:
999 return False
1001 def __repr__(self):
1002 return '<map containing the entry \'%s: %s\'>' % (self._key, self._value)
1005 class SameElementsAs(Comparator):
1006 """Checks whether iterables contain the same elements (ignoring order).
1008 Example:
1009 mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
1012 def __init__(self, expected_seq):
1013 """Initialize.
1015 Args:
1016 expected_seq: a sequence
1019 self._expected_seq = expected_seq
1021 def equals(self, actual_seq):
1022 """Check to see whether actual_seq has same elements as expected_seq.
1024 Args:
1025 actual_seq: sequence
1027 Returns:
1028 bool
1031 try:
1032 expected = dict([(element, None) for element in self._expected_seq])
1033 actual = dict([(element, None) for element in actual_seq])
1034 except TypeError:
1035 # Fall back to slower list-compare if any of the objects are unhashable.
1036 expected = list(self._expected_seq)
1037 actual = list(actual_seq)
1038 expected.sort()
1039 actual.sort()
1040 return expected == actual
1042 def __repr__(self):
1043 return '<sequence with same elements as \'%s\'>' % self._expected_seq
1046 class And(Comparator):
1047 """Evaluates one or more Comparators on RHS and returns an AND of the results.
1050 def __init__(self, *args):
1051 """Initialize.
1053 Args:
1054 *args: One or more Comparator
1057 self._comparators = args
1059 def equals(self, rhs):
1060 """Checks whether all Comparators are equal to rhs.
1062 Args:
1063 # rhs: can be anything
1065 Returns:
1066 bool
1069 for comparator in self._comparators:
1070 if not comparator.equals(rhs):
1071 return False
1073 return True
1075 def __repr__(self):
1076 return '<AND %s>' % str(self._comparators)
1079 class Or(Comparator):
1080 """Evaluates one or more Comparators on RHS and returns an OR of the results.
1083 def __init__(self, *args):
1084 """Initialize.
1086 Args:
1087 *args: One or more Mox comparators
1090 self._comparators = args
1092 def equals(self, rhs):
1093 """Checks whether any Comparator is equal to rhs.
1095 Args:
1096 # rhs: can be anything
1098 Returns:
1099 bool
1102 for comparator in self._comparators:
1103 if comparator.equals(rhs):
1104 return True
1106 return False
1108 def __repr__(self):
1109 return '<OR %s>' % str(self._comparators)
1112 class Func(Comparator):
1113 """Call a function that should verify the parameter passed in is correct.
1115 You may need the ability to perform more advanced operations on the parameter
1116 in order to validate it. You can use this to have a callable validate any
1117 parameter. The callable should return either True or False.
1120 Example:
1122 def myParamValidator(param):
1123 # Advanced logic here
1124 return True
1126 mock_dao.DoSomething(Func(myParamValidator), true)
1129 def __init__(self, func):
1130 """Initialize.
1132 Args:
1133 func: callable that takes one parameter and returns a bool
1136 self._func = func
1138 def equals(self, rhs):
1139 """Test whether rhs passes the function test.
1141 rhs is passed into func.
1143 Args:
1144 rhs: any python object
1146 Returns:
1147 the result of func(rhs)
1150 return self._func(rhs)
1152 def __repr__(self):
1153 return str(self._func)
1156 class IgnoreArg(Comparator):
1157 """Ignore an argument.
1159 This can be used when we don't care about an argument of a method call.
1161 Example:
1162 # Check if CastMagic is called with 3 as first arg and 'disappear' as third.
1163 mymock.CastMagic(3, IgnoreArg(), 'disappear')
1166 def equals(self, unused_rhs):
1167 """Ignores arguments and returns True.
1169 Args:
1170 unused_rhs: any python object
1172 Returns:
1173 always returns True
1176 return True
1178 def __repr__(self):
1179 return '<IgnoreArg>'
1182 class MethodGroup(object):
1183 """Base class containing common behaviour for MethodGroups."""
1185 def __init__(self, group_name):
1186 self._group_name = group_name
1188 def group_name(self):
1189 return self._group_name
1191 def __str__(self):
1192 return '<%s "%s">' % (self.__class__.__name__, self._group_name)
1194 def AddMethod(self, mock_method):
1195 raise NotImplementedError
1197 def MethodCalled(self, mock_method):
1198 raise NotImplementedError
1200 def IsSatisfied(self):
1201 raise NotImplementedError
1203 class UnorderedGroup(MethodGroup):
1204 """UnorderedGroup holds a set of method calls that may occur in any order.
1206 This construct is helpful for non-deterministic events, such as iterating
1207 over the keys of a dict.
1210 def __init__(self, group_name):
1211 super(UnorderedGroup, self).__init__(group_name)
1212 self._methods = []
1214 def AddMethod(self, mock_method):
1215 """Add a method to this group.
1217 Args:
1218 mock_method: A mock method to be added to this group.
1221 self._methods.append(mock_method)
1223 def MethodCalled(self, mock_method):
1224 """Remove a method call from the group.
1226 If the method is not in the set, an UnexpectedMethodCallError will be
1227 raised.
1229 Args:
1230 mock_method: a mock method that should be equal to a method in the group.
1232 Returns:
1233 The mock method from the group
1235 Raises:
1236 UnexpectedMethodCallError if the mock_method was not in the group.
1239 # Check to see if this method exists, and if so, remove it from the set
1240 # and return it.
1241 for method in self._methods:
1242 if method == mock_method:
1243 # Remove the called mock_method instead of the method in the group.
1244 # The called method will match any comparators when equality is checked
1245 # during removal. The method in the group could pass a comparator to
1246 # another comparator during the equality check.
1247 self._methods.remove(mock_method)
1249 # If this group is not empty, put it back at the head of the queue.
1250 if not self.IsSatisfied():
1251 mock_method._call_queue.appendleft(self)
1253 return self, method
1255 raise UnexpectedMethodCallError(mock_method, self)
1257 def IsSatisfied(self):
1258 """Return True if there are not any methods in this group."""
1260 return len(self._methods) == 0
1263 class MultipleTimesGroup(MethodGroup):
1264 """MultipleTimesGroup holds methods that may be called any number of times.
1266 Note: Each method must be called at least once.
1268 This is helpful, if you don't know or care how many times a method is called.
1271 def __init__(self, group_name):
1272 super(MultipleTimesGroup, self).__init__(group_name)
1273 self._methods = set()
1274 self._methods_called = set()
1276 def AddMethod(self, mock_method):
1277 """Add a method to this group.
1279 Args:
1280 mock_method: A mock method to be added to this group.
1283 self._methods.add(mock_method)
1285 def MethodCalled(self, mock_method):
1286 """Remove a method call from the group.
1288 If the method is not in the set, an UnexpectedMethodCallError will be
1289 raised.
1291 Args:
1292 mock_method: a mock method that should be equal to a method in the group.
1294 Returns:
1295 The mock method from the group
1297 Raises:
1298 UnexpectedMethodCallError if the mock_method was not in the group.
1301 # Check to see if this method exists, and if so add it to the set of
1302 # called methods.
1304 for method in self._methods:
1305 if method == mock_method:
1306 self._methods_called.add(mock_method)
1307 # Always put this group back on top of the queue, because we don't know
1308 # when we are done.
1309 mock_method._call_queue.appendleft(self)
1310 return self, method
1312 if self.IsSatisfied():
1313 next_method = mock_method._PopNextMethod();
1314 return next_method, None
1315 else:
1316 raise UnexpectedMethodCallError(mock_method, self)
1318 def IsSatisfied(self):
1319 """Return True if all methods in this group are called at least once."""
1320 # NOTE(psycho): We can't use the simple set difference here because we want
1321 # to match different parameters which are considered the same e.g. IsA(str)
1322 # and some string. This solution is O(n^2) but n should be small.
1323 tmp = self._methods.copy()
1324 for called in self._methods_called:
1325 for expected in tmp:
1326 if called == expected:
1327 tmp.remove(expected)
1328 if not tmp:
1329 return True
1330 break
1331 return False
1334 class MoxMetaTestBase(type):
1335 """Metaclass to add mox cleanup and verification to every test.
1337 As the mox unit testing class is being constructed (MoxTestBase or a
1338 subclass), this metaclass will modify all test functions to call the
1339 CleanUpMox method of the test class after they finish. This means that
1340 unstubbing and verifying will happen for every test with no additional code,
1341 and any failures will result in test failures as opposed to errors.
1344 def __init__(cls, name, bases, d):
1345 type.__init__(cls, name, bases, d)
1347 # also get all the attributes from the base classes to account
1348 # for a case when test class is not the immediate child of MoxTestBase
1349 for base in bases:
1350 for attr_name in dir(base):
1351 d[attr_name] = getattr(base, attr_name)
1353 for func_name, func in d.items():
1354 if func_name.startswith('test') and callable(func):
1355 setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func))
1357 @staticmethod
1358 def CleanUpTest(cls, func):
1359 """Adds Mox cleanup code to any MoxTestBase method.
1361 Always unsets stubs after a test. Will verify all mocks for tests that
1362 otherwise pass.
1364 Args:
1365 cls: MoxTestBase or subclass; the class whose test method we are altering.
1366 func: method; the method of the MoxTestBase test class we wish to alter.
1368 Returns:
1369 The modified method.
1371 def new_method(self, *args, **kwargs):
1372 mox_obj = getattr(self, 'mox', None)
1373 cleanup_mox = False
1374 if mox_obj and isinstance(mox_obj, Mox):
1375 cleanup_mox = True
1376 try:
1377 func(self, *args, **kwargs)
1378 finally:
1379 if cleanup_mox:
1380 mox_obj.UnsetStubs()
1381 if cleanup_mox:
1382 mox_obj.VerifyAll()
1383 new_method.__name__ = func.__name__
1384 new_method.__doc__ = func.__doc__
1385 new_method.__module__ = func.__module__
1386 return new_method
1389 class MoxTestBase(unittest.TestCase):
1390 """Convenience test class to make stubbing easier.
1392 Sets up a "mox" attribute which is an instance of Mox - any mox tests will
1393 want this. Also automatically unsets any stubs and verifies that all mock
1394 methods have been called at the end of each test, eliminating boilerplate
1395 code.
1398 __metaclass__ = MoxMetaTestBase
1400 def setUp(self):
1401 self.mox = Mox()