1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "content/browser/renderer_host/websocket_dispatcher_host.h"
10 #include "base/bind.h"
11 #include "base/bind_helpers.h"
12 #include "base/memory/ref_counted.h"
13 #include "base/memory/weak_ptr.h"
14 #include "content/browser/renderer_host/websocket_host.h"
15 #include "content/common/websocket.h"
16 #include "content/common/websocket_messages.h"
17 #include "ipc/ipc_message.h"
18 #include "testing/gtest/include/gtest/gtest.h"
20 #include "url/origin.h"
25 // This number is unlikely to occur by chance.
26 static const int kMagicRenderProcessId
= 506116062;
28 class WebSocketDispatcherHostTest
;
30 // A mock of WebsocketHost which records received messages.
31 class MockWebSocketHost
: public WebSocketHost
{
33 MockWebSocketHost(int routing_id
,
34 WebSocketDispatcherHost
* dispatcher
,
35 net::URLRequestContext
* url_request_context
,
36 WebSocketDispatcherHostTest
* owner
);
38 virtual ~MockWebSocketHost() {}
40 virtual bool OnMessageReceived(const IPC::Message
& message
) OVERRIDE
{
41 received_messages_
.push_back(message
);
45 virtual void GoAway() OVERRIDE
;
47 std::vector
<IPC::Message
> received_messages_
;
48 base::WeakPtr
<WebSocketDispatcherHostTest
> owner_
;
51 class WebSocketDispatcherHostTest
: public ::testing::Test
{
53 WebSocketDispatcherHostTest()
54 : weak_ptr_factory_(this) {
55 dispatcher_host_
= new WebSocketDispatcherHost(
56 kMagicRenderProcessId
,
57 base::Bind(&WebSocketDispatcherHostTest::OnGetRequestContext
,
58 base::Unretained(this)),
59 base::Bind(&WebSocketDispatcherHostTest::CreateWebSocketHost
,
60 base::Unretained(this)));
63 virtual ~WebSocketDispatcherHostTest() {
64 // We need to invalidate the issued WeakPtrs at the beginning of the
65 // destructor in order not to access destructed member variables.
66 weak_ptr_factory_
.InvalidateWeakPtrs();
69 void GoAway(int routing_id
) {
70 gone_hosts_
.push_back(routing_id
);
73 base::WeakPtr
<WebSocketDispatcherHostTest
> GetWeakPtr() {
74 return weak_ptr_factory_
.GetWeakPtr();
78 scoped_refptr
<WebSocketDispatcherHost
> dispatcher_host_
;
80 // Stores allocated MockWebSocketHost instances. Doesn't take ownership of
82 std::vector
<MockWebSocketHost
*> mock_hosts_
;
83 std::vector
<int> gone_hosts_
;
85 base::WeakPtrFactory
<WebSocketDispatcherHostTest
> weak_ptr_factory_
;
88 net::URLRequestContext
* OnGetRequestContext() {
92 WebSocketHost
* CreateWebSocketHost(int routing_id
) {
93 MockWebSocketHost
* host
=
94 new MockWebSocketHost(routing_id
, dispatcher_host_
.get(), NULL
, this);
95 mock_hosts_
.push_back(host
);
100 MockWebSocketHost::MockWebSocketHost(
102 WebSocketDispatcherHost
* dispatcher
,
103 net::URLRequestContext
* url_request_context
,
104 WebSocketDispatcherHostTest
* owner
)
105 : WebSocketHost(routing_id
, dispatcher
, url_request_context
),
106 owner_(owner
->GetWeakPtr()) {}
108 void MockWebSocketHost::GoAway() {
110 owner_
->GoAway(routing_id());
113 TEST_F(WebSocketDispatcherHostTest
, Construct
) {
117 TEST_F(WebSocketDispatcherHostTest
, UnrelatedMessage
) {
118 IPC::Message message
;
119 EXPECT_FALSE(dispatcher_host_
->OnMessageReceived(message
));
122 TEST_F(WebSocketDispatcherHostTest
, RenderProcessIdGetter
) {
123 EXPECT_EQ(kMagicRenderProcessId
, dispatcher_host_
->render_process_id());
126 TEST_F(WebSocketDispatcherHostTest
, AddChannelRequest
) {
127 int routing_id
= 123;
128 GURL
socket_url("ws://example.com/test");
129 std::vector
<std::string
> requested_protocols
;
130 requested_protocols
.push_back("hello");
131 url::Origin
origin("http://example.com/test");
132 int render_frame_id
= -2;
133 WebSocketHostMsg_AddChannelRequest
message(
134 routing_id
, socket_url
, requested_protocols
, origin
, render_frame_id
);
136 ASSERT_TRUE(dispatcher_host_
->OnMessageReceived(message
));
138 ASSERT_EQ(1U, mock_hosts_
.size());
139 MockWebSocketHost
* host
= mock_hosts_
[0];
141 ASSERT_EQ(1U, host
->received_messages_
.size());
142 const IPC::Message
& forwarded_message
= host
->received_messages_
[0];
143 EXPECT_EQ(WebSocketHostMsg_AddChannelRequest::ID
, forwarded_message
.type());
144 EXPECT_EQ(routing_id
, forwarded_message
.routing_id());
147 TEST_F(WebSocketDispatcherHostTest
, SendFrameButNoHostYet
) {
148 int routing_id
= 123;
149 std::vector
<char> data
;
150 WebSocketMsg_SendFrame
message(
151 routing_id
, true, WEB_SOCKET_MESSAGE_TYPE_TEXT
, data
);
153 // Expected to be ignored.
154 EXPECT_TRUE(dispatcher_host_
->OnMessageReceived(message
));
156 EXPECT_EQ(0U, mock_hosts_
.size());
159 TEST_F(WebSocketDispatcherHostTest
, SendFrame
) {
160 int routing_id
= 123;
162 GURL
socket_url("ws://example.com/test");
163 std::vector
<std::string
> requested_protocols
;
164 requested_protocols
.push_back("hello");
165 url::Origin
origin("http://example.com/test");
166 int render_frame_id
= -2;
167 WebSocketHostMsg_AddChannelRequest
add_channel_message(
168 routing_id
, socket_url
, requested_protocols
, origin
, render_frame_id
);
170 ASSERT_TRUE(dispatcher_host_
->OnMessageReceived(add_channel_message
));
172 std::vector
<char> data
;
173 WebSocketMsg_SendFrame
send_frame_message(
174 routing_id
, true, WEB_SOCKET_MESSAGE_TYPE_TEXT
, data
);
176 EXPECT_TRUE(dispatcher_host_
->OnMessageReceived(send_frame_message
));
178 ASSERT_EQ(1U, mock_hosts_
.size());
179 MockWebSocketHost
* host
= mock_hosts_
[0];
181 ASSERT_EQ(2U, host
->received_messages_
.size());
183 const IPC::Message
& forwarded_message
= host
->received_messages_
[0];
184 EXPECT_EQ(WebSocketHostMsg_AddChannelRequest::ID
, forwarded_message
.type());
185 EXPECT_EQ(routing_id
, forwarded_message
.routing_id());
188 const IPC::Message
& forwarded_message
= host
->received_messages_
[1];
189 EXPECT_EQ(WebSocketMsg_SendFrame::ID
, forwarded_message
.type());
190 EXPECT_EQ(routing_id
, forwarded_message
.routing_id());
194 TEST_F(WebSocketDispatcherHostTest
, Destruct
) {
195 WebSocketHostMsg_AddChannelRequest
message1(
196 123, GURL("ws://example.com/test"), std::vector
<std::string
>(),
197 url::Origin("http://example.com"), -1);
198 WebSocketHostMsg_AddChannelRequest
message2(
199 456, GURL("ws://example.com/test2"), std::vector
<std::string
>(),
200 url::Origin("http://example.com"), -1);
202 ASSERT_TRUE(dispatcher_host_
->OnMessageReceived(message1
));
203 ASSERT_TRUE(dispatcher_host_
->OnMessageReceived(message2
));
205 ASSERT_EQ(2u, mock_hosts_
.size());
208 dispatcher_host_
= NULL
;
210 ASSERT_EQ(2u, gone_hosts_
.size());
211 // The gone_hosts_ ordering is not predictable because it depends on the
212 // hash_map ordering.
213 std::sort(gone_hosts_
.begin(), gone_hosts_
.end());
214 EXPECT_EQ(123, gone_hosts_
[0]);
215 EXPECT_EQ(456, gone_hosts_
[1]);
219 } // namespace content