1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "content/browser/renderer_host/websocket_host.h"
7 #include "base/basictypes.h"
8 #include "base/strings/string_util.h"
9 #include "content/browser/renderer_host/websocket_dispatcher_host.h"
10 #include "content/common/websocket_messages.h"
11 #include "ipc/ipc_message_macros.h"
12 #include "net/http/http_request_headers.h"
13 #include "net/http/http_response_headers.h"
14 #include "net/http/http_util.h"
15 #include "net/websockets/websocket_channel.h"
16 #include "net/websockets/websocket_event_interface.h"
17 #include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode
18 #include "net/websockets/websocket_handshake_request_info.h"
19 #include "net/websockets/websocket_handshake_response_info.h"
20 #include "url/origin.h"
26 typedef net::WebSocketEventInterface::ChannelState ChannelState
;
28 // Convert a content::WebSocketMessageType to a
29 // net::WebSocketFrameHeader::OpCode
30 net::WebSocketFrameHeader::OpCode
MessageTypeToOpCode(
31 WebSocketMessageType type
) {
32 DCHECK(type
== WEB_SOCKET_MESSAGE_TYPE_CONTINUATION
||
33 type
== WEB_SOCKET_MESSAGE_TYPE_TEXT
||
34 type
== WEB_SOCKET_MESSAGE_TYPE_BINARY
);
35 typedef net::WebSocketFrameHeader::OpCode OpCode
;
36 // These compile asserts verify that the same underlying values are used for
37 // both types, so we can simply cast between them.
38 COMPILE_ASSERT(static_cast<OpCode
>(WEB_SOCKET_MESSAGE_TYPE_CONTINUATION
) ==
39 net::WebSocketFrameHeader::kOpCodeContinuation
,
40 enum_values_must_match_for_opcode_continuation
);
41 COMPILE_ASSERT(static_cast<OpCode
>(WEB_SOCKET_MESSAGE_TYPE_TEXT
) ==
42 net::WebSocketFrameHeader::kOpCodeText
,
43 enum_values_must_match_for_opcode_text
);
44 COMPILE_ASSERT(static_cast<OpCode
>(WEB_SOCKET_MESSAGE_TYPE_BINARY
) ==
45 net::WebSocketFrameHeader::kOpCodeBinary
,
46 enum_values_must_match_for_opcode_binary
);
47 return static_cast<OpCode
>(type
);
50 WebSocketMessageType
OpCodeToMessageType(
51 net::WebSocketFrameHeader::OpCode opCode
) {
52 DCHECK(opCode
== net::WebSocketFrameHeader::kOpCodeContinuation
||
53 opCode
== net::WebSocketFrameHeader::kOpCodeText
||
54 opCode
== net::WebSocketFrameHeader::kOpCodeBinary
);
55 // This cast is guaranteed valid by the COMPILE_ASSERT() statements above.
56 return static_cast<WebSocketMessageType
>(opCode
);
59 ChannelState
StateCast(WebSocketDispatcherHost::WebSocketHostState host_state
) {
60 const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_ALIVE
=
61 WebSocketDispatcherHost::WEBSOCKET_HOST_ALIVE
;
62 const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_DELETED
=
63 WebSocketDispatcherHost::WEBSOCKET_HOST_DELETED
;
65 DCHECK(host_state
== WEBSOCKET_HOST_ALIVE
||
66 host_state
== WEBSOCKET_HOST_DELETED
);
67 // These compile asserts verify that we can get away with using static_cast<>
68 // for the conversion.
69 COMPILE_ASSERT(static_cast<ChannelState
>(WEBSOCKET_HOST_ALIVE
) ==
70 net::WebSocketEventInterface::CHANNEL_ALIVE
,
71 enum_values_must_match_for_state_alive
);
72 COMPILE_ASSERT(static_cast<ChannelState
>(WEBSOCKET_HOST_DELETED
) ==
73 net::WebSocketEventInterface::CHANNEL_DELETED
,
74 enum_values_must_match_for_state_deleted
);
75 return static_cast<ChannelState
>(host_state
);
78 // Implementation of net::WebSocketEventInterface. Receives events from our
79 // WebSocketChannel object. Each event is translated to an IPC and sent to the
80 // renderer or child process via WebSocketDispatcherHost.
81 class WebSocketEventHandler
: public net::WebSocketEventInterface
{
83 WebSocketEventHandler(WebSocketDispatcherHost
* dispatcher
, int routing_id
);
84 virtual ~WebSocketEventHandler();
86 // net::WebSocketEventInterface implementation
88 virtual ChannelState
OnAddChannelResponse(
90 const std::string
& selected_subprotocol
,
91 const std::string
& extensions
) OVERRIDE
;
92 virtual ChannelState
OnDataFrame(bool fin
,
93 WebSocketMessageType type
,
94 const std::vector
<char>& data
) OVERRIDE
;
95 virtual ChannelState
OnClosingHandshake() OVERRIDE
;
96 virtual ChannelState
OnFlowControl(int64 quota
) OVERRIDE
;
97 virtual ChannelState
OnDropChannel(bool was_clean
,
99 const std::string
& reason
) OVERRIDE
;
100 virtual ChannelState
OnFailChannel(const std::string
& message
) OVERRIDE
;
101 virtual ChannelState
OnStartOpeningHandshake(
102 scoped_ptr
<net::WebSocketHandshakeRequestInfo
> request
) OVERRIDE
;
103 virtual ChannelState
OnFinishOpeningHandshake(
104 scoped_ptr
<net::WebSocketHandshakeResponseInfo
> response
) OVERRIDE
;
107 WebSocketDispatcherHost
* const dispatcher_
;
108 const int routing_id_
;
110 DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler
);
113 WebSocketEventHandler::WebSocketEventHandler(
114 WebSocketDispatcherHost
* dispatcher
,
116 : dispatcher_(dispatcher
), routing_id_(routing_id
) {}
118 WebSocketEventHandler::~WebSocketEventHandler() {
119 DVLOG(1) << "WebSocketEventHandler destroyed routing_id=" << routing_id_
;
122 ChannelState
WebSocketEventHandler::OnAddChannelResponse(
124 const std::string
& selected_protocol
,
125 const std::string
& extensions
) {
126 DVLOG(3) << "WebSocketEventHandler::OnAddChannelResponse"
127 << " routing_id=" << routing_id_
<< " fail=" << fail
128 << " selected_protocol=\"" << selected_protocol
<< "\""
129 << " extensions=\"" << extensions
<< "\"";
131 return StateCast(dispatcher_
->SendAddChannelResponse(
132 routing_id_
, fail
, selected_protocol
, extensions
));
135 ChannelState
WebSocketEventHandler::OnDataFrame(
137 net::WebSocketFrameHeader::OpCode type
,
138 const std::vector
<char>& data
) {
139 DVLOG(3) << "WebSocketEventHandler::OnDataFrame"
140 << " routing_id=" << routing_id_
<< " fin=" << fin
141 << " type=" << type
<< " data is " << data
.size() << " bytes";
143 return StateCast(dispatcher_
->SendFrame(
144 routing_id_
, fin
, OpCodeToMessageType(type
), data
));
147 ChannelState
WebSocketEventHandler::OnClosingHandshake() {
148 DVLOG(3) << "WebSocketEventHandler::OnClosingHandshake"
149 << " routing_id=" << routing_id_
;
151 return StateCast(dispatcher_
->NotifyClosingHandshake(routing_id_
));
154 ChannelState
WebSocketEventHandler::OnFlowControl(int64 quota
) {
155 DVLOG(3) << "WebSocketEventHandler::OnFlowControl"
156 << " routing_id=" << routing_id_
<< " quota=" << quota
;
158 return StateCast(dispatcher_
->SendFlowControl(routing_id_
, quota
));
161 ChannelState
WebSocketEventHandler::OnDropChannel(bool was_clean
,
163 const std::string
& reason
) {
164 DVLOG(3) << "WebSocketEventHandler::OnDropChannel"
165 << " routing_id=" << routing_id_
<< " was_clean=" << was_clean
166 << " code=" << code
<< " reason=\"" << reason
<< "\"";
169 dispatcher_
->DoDropChannel(routing_id_
, was_clean
, code
, reason
));
172 ChannelState
WebSocketEventHandler::OnFailChannel(const std::string
& message
) {
173 DVLOG(3) << "WebSocketEventHandler::OnFailChannel"
174 << " routing_id=" << routing_id_
175 << " message=\"" << message
<< "\"";
177 return StateCast(dispatcher_
->NotifyFailure(routing_id_
, message
));
180 ChannelState
WebSocketEventHandler::OnStartOpeningHandshake(
181 scoped_ptr
<net::WebSocketHandshakeRequestInfo
> request
) {
182 bool should_send
= dispatcher_
->CanReadRawCookies();
183 DVLOG(3) << "WebSocketEventHandler::OnStartOpeningHandshake "
184 << "should_send=" << should_send
;
187 return WebSocketEventInterface::CHANNEL_ALIVE
;
189 WebSocketHandshakeRequest request_to_pass
;
190 request_to_pass
.url
.Swap(&request
->url
);
191 net::HttpRequestHeaders::Iterator
it(request
->headers
);
193 request_to_pass
.headers
.push_back(std::make_pair(it
.name(), it
.value()));
194 request_to_pass
.headers_text
=
195 base::StringPrintf("GET %s HTTP/1.1\r\n",
196 request_to_pass
.url
.spec().c_str()) +
197 request
->headers
.ToString();
198 request_to_pass
.request_time
= request
->request_time
;
200 return StateCast(dispatcher_
->NotifyStartOpeningHandshake(routing_id_
,
204 ChannelState
WebSocketEventHandler::OnFinishOpeningHandshake(
205 scoped_ptr
<net::WebSocketHandshakeResponseInfo
> response
) {
206 bool should_send
= dispatcher_
->CanReadRawCookies();
207 DVLOG(3) << "WebSocketEventHandler::OnFinishOpeningHandshake "
208 << "should_send=" << should_send
;
211 return WebSocketEventInterface::CHANNEL_ALIVE
;
213 WebSocketHandshakeResponse response_to_pass
;
214 response_to_pass
.url
.Swap(&response
->url
);
215 response_to_pass
.status_code
= response
->status_code
;
216 response_to_pass
.status_text
.swap(response
->status_text
);
218 std::string name
, value
;
219 while (response
->headers
->EnumerateHeaderLines(&iter
, &name
, &value
))
220 response_to_pass
.headers
.push_back(std::make_pair(name
, value
));
221 response_to_pass
.headers_text
=
222 net::HttpUtil::ConvertHeadersBackToHTTPResponse(
223 response
->headers
->raw_headers());
224 response_to_pass
.response_time
= response
->response_time
;
226 return StateCast(dispatcher_
->NotifyFinishOpeningHandshake(routing_id_
,
232 WebSocketHost::WebSocketHost(int routing_id
,
233 WebSocketDispatcherHost
* dispatcher
,
234 net::URLRequestContext
* url_request_context
)
235 : routing_id_(routing_id
) {
236 DVLOG(1) << "WebSocketHost: created routing_id=" << routing_id
;
238 scoped_ptr
<net::WebSocketEventInterface
> event_interface(
239 new WebSocketEventHandler(dispatcher
, routing_id
));
241 new net::WebSocketChannel(event_interface
.Pass(), url_request_context
));
244 WebSocketHost::~WebSocketHost() {}
246 bool WebSocketHost::OnMessageReceived(const IPC::Message
& message
,
247 bool* message_was_ok
) {
249 IPC_BEGIN_MESSAGE_MAP_EX(WebSocketHost
, message
, *message_was_ok
)
250 IPC_MESSAGE_HANDLER(WebSocketHostMsg_AddChannelRequest
, OnAddChannelRequest
)
251 IPC_MESSAGE_HANDLER(WebSocketMsg_SendFrame
, OnSendFrame
)
252 IPC_MESSAGE_HANDLER(WebSocketMsg_FlowControl
, OnFlowControl
)
253 IPC_MESSAGE_HANDLER(WebSocketMsg_DropChannel
, OnDropChannel
)
254 IPC_MESSAGE_UNHANDLED(handled
= false)
255 IPC_END_MESSAGE_MAP_EX()
259 void WebSocketHost::OnAddChannelRequest(
260 const GURL
& socket_url
,
261 const std::vector
<std::string
>& requested_protocols
,
262 const url::Origin
& origin
) {
263 DVLOG(3) << "WebSocketHost::OnAddChannelRequest"
264 << " routing_id=" << routing_id_
<< " socket_url=\"" << socket_url
265 << "\" requested_protocols=\""
266 << JoinString(requested_protocols
, ", ") << "\" origin=\""
267 << origin
.string() << "\"";
269 channel_
->SendAddChannelRequest(
270 socket_url
, requested_protocols
, origin
);
273 void WebSocketHost::OnSendFrame(bool fin
,
274 WebSocketMessageType type
,
275 const std::vector
<char>& data
) {
276 DVLOG(3) << "WebSocketHost::OnSendFrame"
277 << " routing_id=" << routing_id_
<< " fin=" << fin
278 << " type=" << type
<< " data is " << data
.size() << " bytes";
280 channel_
->SendFrame(fin
, MessageTypeToOpCode(type
), data
);
283 void WebSocketHost::OnFlowControl(int64 quota
) {
284 DVLOG(3) << "WebSocketHost::OnFlowControl"
285 << " routing_id=" << routing_id_
<< " quota=" << quota
;
287 channel_
->SendFlowControl(quota
);
290 void WebSocketHost::OnDropChannel(bool was_clean
,
292 const std::string
& reason
) {
293 DVLOG(3) << "WebSocketHost::OnDropChannel"
294 << " routing_id=" << routing_id_
<< " was_clean=" << was_clean
295 << " code=" << code
<< " reason=\"" << reason
<< "\"";
297 // TODO(yhirano): Handle |was_clean| appropriately.
298 channel_
->StartClosingHandshake(code
, reason
);
301 } // namespace content