2 # Copyright 2013 The Chromium Authors. All rights reserved.
3 # Use of this source code is governed by a BSD-style license that can be
4 # found in the LICENSE file.
6 """Tests exercising the various classes in xmppserver.py."""
13 class XmlUtilsTest(unittest
.TestCase
):
15 def testParseXml(self
):
16 xml_text
= """<foo xmlns=""><bar xmlns=""><baz/></bar></foo>"""
17 xml
= xmppserver
.ParseXml(xml_text
)
18 self
.assertEqual(xml
.toxml(), xml_text
)
20 def testCloneXml(self
):
21 xml
= xmppserver
.ParseXml('<foo/>')
22 xml_clone
= xmppserver
.CloneXml(xml
)
23 xml_clone
.setAttribute('bar', 'baz')
24 self
.assertEqual(xml
, xml
)
25 self
.assertEqual(xml_clone
, xml_clone
)
26 self
.assertNotEqual(xml
, xml_clone
)
28 def testCloneXmlUnlink(self
):
30 xml
= xmppserver
.ParseXml(xml_text
)
31 xml_clone
= xmppserver
.CloneXml(xml
)
33 self
.assertEqual(xml
.parentNode
, None)
34 self
.assertNotEqual(xml_clone
.parentNode
, None)
35 self
.assertEqual(xml_clone
.toxml(), xml_text
)
37 class StanzaParserTest(unittest
.TestCase
):
42 def FeedStanza(self
, stanza
):
43 # We can't append stanza directly because it is unlinked after
45 self
.stanzas
.append(stanza
.toxml())
48 parser
= xmppserver
.StanzaParser(self
)
49 parser
.FeedString('<foo')
50 self
.assertEqual(len(self
.stanzas
), 0)
51 parser
.FeedString('/><bar></bar>')
52 self
.assertEqual(self
.stanzas
[0], '<foo/>')
53 self
.assertEqual(self
.stanzas
[1], '<bar/>')
56 parser
= xmppserver
.StanzaParser(self
)
57 parser
.FeedString('<stream')
58 self
.assertEqual(len(self
.stanzas
), 0)
59 parser
.FeedString(':stream foo="bar" xmlns:stream="baz">')
60 self
.assertEqual(self
.stanzas
[0],
61 '<stream:stream foo="bar" xmlns:stream="baz"/>')
64 parser
= xmppserver
.StanzaParser(self
)
65 parser
.FeedString('<foo')
66 self
.assertEqual(len(self
.stanzas
), 0)
67 parser
.FeedString(' bar="baz"')
68 parser
.FeedString('><baz/><blah>meh</blah></foo>')
69 self
.assertEqual(self
.stanzas
[0],
70 '<foo bar="baz"><baz/><blah>meh</blah></foo>')
73 class JidTest(unittest
.TestCase
):
76 jid
= xmppserver
.Jid('foo', 'bar.com')
77 self
.assertEqual(str(jid
), 'foo@bar.com')
79 def testResource(self
):
80 jid
= xmppserver
.Jid('foo', 'bar.com', 'resource')
81 self
.assertEqual(str(jid
), 'foo@bar.com/resource')
83 def testGetBareJid(self
):
84 jid
= xmppserver
.Jid('foo', 'bar.com', 'resource')
85 self
.assertEqual(str(jid
.GetBareJid()), 'foo@bar.com')
88 class IdGeneratorTest(unittest
.TestCase
):
91 id_generator
= xmppserver
.IdGenerator('foo')
92 for i
in xrange(0, 100):
93 self
.assertEqual('foo.%d' % i
, id_generator
.GetNextId())
96 class HandshakeTaskTest(unittest
.TestCase
):
102 self
.data_received
= 0
103 self
.handshake_done
= False
106 def SendData(self
, _
):
107 self
.data_received
+= 1
109 def SendStanza(self
, _
, unused
=True):
110 self
.data_received
+= 1
112 def HandshakeDone(self
, jid
):
113 self
.handshake_done
= True
116 def DoHandshake(self
, resource_prefix
, resource
, username
,
117 initial_stream_domain
, auth_domain
, auth_stream_domain
):
120 xmppserver
.HandshakeTask(self
, resource_prefix
, True))
121 stream_xml
= xmppserver
.ParseXml('<stream:stream xmlns:stream="foo"/>')
122 stream_xml
.setAttribute('to', initial_stream_domain
)
123 self
.assertEqual(self
.data_received
, 0)
124 handshake_task
.FeedStanza(stream_xml
)
125 self
.assertEqual(self
.data_received
, 2)
128 username_domain
= '%s@%s' % (username
, auth_domain
)
130 username_domain
= username
131 auth_string
= base64
.b64encode('\0%s\0bar' % username_domain
)
132 auth_xml
= xmppserver
.ParseXml('<auth>%s</auth>'% auth_string
)
133 handshake_task
.FeedStanza(auth_xml
)
134 self
.assertEqual(self
.data_received
, 3)
136 stream_xml
= xmppserver
.ParseXml('<stream:stream xmlns:stream="foo"/>')
137 stream_xml
.setAttribute('to', auth_stream_domain
)
138 handshake_task
.FeedStanza(stream_xml
)
139 self
.assertEqual(self
.data_received
, 5)
141 bind_xml
= xmppserver
.ParseXml(
142 '<iq type="set"><bind><resource>%s</resource></bind></iq>' % resource
)
143 handshake_task
.FeedStanza(bind_xml
)
144 self
.assertEqual(self
.data_received
, 6)
146 self
.assertFalse(self
.handshake_done
)
148 session_xml
= xmppserver
.ParseXml(
149 '<iq type="set"><session></session></iq>')
150 handshake_task
.FeedStanza(session_xml
)
151 self
.assertEqual(self
.data_received
, 7)
153 self
.assertTrue(self
.handshake_done
)
155 self
.assertEqual(self
.jid
.username
, username
)
156 self
.assertEqual(self
.jid
.domain
,
157 auth_stream_domain
or auth_domain
or
158 initial_stream_domain
)
159 self
.assertEqual(self
.jid
.resource
,
160 '%s.%s' % (resource_prefix
, resource
))
162 handshake_task
.FeedStanza('<ignored/>')
163 self
.assertEqual(self
.data_received
, 7)
165 def DoHandshakeUnauthenticated(self
, resource_prefix
, resource
, username
,
166 initial_stream_domain
):
169 xmppserver
.HandshakeTask(self
, resource_prefix
, False))
170 stream_xml
= xmppserver
.ParseXml('<stream:stream xmlns:stream="foo"/>')
171 stream_xml
.setAttribute('to', initial_stream_domain
)
172 self
.assertEqual(self
.data_received
, 0)
173 handshake_task
.FeedStanza(stream_xml
)
174 self
.assertEqual(self
.data_received
, 2)
176 self
.assertFalse(self
.handshake_done
)
178 auth_string
= base64
.b64encode('\0%s\0bar' % username
)
179 auth_xml
= xmppserver
.ParseXml('<auth>%s</auth>'% auth_string
)
180 handshake_task
.FeedStanza(auth_xml
)
181 self
.assertEqual(self
.data_received
, 3)
183 self
.assertTrue(self
.handshake_done
)
185 self
.assertEqual(self
.jid
, None)
187 handshake_task
.FeedStanza('<ignored/>')
188 self
.assertEqual(self
.data_received
, 3)
191 self
.DoHandshake('resource_prefix', 'resource',
192 'foo', 'bar.com', 'baz.com', 'quux.com')
194 def testDomainBehavior(self
):
195 self
.DoHandshake('resource_prefix', 'resource',
196 'foo', 'bar.com', 'baz.com', 'quux.com')
197 self
.DoHandshake('resource_prefix', 'resource',
198 'foo', 'bar.com', 'baz.com', '')
199 self
.DoHandshake('resource_prefix', 'resource',
200 'foo', 'bar.com', '', '')
201 self
.DoHandshake('resource_prefix', 'resource',
204 def testBasicUnauthenticated(self
):
205 self
.DoHandshakeUnauthenticated('resource_prefix', 'resource',
209 class FakeSocket(object):
210 """A fake socket object used for testing.
216 def GetSentData(self
):
217 return self
._sent
_data
219 # socket-like methods.
223 def setblocking(self
, int):
226 def getpeername(self
):
229 def send(self
, data
):
230 self
._sent
_data
.append(data
)
237 class XmppConnectionTest(unittest
.TestCase
):
240 self
.connections
= set()
241 self
.fake_socket
= FakeSocket()
243 # XmppConnection delegate methods.
244 def OnXmppHandshakeDone(self
, xmpp_connection
):
245 self
.connections
.add(xmpp_connection
)
247 def OnXmppConnectionClosed(self
, xmpp_connection
):
248 self
.connections
.discard(xmpp_connection
)
250 def ForwardNotification(self
, unused_xmpp_connection
, notification_stanza
):
251 for connection
in self
.connections
:
252 connection
.ForwardNotification(notification_stanza
)
256 xmpp_connection
= xmppserver
.XmppConnection(
257 self
.fake_socket
, socket_map
, self
, ('', 0), True)
258 self
.assertEqual(len(socket_map
), 1)
259 self
.assertEqual(len(self
.connections
), 0)
260 xmpp_connection
.HandshakeDone(xmppserver
.Jid('foo', 'bar'))
261 self
.assertEqual(len(socket_map
), 1)
262 self
.assertEqual(len(self
.connections
), 1)
264 sent_data
= self
.fake_socket
.GetSentData()
266 # Test subscription request.
267 self
.assertEqual(len(sent_data
), 0)
268 xmpp_connection
.collect_incoming_data(
269 '<iq><subscribe xmlns="google:push"></subscribe></iq>')
270 self
.assertEqual(len(sent_data
), 1)
273 xmpp_connection
.collect_incoming_data('<iq type="result"/>')
274 self
.assertEqual(len(sent_data
), 1)
277 xmpp_connection
.collect_incoming_data(
278 '<message><push xmlns="google:push"/></message>')
279 self
.assertEqual(len(sent_data
), 2)
281 # Test unexpected stanza.
282 def SendUnexpectedStanza():
283 xmpp_connection
.collect_incoming_data('<foo/>')
284 self
.assertRaises(xmppserver
.UnexpectedXml
, SendUnexpectedStanza
)
286 # Test unexpected notifier command.
287 def SendUnexpectedNotifierCommand():
288 xmpp_connection
.collect_incoming_data(
289 '<iq><foo xmlns="google:notifier"/></iq>')
290 self
.assertRaises(xmppserver
.UnexpectedXml
,
291 SendUnexpectedNotifierCommand
)
294 xmpp_connection
.close()
295 self
.assertEqual(len(socket_map
), 0)
296 self
.assertEqual(len(self
.connections
), 0)
298 def testBasicUnauthenticated(self
):
300 xmpp_connection
= xmppserver
.XmppConnection(
301 self
.fake_socket
, socket_map
, self
, ('', 0), False)
302 self
.assertEqual(len(socket_map
), 1)
303 self
.assertEqual(len(self
.connections
), 0)
304 xmpp_connection
.HandshakeDone(None)
305 self
.assertEqual(len(socket_map
), 0)
306 self
.assertEqual(len(self
.connections
), 0)
308 # Test unexpected stanza.
309 def SendUnexpectedStanza():
310 xmpp_connection
.collect_incoming_data('<foo/>')
311 self
.assertRaises(xmppserver
.UnexpectedXml
, SendUnexpectedStanza
)
313 # Test redundant close.
314 xmpp_connection
.close()
315 self
.assertEqual(len(socket_map
), 0)
316 self
.assertEqual(len(self
.connections
), 0)
319 class FakeXmppServer(xmppserver
.XmppServer
):
320 """A fake XMPP server object used for testing.
324 self
._socket
_map
= {}
325 self
._fake
_sockets
= set()
326 self
._next
_jid
_suffix
= 1
327 xmppserver
.XmppServer
.__init
__(self
, self
._socket
_map
, ('', 0))
329 def GetSocketMap(self
):
330 return self
._socket
_map
332 def GetFakeSockets(self
):
333 return self
._fake
_sockets
335 def AddHandshakeCompletedConnection(self
):
336 """Creates a new XMPP connection and completes its handshake.
338 xmpp_connection
= self
.handle_accept()
339 jid
= xmppserver
.Jid('user%s' % self
._next
_jid
_suffix
, 'domain.com')
340 self
._next
_jid
_suffix
+= 1
341 xmpp_connection
.HandshakeDone(jid
)
343 # XmppServer overrides.
345 fake_socket
= FakeSocket()
346 self
._fake
_sockets
.add(fake_socket
)
347 return (fake_socket
, ('', 0))
350 self
._fake
_sockets
.clear()
351 xmppserver
.XmppServer
.close(self
)
354 class XmppServerTest(unittest
.TestCase
):
357 self
.xmpp_server
= FakeXmppServer()
359 def AssertSentDataLength(self
, expected_length
):
360 for fake_socket
in self
.xmpp_server
.GetFakeSockets():
361 self
.assertEqual(len(fake_socket
.GetSentData()), expected_length
)
364 socket_map
= self
.xmpp_server
.GetSocketMap()
365 self
.assertEqual(len(socket_map
), 1)
366 self
.xmpp_server
.AddHandshakeCompletedConnection()
367 self
.assertEqual(len(socket_map
), 2)
368 self
.xmpp_server
.close()
369 self
.assertEqual(len(socket_map
), 0)
371 def testMakeNotification(self
):
372 notification
= self
.xmpp_server
.MakeNotification('channel', 'data')
375 ' <push channel="channel" xmlns="google:push">'
378 '</message>' % base64
.b64encode('data'))
379 self
.assertEqual(notification
.toxml(), expected_xml
)
381 def testSendNotification(self
):
382 # Add a few connections.
383 for _
in xrange(0, 7):
384 self
.xmpp_server
.AddHandshakeCompletedConnection()
386 self
.assertEqual(len(self
.xmpp_server
.GetFakeSockets()), 7)
388 self
.AssertSentDataLength(0)
389 self
.xmpp_server
.SendNotification('channel', 'data')
390 self
.AssertSentDataLength(1)
392 def testEnableDisableNotifications(self
):
393 # Add a few connections.
394 for _
in xrange(0, 5):
395 self
.xmpp_server
.AddHandshakeCompletedConnection()
397 self
.assertEqual(len(self
.xmpp_server
.GetFakeSockets()), 5)
399 self
.AssertSentDataLength(0)
400 self
.xmpp_server
.SendNotification('channel', 'data')
401 self
.AssertSentDataLength(1)
403 self
.xmpp_server
.EnableNotifications()
404 self
.xmpp_server
.SendNotification('channel', 'data')
405 self
.AssertSentDataLength(2)
407 self
.xmpp_server
.DisableNotifications()
408 self
.xmpp_server
.SendNotification('channel', 'data')
409 self
.AssertSentDataLength(2)
411 self
.xmpp_server
.DisableNotifications()
412 self
.xmpp_server
.SendNotification('channel', 'data')
413 self
.AssertSentDataLength(2)
415 self
.xmpp_server
.EnableNotifications()
416 self
.xmpp_server
.SendNotification('channel', 'data')
417 self
.AssertSentDataLength(3)
420 if __name__
== '__main__':