1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "content/browser/renderer_host/websocket_dispatcher_host.h"
10 #include "base/bind.h"
11 #include "base/bind_helpers.h"
12 #include "base/memory/ref_counted.h"
13 #include "base/memory/weak_ptr.h"
14 #include "base/message_loop/message_loop.h"
15 #include "content/browser/renderer_host/websocket_host.h"
16 #include "content/common/websocket.h"
17 #include "content/common/websocket_messages.h"
18 #include "ipc/ipc_message.h"
19 #include "net/websockets/websocket_errors.h"
20 #include "testing/gtest/include/gtest/gtest.h"
22 #include "url/origin.h"
27 // This number is unlikely to occur by chance.
28 static const int kMagicRenderProcessId
= 506116062;
30 class WebSocketDispatcherHostTest
;
32 // A mock of WebsocketHost which records received messages.
33 class MockWebSocketHost
: public WebSocketHost
{
35 MockWebSocketHost(int routing_id
,
36 WebSocketDispatcherHost
* dispatcher
,
37 net::URLRequestContext
* url_request_context
,
38 base::TimeDelta delay
,
39 WebSocketDispatcherHostTest
* owner
);
41 ~MockWebSocketHost() override
{}
43 bool OnMessageReceived(const IPC::Message
& message
) override
{
44 received_messages_
.push_back(message
);
45 switch (message
.type()) {
46 case WebSocketMsg_DropChannel::ID
:
47 // Needed for PerRendererThrottlingFailedHandshakes, because without
48 // calling WebSocketHost::OnMessageReceived() (and thus
49 // WebSocketHost::OnDropChannel()), the connection stays pending and
50 // we cannot test per-renderer throttling with failed connections.
51 return WebSocketHost::OnMessageReceived(message
);
58 void GoAway() override
;
60 std::vector
<IPC::Message
> received_messages_
;
61 base::WeakPtr
<WebSocketDispatcherHostTest
> owner_
;
62 base::TimeDelta delay_
;
65 class TestingWebSocketDispatcherHost
: public WebSocketDispatcherHost
{
67 TestingWebSocketDispatcherHost(
69 const GetRequestContextCallback
& get_context_callback
,
70 const WebSocketHostFactory
& websocket_host_factory
)
71 : WebSocketDispatcherHost(process_id
,
73 websocket_host_factory
) {}
75 // This is needed because BrowserMessageFilter::Send() tries post the task to
76 // the IO thread, which doesn't exist in the context of these tests.
77 bool Send(IPC::Message
* message
) override
{
82 using WebSocketDispatcherHost::num_pending_connections
;
83 using WebSocketDispatcherHost::num_failed_connections
;
84 using WebSocketDispatcherHost::num_succeeded_connections
;
87 ~TestingWebSocketDispatcherHost() override
{}
90 class WebSocketDispatcherHostTest
: public ::testing::Test
{
92 WebSocketDispatcherHostTest()
93 : next_routing_id_(123),
94 weak_ptr_factory_(this) {
95 dispatcher_host_
= new TestingWebSocketDispatcherHost(
96 kMagicRenderProcessId
,
97 base::Bind(&WebSocketDispatcherHostTest::OnGetRequestContext
,
98 base::Unretained(this)),
99 base::Bind(&WebSocketDispatcherHostTest::CreateWebSocketHost
,
100 base::Unretained(this)));
103 ~WebSocketDispatcherHostTest() override
{
104 // We need to invalidate the issued WeakPtrs at the beginning of the
105 // destructor in order not to access destructed member variables.
106 weak_ptr_factory_
.InvalidateWeakPtrs();
109 void GoAway(int routing_id
) {
110 gone_hosts_
.push_back(routing_id
);
113 base::WeakPtr
<WebSocketDispatcherHostTest
> GetWeakPtr() {
114 return weak_ptr_factory_
.GetWeakPtr();
118 // Adds |n| connections. Returns true if succeeded.
119 bool AddMultipleChannels(int number_of_channels
) {
120 GURL
socket_url("ws://example.com/test");
121 std::vector
<std::string
> requested_protocols
;
122 url::Origin
origin(GURL("http://example.com"));
123 int render_frame_id
= -3;
125 for (int i
= 0; i
< number_of_channels
; ++i
) {
126 int routing_id
= next_routing_id_
++;
127 WebSocketHostMsg_AddChannelRequest
message(
133 if (!dispatcher_host_
->OnMessageReceived(message
))
140 // Adds and cancels |n| connections. Returns true if succeeded.
141 bool AddAndCancelMultipleChannels(int number_of_channels
) {
142 GURL
socket_url("ws://example.com/test");
143 std::vector
<std::string
> requested_protocols
;
144 url::Origin
origin(GURL("http://example.com"));
145 int render_frame_id
= -3;
147 for (int i
= 0; i
< number_of_channels
; ++i
) {
148 int routing_id
= next_routing_id_
++;
149 WebSocketHostMsg_AddChannelRequest
messageAddChannelRequest(
155 if (!dispatcher_host_
->OnMessageReceived(messageAddChannelRequest
))
158 WebSocketMsg_DropChannel
messageDropChannel(
159 routing_id
, false, net::kWebSocketErrorAbnormalClosure
, "");
160 if (!dispatcher_host_
->OnMessageReceived(messageDropChannel
))
167 scoped_refptr
<TestingWebSocketDispatcherHost
> dispatcher_host_
;
169 // Stores allocated MockWebSocketHost instances. Doesn't take ownership of
171 std::vector
<MockWebSocketHost
*> mock_hosts_
;
172 std::vector
<int> gone_hosts_
;
175 net::URLRequestContext
* OnGetRequestContext() {
179 WebSocketHost
* CreateWebSocketHost(int routing_id
, base::TimeDelta delay
) {
180 MockWebSocketHost
* host
= new MockWebSocketHost(
181 routing_id
, dispatcher_host_
.get(), NULL
, delay
, this);
182 mock_hosts_
.push_back(host
);
186 base::MessageLoop message_loop_
;
188 int next_routing_id_
;
190 base::WeakPtrFactory
<WebSocketDispatcherHostTest
> weak_ptr_factory_
;
193 MockWebSocketHost::MockWebSocketHost(
195 WebSocketDispatcherHost
* dispatcher
,
196 net::URLRequestContext
* url_request_context
,
197 base::TimeDelta delay
,
198 WebSocketDispatcherHostTest
* owner
)
199 : WebSocketHost(routing_id
, dispatcher
, url_request_context
, delay
),
200 owner_(owner
->GetWeakPtr()),
203 void MockWebSocketHost::GoAway() {
205 owner_
->GoAway(routing_id());
208 TEST_F(WebSocketDispatcherHostTest
, Construct
) {
212 TEST_F(WebSocketDispatcherHostTest
, UnrelatedMessage
) {
213 IPC::Message message
;
214 EXPECT_FALSE(dispatcher_host_
->OnMessageReceived(message
));
217 TEST_F(WebSocketDispatcherHostTest
, RenderProcessIdGetter
) {
218 EXPECT_EQ(kMagicRenderProcessId
, dispatcher_host_
->render_process_id());
221 TEST_F(WebSocketDispatcherHostTest
, AddChannelRequest
) {
222 int routing_id
= 123;
223 GURL
socket_url("ws://example.com/test");
224 std::vector
<std::string
> requested_protocols
;
225 requested_protocols
.push_back("hello");
226 url::Origin
origin(GURL("http://example.com"));
227 int render_frame_id
= -2;
228 WebSocketHostMsg_AddChannelRequest
message(
229 routing_id
, socket_url
, requested_protocols
, origin
, render_frame_id
);
231 ASSERT_TRUE(dispatcher_host_
->OnMessageReceived(message
));
233 ASSERT_EQ(1U, mock_hosts_
.size());
234 MockWebSocketHost
* host
= mock_hosts_
[0];
236 ASSERT_EQ(1U, host
->received_messages_
.size());
237 const IPC::Message
& forwarded_message
= host
->received_messages_
[0];
238 EXPECT_EQ(WebSocketHostMsg_AddChannelRequest::ID
, forwarded_message
.type());
239 EXPECT_EQ(routing_id
, forwarded_message
.routing_id());
242 TEST_F(WebSocketDispatcherHostTest
, SendFrameButNoHostYet
) {
243 int routing_id
= 123;
244 std::vector
<char> data
;
245 WebSocketMsg_SendFrame
message(
246 routing_id
, true, WEB_SOCKET_MESSAGE_TYPE_TEXT
, data
);
248 // Expected to be ignored.
249 EXPECT_TRUE(dispatcher_host_
->OnMessageReceived(message
));
251 EXPECT_EQ(0U, mock_hosts_
.size());
254 TEST_F(WebSocketDispatcherHostTest
, SendFrame
) {
255 int routing_id
= 123;
257 GURL
socket_url("ws://example.com/test");
258 std::vector
<std::string
> requested_protocols
;
259 requested_protocols
.push_back("hello");
260 url::Origin
origin(GURL("http://example.com"));
261 int render_frame_id
= -2;
262 WebSocketHostMsg_AddChannelRequest
add_channel_message(
263 routing_id
, socket_url
, requested_protocols
, origin
, render_frame_id
);
265 ASSERT_TRUE(dispatcher_host_
->OnMessageReceived(add_channel_message
));
267 std::vector
<char> data
;
268 WebSocketMsg_SendFrame
send_frame_message(
269 routing_id
, true, WEB_SOCKET_MESSAGE_TYPE_TEXT
, data
);
271 EXPECT_TRUE(dispatcher_host_
->OnMessageReceived(send_frame_message
));
273 ASSERT_EQ(1U, mock_hosts_
.size());
274 MockWebSocketHost
* host
= mock_hosts_
[0];
276 ASSERT_EQ(2U, host
->received_messages_
.size());
278 const IPC::Message
& forwarded_message
= host
->received_messages_
[0];
279 EXPECT_EQ(WebSocketHostMsg_AddChannelRequest::ID
, forwarded_message
.type());
280 EXPECT_EQ(routing_id
, forwarded_message
.routing_id());
283 const IPC::Message
& forwarded_message
= host
->received_messages_
[1];
284 EXPECT_EQ(WebSocketMsg_SendFrame::ID
, forwarded_message
.type());
285 EXPECT_EQ(routing_id
, forwarded_message
.routing_id());
289 TEST_F(WebSocketDispatcherHostTest
, Destruct
) {
290 WebSocketHostMsg_AddChannelRequest
message1(
291 123, GURL("ws://example.com/test"), std::vector
<std::string
>(),
292 url::Origin(GURL("http://example.com")), -1);
293 WebSocketHostMsg_AddChannelRequest
message2(
294 456, GURL("ws://example.com/test2"), std::vector
<std::string
>(),
295 url::Origin(GURL("http://example.com")), -1);
297 ASSERT_TRUE(dispatcher_host_
->OnMessageReceived(message1
));
298 ASSERT_TRUE(dispatcher_host_
->OnMessageReceived(message2
));
300 ASSERT_EQ(2u, mock_hosts_
.size());
303 dispatcher_host_
= NULL
;
305 ASSERT_EQ(2u, gone_hosts_
.size());
306 // The gone_hosts_ ordering is not predictable because it depends on the
307 // hash_map ordering.
308 std::sort(gone_hosts_
.begin(), gone_hosts_
.end());
309 EXPECT_EQ(123, gone_hosts_
[0]);
310 EXPECT_EQ(456, gone_hosts_
[1]);
313 TEST_F(WebSocketDispatcherHostTest
, DelayFor4thPendingConnectionIsZero
) {
314 ASSERT_TRUE(AddMultipleChannels(4));
316 EXPECT_EQ(4, dispatcher_host_
->num_pending_connections());
317 EXPECT_EQ(0, dispatcher_host_
->num_failed_connections());
318 EXPECT_EQ(0, dispatcher_host_
->num_succeeded_connections());
320 ASSERT_EQ(4U, mock_hosts_
.size());
321 EXPECT_EQ(base::TimeDelta(), mock_hosts_
[3]->delay_
);
324 TEST_F(WebSocketDispatcherHostTest
, DelayFor8thPendingConnectionIsNonZero
) {
325 ASSERT_TRUE(AddMultipleChannels(8));
327 EXPECT_EQ(8, dispatcher_host_
->num_pending_connections());
328 EXPECT_EQ(0, dispatcher_host_
->num_failed_connections());
329 EXPECT_EQ(0, dispatcher_host_
->num_succeeded_connections());
331 ASSERT_EQ(8U, mock_hosts_
.size());
332 EXPECT_LT(base::TimeDelta(), mock_hosts_
[7]->delay_
);
335 TEST_F(WebSocketDispatcherHostTest
, DelayFor17thPendingConnection
) {
336 ASSERT_TRUE(AddMultipleChannels(17));
338 EXPECT_EQ(17, dispatcher_host_
->num_pending_connections());
339 EXPECT_EQ(0, dispatcher_host_
->num_failed_connections());
340 EXPECT_EQ(0, dispatcher_host_
->num_succeeded_connections());
342 ASSERT_EQ(17U, mock_hosts_
.size());
343 EXPECT_LE(base::TimeDelta::FromMilliseconds(1000), mock_hosts_
[16]->delay_
);
344 EXPECT_GE(base::TimeDelta::FromMilliseconds(5000), mock_hosts_
[16]->delay_
);
347 // The 256th connection is rejected by per-renderer WebSocket throttling.
348 // This is not counted as a failure.
349 TEST_F(WebSocketDispatcherHostTest
, Rejects256thPendingConnection
) {
350 ASSERT_TRUE(AddMultipleChannels(256));
352 EXPECT_EQ(255, dispatcher_host_
->num_pending_connections());
353 EXPECT_EQ(0, dispatcher_host_
->num_failed_connections());
354 EXPECT_EQ(0, dispatcher_host_
->num_succeeded_connections());
356 ASSERT_EQ(255U, mock_hosts_
.size());
359 TEST_F(WebSocketDispatcherHostTest
, DelayIsZeroAfter3FailedConnections
) {
360 ASSERT_TRUE(AddAndCancelMultipleChannels(3));
362 EXPECT_EQ(0, dispatcher_host_
->num_pending_connections());
363 EXPECT_EQ(3, dispatcher_host_
->num_failed_connections());
364 EXPECT_EQ(0, dispatcher_host_
->num_succeeded_connections());
366 ASSERT_TRUE(AddMultipleChannels(1));
368 ASSERT_EQ(4U, mock_hosts_
.size());
369 EXPECT_EQ(base::TimeDelta(), mock_hosts_
[3]->delay_
);
372 TEST_F(WebSocketDispatcherHostTest
, DelayIsNonZeroAfter7FailedConnections
) {
373 ASSERT_TRUE(AddAndCancelMultipleChannels(7));
375 EXPECT_EQ(0, dispatcher_host_
->num_pending_connections());
376 EXPECT_EQ(7, dispatcher_host_
->num_failed_connections());
377 EXPECT_EQ(0, dispatcher_host_
->num_succeeded_connections());
379 ASSERT_TRUE(AddMultipleChannels(1));
381 ASSERT_EQ(8U, mock_hosts_
.size());
382 EXPECT_LT(base::TimeDelta(), mock_hosts_
[7]->delay_
);
385 TEST_F(WebSocketDispatcherHostTest
, DelayAfter16FailedConnections
) {
386 ASSERT_TRUE(AddAndCancelMultipleChannels(16));
388 EXPECT_EQ(0, dispatcher_host_
->num_pending_connections());
389 EXPECT_EQ(16, dispatcher_host_
->num_failed_connections());
390 EXPECT_EQ(0, dispatcher_host_
->num_succeeded_connections());
392 ASSERT_TRUE(AddMultipleChannels(1));
394 ASSERT_EQ(17U, mock_hosts_
.size());
395 EXPECT_LE(base::TimeDelta::FromMilliseconds(1000), mock_hosts_
[16]->delay_
);
396 EXPECT_GE(base::TimeDelta::FromMilliseconds(5000), mock_hosts_
[16]->delay_
);
399 TEST_F(WebSocketDispatcherHostTest
, NotRejectedAfter255FailedConnections
) {
400 ASSERT_TRUE(AddAndCancelMultipleChannels(255));
402 EXPECT_EQ(0, dispatcher_host_
->num_pending_connections());
403 EXPECT_EQ(255, dispatcher_host_
->num_failed_connections());
404 EXPECT_EQ(0, dispatcher_host_
->num_succeeded_connections());
406 ASSERT_TRUE(AddMultipleChannels(1));
408 EXPECT_EQ(1, dispatcher_host_
->num_pending_connections());
409 EXPECT_EQ(255, dispatcher_host_
->num_failed_connections());
410 EXPECT_EQ(0, dispatcher_host_
->num_succeeded_connections());
413 // This is a regression test for https://crrev.com/998173003/.
414 TEST_F(WebSocketDispatcherHostTest
, InvalidScheme
) {
415 int routing_id
= 123;
416 GURL
socket_url("http://example.com/test");
417 std::vector
<std::string
> requested_protocols
;
418 requested_protocols
.push_back("hello");
419 url::Origin
origin(GURL("http://example.com"));
420 int render_frame_id
= -2;
421 WebSocketHostMsg_AddChannelRequest
message(
422 routing_id
, socket_url
, requested_protocols
, origin
, render_frame_id
);
424 ASSERT_TRUE(dispatcher_host_
->OnMessageReceived(message
));
426 ASSERT_EQ(1U, mock_hosts_
.size());
427 MockWebSocketHost
* host
= mock_hosts_
[0];
429 // Tests that WebSocketHost::OnMessageReceived() doesn't cause a crash and
430 // the connection with an invalid scheme fails here.
431 // We call WebSocketHost::OnMessageReceived() here explicitly because
432 // MockWebSocketHost does not call WebSocketHost::OnMessageReceived() for
433 // WebSocketHostMsg_AddChannelRequest.
434 host
->WebSocketHost::OnMessageReceived(message
);
436 EXPECT_EQ(0, dispatcher_host_
->num_pending_connections());
437 EXPECT_EQ(1, dispatcher_host_
->num_failed_connections());
438 EXPECT_EQ(0, dispatcher_host_
->num_succeeded_connections());
442 } // namespace content