1 """Tests for the attribute exchange extension module
5 from openid
.extensions
import ax
6 from openid
.message
import NamespaceMap
, Message
, OPENID2_NS
7 from openid
.consumer
.consumer
import SuccessResponse
9 class BogusAXMessage(ax
.AXMessage
):
12 getExtensionArgs
= ax
.AXMessage
._newArgs
14 class DummyRequest(object):
15 def __init__(self
, message
):
16 self
.message
= message
18 class AXMessageTest(unittest
.TestCase
):
20 self
.bax
= BogusAXMessage()
22 def test_checkMode(self
):
23 check
= self
.bax
._checkMode
24 self
.failUnlessRaises(ax
.NotAXMessage
, check
, {})
25 self
.failUnlessRaises(ax
.AXError
, check
, {'mode':'fetch_request'})
27 # does not raise an exception when the mode is right
28 check({'mode':self
.bax
.mode
})
30 def test_checkMode_newArgs(self
):
31 """_newArgs generates something that has the correct mode"""
32 # This would raise AXError if it didn't like the mode newArgs made.
33 self
.bax
._checkMode
(self
.bax
._newArgs
())
36 class AttrInfoTest(unittest
.TestCase
):
37 def test_construct(self
):
38 self
.failUnlessRaises(TypeError, ax
.AttrInfo
)
40 ainfo
= ax
.AttrInfo(type_uri
)
42 self
.failUnlessEqual(type_uri
, ainfo
.type_uri
)
43 self
.failUnlessEqual(1, ainfo
.count
)
44 self
.failIf(ainfo
.required
)
45 self
.failUnless(ainfo
.alias
is None)
48 class ToTypeURIsTest(unittest
.TestCase
):
50 self
.aliases
= NamespaceMap()
53 for empty
in [None, '']:
54 uris
= ax
.toTypeURIs(self
.aliases
, empty
)
55 self
.failUnlessEqual([], uris
)
57 def test_undefined(self
):
58 self
.failUnlessRaises(
60 ax
.toTypeURIs
, self
.aliases
, 'http://janrain.com/')
63 uri
= 'http://janrain.com/'
64 alias
= 'openid_hackers'
65 self
.aliases
.addAlias(uri
, alias
)
66 uris
= ax
.toTypeURIs(self
.aliases
, alias
)
67 self
.failUnlessEqual([uri
], uris
)
70 uri1
= 'http://janrain.com/'
71 alias1
= 'openid_hackers'
72 self
.aliases
.addAlias(uri1
, alias1
)
74 uri2
= 'http://jyte.com/'
75 alias2
= 'openid_hack'
76 self
.aliases
.addAlias(uri2
, alias2
)
78 uris
= ax
.toTypeURIs(self
.aliases
, ','.join([alias1
, alias2
]))
79 self
.failUnlessEqual([uri1
, uri2
], uris
)
81 class ParseAXValuesTest(unittest
.TestCase
):
82 """Testing AXKeyValueMessage.parseExtensionArgs."""
84 def failUnlessAXKeyError(self
, ax_args
):
85 msg
= ax
.AXKeyValueMessage()
86 self
.failUnlessRaises(KeyError, msg
.parseExtensionArgs
, ax_args
)
88 def failUnlessAXValues(self
, ax_args
, expected_args
):
89 """Fail unless parseExtensionArgs(ax_args) == expected_args."""
90 msg
= ax
.AXKeyValueMessage()
91 msg
.parseExtensionArgs(ax_args
)
92 self
.failUnlessEqual(expected_args
, msg
.data
)
94 def test_emptyIsValid(self
):
95 self
.failUnlessAXValues({}, {})
97 def test_missingValueForAliasExplodes(self
):
98 self
.failUnlessAXKeyError({'type.foo':'urn:foo'})
100 def test_countPresentButNotValue(self
):
101 self
.failUnlessAXKeyError({'type.foo':'urn:foo',
104 def test_invalidCountValue(self
):
105 msg
= ax
.FetchRequest()
106 self
.failUnlessRaises(ax
.AXError
,
107 msg
.parseExtensionArgs
,
108 {'type.foo':'urn:foo',
109 'count.foo':'bogus'})
111 def test_requestUnlimitedValues(self
):
112 msg
= ax
.FetchRequest()
114 msg
.parseExtensionArgs(
115 {'mode':'fetch_request',
117 'type.foo':'urn:foo',
118 'count.foo':ax
.UNLIMITED_VALUES
})
120 attrs
= list(msg
.iterAttrs())
123 self
.failUnless(foo
.count
== ax
.UNLIMITED_VALUES
)
124 self
.failUnless(foo
.wantsUnlimitedValues())
126 def test_longAlias(self
):
127 # Spec minimum length is 32 characters. This is a silly test
128 # for this library, but it's here for completeness.
129 alias
= 'x' * ax
.MINIMUM_SUPPORTED_ALIAS_LENGTH
131 msg
= ax
.AXKeyValueMessage()
132 msg
.parseExtensionArgs(
133 {'type.%s' % (alias
,): 'urn:foo',
134 'count.%s' % (alias
,): '1',
135 'value.%s.1' % (alias
,): 'first'}
138 def test_invalidAlias(self
):
140 ax
.AXKeyValueMessage
,
145 {'type.a.b':'urn:foo',
147 {'type.a,b':'urn:foo',
154 self
.failUnlessRaises(ax
.AXError
, msg
.parseExtensionArgs
,
157 def test_countPresentAndIsZero(self
):
158 self
.failUnlessAXValues(
159 {'type.foo':'urn:foo',
163 def test_singletonEmpty(self
):
164 self
.failUnlessAXValues(
165 {'type.foo':'urn:foo',
169 def test_doubleAlias(self
):
170 self
.failUnlessAXKeyError(
171 {'type.foo':'urn:foo',
173 'type.bar':'urn:foo',
177 def test_doubleSingleton(self
):
178 self
.failUnlessAXValues(
179 {'type.foo':'urn:foo',
181 'type.bar':'urn:bar',
183 }, {'urn:foo':[], 'urn:bar':[]})
185 def test_singletonValue(self
):
186 self
.failUnlessAXValues(
187 {'type.foo':'urn:foo',
188 'value.foo':'Westfall',
189 }, {'urn:foo':['Westfall']})
192 class FetchRequestTest(unittest
.TestCase
):
194 self
.msg
= ax
.FetchRequest()
195 self
.type_a
= 'http://janrain.example.com/a'
200 self
.failUnlessEqual(self
.msg
.mode
, 'fetch_request')
202 def test_construct(self
):
203 self
.failUnlessEqual({}, self
.msg
.requested_attributes
)
204 self
.failUnlessEqual(None, self
.msg
.update_url
)
206 msg
= ax
.FetchRequest('hailstorm')
207 self
.failUnlessEqual({}, msg
.requested_attributes
)
208 self
.failUnlessEqual('hailstorm', msg
.update_url
)
214 self
.failIf(uri
in self
.msg
)
216 attr
= ax
.AttrInfo(uri
)
219 # Present after adding
220 self
.failUnless(uri
in self
.msg
)
222 def test_addTwice(self
):
223 uri
= 'lightning://storm'
225 attr
= ax
.AttrInfo(uri
)
227 self
.failUnlessRaises(KeyError, self
.msg
.add
, attr
)
229 def test_getExtensionArgs_empty(self
):
231 'mode':'fetch_request',
233 self
.failUnlessEqual(expected_args
, self
.msg
.getExtensionArgs())
235 def test_getExtensionArgs_noAlias(self
):
237 type_uri
= 'type://of.transportation',
240 ax_args
= self
.msg
.getExtensionArgs()
241 for k
, v
in ax_args
.iteritems():
242 if v
== attr
.type_uri
and k
.startswith('type.'):
246 self
.fail("Didn't find the type definition")
248 self
.failUnlessExtensionArgs({
249 'type.' + alias
:attr
.type_uri
,
250 'if_available':alias
,
253 def test_getExtensionArgs_alias_if_available(self
):
255 type_uri
= 'type://of.transportation',
259 self
.failUnlessExtensionArgs({
260 'type.' + attr
.alias
:attr
.type_uri
,
261 'if_available':attr
.alias
,
264 def test_getExtensionArgs_alias_req(self
):
266 type_uri
= 'type://of.transportation',
271 self
.failUnlessExtensionArgs({
272 'type.' + attr
.alias
:attr
.type_uri
,
273 'required':attr
.alias
,
276 def failUnlessExtensionArgs(self
, expected_args
):
277 """Make sure that getExtensionArgs has the expected result
279 This method will fill in the mode.
281 expected_args
= dict(expected_args
)
282 expected_args
['mode'] = self
.msg
.mode
283 self
.failUnlessEqual(expected_args
, self
.msg
.getExtensionArgs())
285 def test_isIterable(self
):
286 self
.failUnlessEqual([], list(self
.msg
))
287 self
.failUnlessEqual([], list(self
.msg
.iterAttrs()))
289 def test_getRequiredAttrs_empty(self
):
290 self
.failUnlessEqual([], self
.msg
.getRequiredAttrs())
292 def test_parseExtensionArgs_extraType(self
):
294 'mode':'fetch_request',
295 'type.' + self
.alias_a
:self
.type_a
,
297 self
.failUnlessRaises(ValueError,
298 self
.msg
.parseExtensionArgs
, extension_args
)
300 def test_parseExtensionArgs(self
):
302 'mode':'fetch_request',
303 'type.' + self
.alias_a
:self
.type_a
,
304 'if_available':self
.alias_a
306 self
.msg
.parseExtensionArgs(extension_args
)
307 self
.failUnless(self
.type_a
in self
.msg
)
308 self
.failUnlessEqual([self
.type_a
], list(self
.msg
))
309 attr_info
= self
.msg
.requested_attributes
.get(self
.type_a
)
310 self
.failUnless(attr_info
)
311 self
.failIf(attr_info
.required
)
312 self
.failUnlessEqual(self
.type_a
, attr_info
.type_uri
)
313 self
.failUnlessEqual(self
.alias_a
, attr_info
.alias
)
314 self
.failUnlessEqual([attr_info
], list(self
.msg
.iterAttrs()))
316 def test_extensionArgs_idempotent(self
):
318 'mode':'fetch_request',
319 'type.' + self
.alias_a
:self
.type_a
,
320 'if_available':self
.alias_a
322 self
.msg
.parseExtensionArgs(extension_args
)
323 self
.failUnlessEqual(extension_args
, self
.msg
.getExtensionArgs())
324 self
.failIf(self
.msg
.requested_attributes
[self
.type_a
].required
)
326 def test_extensionArgs_idempotent_count_required(self
):
328 'mode':'fetch_request',
329 'type.' + self
.alias_a
:self
.type_a
,
330 'count.' + self
.alias_a
:'2',
331 'required':self
.alias_a
333 self
.msg
.parseExtensionArgs(extension_args
)
334 self
.failUnlessEqual(extension_args
, self
.msg
.getExtensionArgs())
335 self
.failUnless(self
.msg
.requested_attributes
[self
.type_a
].required
)
337 def test_extensionArgs_count1(self
):
339 'mode':'fetch_request',
340 'type.' + self
.alias_a
:self
.type_a
,
341 'count.' + self
.alias_a
:'1',
342 'if_available':self
.alias_a
,
344 extension_args_norm
= {
345 'mode':'fetch_request',
346 'type.' + self
.alias_a
:self
.type_a
,
347 'if_available':self
.alias_a
,
349 self
.msg
.parseExtensionArgs(extension_args
)
350 self
.failUnlessEqual(extension_args_norm
, self
.msg
.getExtensionArgs())
352 def test_openidNoRealm(self
):
353 openid_req_msg
= Message
.fromOpenIDArgs({
354 'mode': 'checkid_setup',
356 'ns.ax': ax
.AXMessage
.ns_uri
,
357 'ax.update_url': 'http://different.site/path',
358 'ax.mode': 'fetch_request',
360 self
.failUnlessRaises(ax
.AXError
,
361 ax
.FetchRequest
.fromOpenIDRequest
,
362 DummyRequest(openid_req_msg
))
364 def test_openidUpdateURLVerificationError(self
):
365 openid_req_msg
= Message
.fromOpenIDArgs({
366 'mode': 'checkid_setup',
368 'realm': 'http://example.com/realm',
369 'ns.ax': ax
.AXMessage
.ns_uri
,
370 'ax.update_url': 'http://different.site/path',
371 'ax.mode': 'fetch_request',
374 self
.failUnlessRaises(ax
.AXError
,
375 ax
.FetchRequest
.fromOpenIDRequest
,
376 DummyRequest(openid_req_msg
))
378 def test_openidUpdateURLVerificationSuccess(self
):
379 openid_req_msg
= Message
.fromOpenIDArgs({
380 'mode': 'checkid_setup',
382 'realm': 'http://example.com/realm',
383 'ns.ax': ax
.AXMessage
.ns_uri
,
384 'ax.update_url': 'http://example.com/realm/update_path',
385 'ax.mode': 'fetch_request',
388 fr
= ax
.FetchRequest
.fromOpenIDRequest(DummyRequest(openid_req_msg
))
390 def test_openidUpdateURLVerificationSuccessReturnTo(self
):
391 openid_req_msg
= Message
.fromOpenIDArgs({
392 'mode': 'checkid_setup',
394 'return_to': 'http://example.com/realm',
395 'ns.ax': ax
.AXMessage
.ns_uri
,
396 'ax.update_url': 'http://example.com/realm/update_path',
397 'ax.mode': 'fetch_request',
400 fr
= ax
.FetchRequest
.fromOpenIDRequest(DummyRequest(openid_req_msg
))
402 def test_fromOpenIDRequestWithoutExtension(self
):
403 """return None for an OpenIDRequest without AX paramaters."""
404 openid_req_msg
= Message
.fromOpenIDArgs({
405 'mode': 'checkid_setup',
408 oreq
= DummyRequest(openid_req_msg
)
409 r
= ax
.FetchRequest
.fromOpenIDRequest(oreq
)
410 self
.failUnless(r
is None, "%s is not None" % (r
,))
412 def test_fromOpenIDRequestWithoutData(self
):
413 """return something for SuccessResponse with AX paramaters,
414 even if it is the empty set."""
415 openid_req_msg
= Message
.fromOpenIDArgs({
416 'mode': 'checkid_setup',
417 'realm': 'http://example.com/realm',
419 'ns.ax': ax
.AXMessage
.ns_uri
,
420 'ax.mode': 'fetch_request',
422 oreq
= DummyRequest(openid_req_msg
)
423 r
= ax
.FetchRequest
.fromOpenIDRequest(oreq
)
424 self
.failUnless(r
is not None)
427 class FetchResponseTest(unittest
.TestCase
):
429 self
.msg
= ax
.FetchResponse()
430 self
.value_a
= 'monkeys'
431 self
.type_a
= 'http://phone.home/'
432 self
.alias_a
= 'robocop'
433 self
.request_update_url
= 'http://update.bogus/'
435 def test_construct(self
):
436 self
.failUnless(self
.msg
.update_url
is None)
437 self
.failUnlessEqual({}, self
.msg
.data
)
439 def test_getExtensionArgs_empty(self
):
441 'mode':'fetch_response',
443 self
.failUnlessEqual(expected_args
, self
.msg
.getExtensionArgs())
445 def test_getExtensionArgs_empty_request(self
):
447 'mode':'fetch_response',
449 req
= ax
.FetchRequest()
450 msg
= ax
.FetchResponse(request
=req
)
451 self
.failUnlessEqual(expected_args
, msg
.getExtensionArgs())
453 def test_getExtensionArgs_empty_request_some(self
):
454 uri
= 'http://not.found/'
458 'mode':'fetch_response',
459 'type.%s' % (alias
,): uri
,
460 'count.%s' % (alias
,): '0'
462 req
= ax
.FetchRequest()
463 req
.add(ax
.AttrInfo(uri
))
464 msg
= ax
.FetchResponse(request
=req
)
465 self
.failUnlessEqual(expected_args
, msg
.getExtensionArgs())
467 def test_updateUrlInResponse(self
):
468 uri
= 'http://not.found/'
472 'mode':'fetch_response',
473 'update_url': self
.request_update_url
,
474 'type.%s' % (alias
,): uri
,
475 'count.%s' % (alias
,): '0'
477 req
= ax
.FetchRequest(update_url
=self
.request_update_url
)
478 req
.add(ax
.AttrInfo(uri
))
479 msg
= ax
.FetchResponse(request
=req
)
480 self
.failUnlessEqual(expected_args
, msg
.getExtensionArgs())
482 def test_getExtensionArgs_some_request(self
):
484 'mode':'fetch_response',
485 'type.' + self
.alias_a
:self
.type_a
,
486 'value.' + self
.alias_a
+ '.1':self
.value_a
,
487 'count.' + self
.alias_a
: '1'
489 req
= ax
.FetchRequest()
490 req
.add(ax
.AttrInfo(self
.type_a
, alias
=self
.alias_a
))
491 msg
= ax
.FetchResponse(request
=req
)
492 msg
.addValue(self
.type_a
, self
.value_a
)
493 self
.failUnlessEqual(expected_args
, msg
.getExtensionArgs())
495 def test_getExtensionArgs_some_not_request(self
):
496 req
= ax
.FetchRequest()
497 msg
= ax
.FetchResponse(request
=req
)
498 msg
.addValue(self
.type_a
, self
.value_a
)
499 self
.failUnlessRaises(KeyError, msg
.getExtensionArgs
)
501 def test_getSingle_success(self
):
502 req
= ax
.FetchRequest()
503 self
.msg
.addValue(self
.type_a
, self
.value_a
)
504 self
.failUnlessEqual(self
.value_a
, self
.msg
.getSingle(self
.type_a
))
506 def test_getSingle_none(self
):
507 self
.failUnlessEqual(None, self
.msg
.getSingle(self
.type_a
))
509 def test_getSingle_extra(self
):
510 self
.msg
.setValues(self
.type_a
, ['x', 'y'])
511 self
.failUnlessRaises(ax
.AXError
, self
.msg
.getSingle
, self
.type_a
)
514 self
.failUnlessRaises(KeyError, self
.msg
.get
, self
.type_a
)
516 def test_fromSuccessResponseWithoutExtension(self
):
517 """return None for SuccessResponse with no AX paramaters."""
522 sf
= ['openid.' + i
for i
in args
.keys()]
523 msg
= Message
.fromOpenIDArgs(args
)
525 claimed_id
= 'http://invalid.'
527 oreq
= SuccessResponse(Endpoint(), msg
, signed_fields
=sf
)
528 r
= ax
.FetchResponse
.fromSuccessResponse(oreq
)
529 self
.failUnless(r
is None, "%s is not None" % (r
,))
531 def test_fromSuccessResponseWithoutData(self
):
532 """return something for SuccessResponse with AX paramaters,
533 even if it is the empty set."""
537 'ns.ax': ax
.AXMessage
.ns_uri
,
538 'ax.mode': 'fetch_response',
540 sf
= ['openid.' + i
for i
in args
.keys()]
541 msg
= Message
.fromOpenIDArgs(args
)
543 claimed_id
= 'http://invalid.'
545 oreq
= SuccessResponse(Endpoint(), msg
, signed_fields
=sf
)
546 r
= ax
.FetchResponse
.fromSuccessResponse(oreq
)
547 self
.failUnless(r
is not None)
549 def test_fromSuccessResponseWithData(self
):
552 uri
= "http://willy.wonka.name/"
556 'ns.ax': ax
.AXMessage
.ns_uri
,
557 'ax.update_url': 'http://example.com/realm/update_path',
558 'ax.mode': 'fetch_response',
559 'ax.type.'+name
: uri
,
560 'ax.count.'+name
: '1',
561 'ax.value.%s.1'%name
: value
,
563 sf
= ['openid.' + i
for i
in args
.keys()]
564 msg
= Message
.fromOpenIDArgs(args
)
566 claimed_id
= 'http://invalid.'
568 resp
= SuccessResponse(Endpoint(), msg
, signed_fields
=sf
)
569 ax_resp
= ax
.FetchResponse
.fromSuccessResponse(resp
)
570 values
= ax_resp
.get(uri
)
571 self
.failUnlessEqual([value
], values
)
574 class StoreRequestTest(unittest
.TestCase
):
576 self
.msg
= ax
.StoreRequest()
577 self
.type_a
= 'http://three.count/'
578 self
.alias_a
= 'juggling'
580 def test_construct(self
):
581 self
.failUnlessEqual({}, self
.msg
.data
)
583 def test_getExtensionArgs_empty(self
):
584 args
= self
.msg
.getExtensionArgs()
586 'mode':'store_request',
588 self
.failUnlessEqual(expected_args
, args
)
590 def test_getExtensionArgs_nonempty(self
):
591 aliases
= NamespaceMap()
592 aliases
.addAlias(self
.type_a
, self
.alias_a
)
593 msg
= ax
.StoreRequest(aliases
=aliases
)
594 msg
.setValues(self
.type_a
, ['foo', 'bar'])
595 args
= msg
.getExtensionArgs()
597 'mode':'store_request',
598 'type.' + self
.alias_a
: self
.type_a
,
599 'count.' + self
.alias_a
: '2',
600 'value.%s.1' % (self
.alias_a
,):'foo',
601 'value.%s.2' % (self
.alias_a
,):'bar',
603 self
.failUnlessEqual(expected_args
, args
)
605 class StoreResponseTest(unittest
.TestCase
):
606 def test_success(self
):
607 msg
= ax
.StoreResponse()
608 self
.failUnless(msg
.succeeded())
609 self
.failIf(msg
.error_message
)
610 self
.failUnlessEqual({'mode':'store_response_success'},
611 msg
.getExtensionArgs())
613 def test_fail_nomsg(self
):
614 msg
= ax
.StoreResponse(False)
615 self
.failIf(msg
.succeeded())
616 self
.failIf(msg
.error_message
)
617 self
.failUnlessEqual({'mode':'store_response_failure'},
618 msg
.getExtensionArgs())
620 def test_fail_msg(self
):
621 reason
= 'no reason, really'
622 msg
= ax
.StoreResponse(False, reason
)
623 self
.failIf(msg
.succeeded())
624 self
.failUnlessEqual(reason
, msg
.error_message
)
625 self
.failUnlessEqual({'mode':'store_response_failure',
626 'error':reason
}, msg
.getExtensionArgs())