1 """Tests for consumer handling of association responses
3 This duplicates some things that are covered by test_consumer, but
6 from openid
import oidutil
7 from openid
.test
.test_consumer
import CatchLogs
8 from openid
.message
import Message
, OPENID2_NS
, OPENID_NS
, no_default
9 from openid
.server
.server
import DiffieHellmanSHA1ServerSession
10 from openid
.consumer
.consumer
import GenericConsumer
, \
11 DiffieHellmanSHA1ConsumerSession
, ProtocolError
12 from openid
.consumer
.discover
import OpenIDServiceEndpoint
, OPENID_1_1_TYPE
, OPENID_2_0_TYPE
13 from openid
.store
import memstore
16 # Some values we can use for convenience (see mkAssocResponse)
17 association_response_values
= {
19 'assoc_handle':'a handle',
20 'assoc_type':'a type',
21 'session_type':'a session type',
25 def mkAssocResponse(*keys
):
26 """Build an association response message that contains the
27 specified subset of keys. The values come from
28 `association_response_values`.
30 This is useful for testing for missing keys and other times that
31 we don't care what the values are."""
32 args
= dict([(key
, association_response_values
[key
]) for key
in keys
])
33 return Message
.fromOpenIDArgs(args
)
35 class BaseAssocTest(CatchLogs
, unittest
.TestCase
):
38 self
.store
= memstore
.MemoryStore()
39 self
.consumer
= GenericConsumer(self
.store
)
40 self
.endpoint
= OpenIDServiceEndpoint()
42 def failUnlessProtocolError(self
, str_prefix
, func
, *args
, **kwargs
):
44 result
= func(*args
, **kwargs
)
45 except ProtocolError
, e
:
46 message
= 'Expected prefix %r, got %r' % (str_prefix
, e
[0])
47 self
.failUnless(e
[0].startswith(str_prefix
), message
)
49 self
.fail('Expected ProtocolError, got %r' % (result
,))
51 def mkExtractAssocMissingTest(keys
):
52 """Factory function for creating test methods for generating
55 Make a test that ensures that an association response that
56 is missing required fields will short-circuit return None.
58 According to 'Association Session Response' subsection 'Common
59 Response Parameters', the following fields are required for OpenID
68 If 'ns' is missing, it will fall back to OpenID 1 checking. In
69 OpenID 1, everything except 'session_type' and 'ns' are required.
73 msg
= mkAssocResponse(*keys
)
75 self
.failUnlessRaises(KeyError,
76 self
.consumer
._extractAssociation
, msg
, None)
80 class TestExtractAssociationMissingFieldsOpenID2(BaseAssocTest
):
81 """Test for returning an error upon missing fields in association
82 responses for OpenID 2"""
84 test_noFields_openid2
= mkExtractAssocMissingTest(['ns'])
86 test_missingExpires_openid2
= mkExtractAssocMissingTest(
87 ['assoc_handle', 'assoc_type', 'session_type', 'ns'])
89 test_missingHandle_openid2
= mkExtractAssocMissingTest(
90 ['expires_in', 'assoc_type', 'session_type', 'ns'])
92 test_missingAssocType_openid2
= mkExtractAssocMissingTest(
93 ['expires_in', 'assoc_handle', 'session_type', 'ns'])
95 test_missingSessionType_openid2
= mkExtractAssocMissingTest(
96 ['expires_in', 'assoc_handle', 'assoc_type', 'ns'])
98 class TestExtractAssociationMissingFieldsOpenID1(BaseAssocTest
):
99 """Test for returning an error upon missing fields in association
100 responses for OpenID 2"""
102 test_noFields_openid1
= mkExtractAssocMissingTest([])
104 test_missingExpires_openid1
= mkExtractAssocMissingTest(
105 ['assoc_handle', 'assoc_type'])
107 test_missingHandle_openid1
= mkExtractAssocMissingTest(
108 ['expires_in', 'assoc_type'])
110 test_missingAssocType_openid1
= mkExtractAssocMissingTest(
111 ['expires_in', 'assoc_handle'])
113 class DummyAssocationSession(object):
114 def __init__(self
, session_type
, allowed_assoc_types
=()):
115 self
.session_type
= session_type
116 self
.allowed_assoc_types
= allowed_assoc_types
118 class ExtractAssociationSessionTypeMismatch(BaseAssocTest
):
119 def mkTest(requested_session_type
, response_session_type
, openid1
=False):
121 assoc_session
= DummyAssocationSession(requested_session_type
)
122 keys
= association_response_values
.keys()
125 msg
= mkAssocResponse(*keys
)
126 msg
.setArg(OPENID_NS
, 'session_type', response_session_type
)
127 self
.failUnlessProtocolError('Session type mismatch',
128 self
.consumer
._extractAssociation
, msg
, assoc_session
)
132 test_typeMismatchNoEncBlank_openid2
= mkTest(
133 requested_session_type
='no-encryption',
134 response_session_type
='',
137 test_typeMismatchDHSHA1NoEnc_openid2
= mkTest(
138 requested_session_type
='DH-SHA1',
139 response_session_type
='no-encryption',
142 test_typeMismatchDHSHA256NoEnc_openid2
= mkTest(
143 requested_session_type
='DH-SHA256',
144 response_session_type
='no-encryption',
147 test_typeMismatchNoEncDHSHA1_openid2
= mkTest(
148 requested_session_type
='no-encryption',
149 response_session_type
='DH-SHA1',
152 test_typeMismatchDHSHA1NoEnc_openid1
= mkTest(
153 requested_session_type
='DH-SHA1',
154 response_session_type
='DH-SHA256',
158 test_typeMismatchDHSHA256NoEnc_openid1
= mkTest(
159 requested_session_type
='DH-SHA256',
160 response_session_type
='DH-SHA1',
164 test_typeMismatchNoEncDHSHA1_openid1
= mkTest(
165 requested_session_type
='no-encryption',
166 response_session_type
='DH-SHA1',
171 class TestOpenID1AssociationResponseSessionType(BaseAssocTest
):
172 def mkTest(expected_session_type
, session_type_value
):
173 """Return a test method that will check what session type will
174 be used if the OpenID 1 response to an associate call sets the
175 'session_type' field to `session_type_value`
178 self
._doTest
(expected_session_type
, session_type_value
)
179 self
.failUnlessEqual(0, len(self
.messages
))
183 def _doTest(self
, expected_session_type
, session_type_value
):
184 # Create a Message with just 'session_type' in it, since
185 # that's all this function will use. 'session_type' may be
186 # absent if it's set to None.
188 if session_type_value
is not None:
189 args
['session_type'] = session_type_value
190 message
= Message
.fromOpenIDArgs(args
)
191 self
.failUnless(message
.isOpenID1())
193 actual_session_type
= self
.consumer
._getOpenID
1SessionType
(message
)
194 error_message
= ('Returned sesion type parameter %r was expected '
195 'to yield session type %r, but yielded %r' %
196 (session_type_value
, expected_session_type
,
197 actual_session_type
))
198 self
.failUnlessEqual(
199 expected_session_type
, actual_session_type
, error_message
)
202 session_type_value
=None,
203 expected_session_type
='no-encryption',
207 session_type_value
='',
208 expected_session_type
='no-encryption',
211 # This one's different because it expects log messages
212 def test_explicitNoEncryption(self
):
214 session_type_value
='no-encryption',
215 expected_session_type
='no-encryption',
217 self
.failUnlessEqual(1, len(self
.messages
))
218 self
.failUnless(self
.messages
[0].startswith(
219 'WARNING: OpenID server sent "no-encryption"'))
221 test_dhSHA1
= mkTest(
222 session_type_value
='DH-SHA1',
223 expected_session_type
='DH-SHA1',
226 # DH-SHA256 is not a valid session type for OpenID1, but this
227 # function does not test that. This is mostly just to make sure
228 # that it will pass-through stuff that is not explicitly handled,
229 # so it will get handled the same way as it is handled for OpenID
231 test_dhSHA256
= mkTest(
232 session_type_value
='DH-SHA256',
233 expected_session_type
='DH-SHA256',
236 class DummyAssociationSession(object):
237 secret
= "shh! don't tell!"
238 extract_secret_called
= False
242 allowed_assoc_types
= None
244 def extractSecret(self
, message
):
245 self
.extract_secret_called
= True
248 class TestInvalidFields(BaseAssocTest
):
250 BaseAssocTest
.setUp(self
)
251 self
.session_type
= 'testing-session'
253 # This must something that works for Association.fromExpiresIn
254 self
.assoc_type
= 'HMAC-SHA1'
256 self
.assoc_handle
= 'testing-assoc-handle'
258 # These arguments should all be valid
259 self
.assoc_response
= Message
.fromOpenIDArgs({
260 'expires_in': '1000',
261 'assoc_handle':self
.assoc_handle
,
262 'assoc_type':self
.assoc_type
,
263 'session_type':self
.session_type
,
267 self
.assoc_session
= DummyAssociationSession()
269 # Make the session for the response's session type
270 self
.assoc_session
.session_type
= self
.session_type
271 self
.assoc_session
.allowed_assoc_types
= [self
.assoc_type
]
273 def test_worksWithGoodFields(self
):
274 """Handle a full successful association response"""
275 assoc
= self
.consumer
._extractAssociation
(
276 self
.assoc_response
, self
.assoc_session
)
277 self
.failUnless(self
.assoc_session
.extract_secret_called
)
278 self
.failUnlessEqual(self
.assoc_session
.secret
, assoc
.secret
)
279 self
.failUnlessEqual(1000, assoc
.lifetime
)
280 self
.failUnlessEqual(self
.assoc_handle
, assoc
.handle
)
281 self
.failUnlessEqual(self
.assoc_type
, assoc
.assoc_type
)
283 def test_badAssocType(self
):
284 # Make sure that the assoc type in the response is not valid
285 # for the given session.
286 self
.assoc_session
.allowed_assoc_types
= []
287 self
.failUnlessProtocolError('Unsupported assoc_type for session',
288 self
.consumer
._extractAssociation
,
289 self
.assoc_response
, self
.assoc_session
)
291 def test_badExpiresIn(self
):
292 # Invalid value for expires_in should cause failure
293 self
.assoc_response
.setArg(OPENID_NS
, 'expires_in', 'forever')
294 self
.failUnlessProtocolError('Invalid expires_in',
295 self
.consumer
._extractAssociation
,
296 self
.assoc_response
, self
.assoc_session
)
299 # XXX: This is what causes most of the imports in this file. It is
300 # sort of a unit test and sort of a functional test. I'm not terribly
302 class TestExtractAssociationDiffieHellman(BaseAssocTest
):
306 sess
, message
= self
.consumer
._createAssociateRequest
(
307 self
.endpoint
, 'HMAC-SHA1', 'DH-SHA1')
309 # XXX: this is testing _createAssociateRequest
310 self
.failUnlessEqual(self
.endpoint
.compatibilityMode(),
313 server_sess
= DiffieHellmanSHA1ServerSession
.fromMessage(message
)
314 server_resp
= server_sess
.answer(self
.secret
)
315 server_resp
['assoc_type'] = 'HMAC-SHA1'
316 server_resp
['assoc_handle'] = 'handle'
317 server_resp
['expires_in'] = '1000'
318 server_resp
['session_type'] = 'DH-SHA1'
319 return sess
, Message
.fromOpenIDArgs(server_resp
)
321 def test_success(self
):
322 sess
, server_resp
= self
._setUpDH
()
323 ret
= self
.consumer
._extractAssociation
(server_resp
, sess
)
324 self
.failIf(ret
is None)
325 self
.failUnlessEqual(ret
.assoc_type
, 'HMAC-SHA1')
326 self
.failUnlessEqual(ret
.secret
, self
.secret
)
327 self
.failUnlessEqual(ret
.handle
, 'handle')
328 self
.failUnlessEqual(ret
.lifetime
, 1000)
330 def test_openid2success(self
):
331 # Use openid 2 type in endpoint so _setUpDH checks
332 # compatibility mode state properly
333 self
.endpoint
.type_uris
= [OPENID_2_0_TYPE
, OPENID_1_1_TYPE
]
336 def test_badDHValues(self
):
337 sess
, server_resp
= self
._setUpDH
()
338 server_resp
.setArg(OPENID_NS
, 'enc_mac_key', '\x00\x00\x00')
339 self
.failUnlessProtocolError('Malformed response for',
340 self
.consumer
._extractAssociation
, server_resp
, sess
)