1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "content/browser/renderer_host/websocket_host.h"
7 #include "base/basictypes.h"
8 #include "base/location.h"
9 #include "base/memory/weak_ptr.h"
10 #include "base/single_thread_task_runner.h"
11 #include "base/strings/string_util.h"
12 #include "base/thread_task_runner_handle.h"
13 #include "content/browser/renderer_host/websocket_dispatcher_host.h"
14 #include "content/browser/ssl/ssl_error_handler.h"
15 #include "content/browser/ssl/ssl_manager.h"
16 #include "content/common/websocket_messages.h"
17 #include "ipc/ipc_message_macros.h"
18 #include "net/http/http_request_headers.h"
19 #include "net/http/http_response_headers.h"
20 #include "net/http/http_util.h"
21 #include "net/ssl/ssl_info.h"
22 #include "net/websockets/websocket_channel.h"
23 #include "net/websockets/websocket_errors.h"
24 #include "net/websockets/websocket_event_interface.h"
25 #include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode
26 #include "net/websockets/websocket_handshake_request_info.h"
27 #include "net/websockets/websocket_handshake_response_info.h"
28 #include "url/origin.h"
34 typedef net::WebSocketEventInterface::ChannelState ChannelState
;
36 // Convert a content::WebSocketMessageType to a
37 // net::WebSocketFrameHeader::OpCode
38 net::WebSocketFrameHeader::OpCode
MessageTypeToOpCode(
39 WebSocketMessageType type
) {
40 DCHECK(type
== WEB_SOCKET_MESSAGE_TYPE_CONTINUATION
||
41 type
== WEB_SOCKET_MESSAGE_TYPE_TEXT
||
42 type
== WEB_SOCKET_MESSAGE_TYPE_BINARY
);
43 typedef net::WebSocketFrameHeader::OpCode OpCode
;
44 // These compile asserts verify that the same underlying values are used for
45 // both types, so we can simply cast between them.
46 static_assert(static_cast<OpCode
>(WEB_SOCKET_MESSAGE_TYPE_CONTINUATION
) ==
47 net::WebSocketFrameHeader::kOpCodeContinuation
,
48 "enum values must match for opcode continuation");
49 static_assert(static_cast<OpCode
>(WEB_SOCKET_MESSAGE_TYPE_TEXT
) ==
50 net::WebSocketFrameHeader::kOpCodeText
,
51 "enum values must match for opcode text");
52 static_assert(static_cast<OpCode
>(WEB_SOCKET_MESSAGE_TYPE_BINARY
) ==
53 net::WebSocketFrameHeader::kOpCodeBinary
,
54 "enum values must match for opcode binary");
55 return static_cast<OpCode
>(type
);
58 WebSocketMessageType
OpCodeToMessageType(
59 net::WebSocketFrameHeader::OpCode opCode
) {
60 DCHECK(opCode
== net::WebSocketFrameHeader::kOpCodeContinuation
||
61 opCode
== net::WebSocketFrameHeader::kOpCodeText
||
62 opCode
== net::WebSocketFrameHeader::kOpCodeBinary
);
63 // This cast is guaranteed valid by the static_assert() statements above.
64 return static_cast<WebSocketMessageType
>(opCode
);
67 ChannelState
StateCast(WebSocketDispatcherHost::WebSocketHostState host_state
) {
68 const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_ALIVE
=
69 WebSocketDispatcherHost::WEBSOCKET_HOST_ALIVE
;
70 const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_DELETED
=
71 WebSocketDispatcherHost::WEBSOCKET_HOST_DELETED
;
73 DCHECK(host_state
== WEBSOCKET_HOST_ALIVE
||
74 host_state
== WEBSOCKET_HOST_DELETED
);
75 // These compile asserts verify that we can get away with using static_cast<>
76 // for the conversion.
77 static_assert(static_cast<ChannelState
>(WEBSOCKET_HOST_ALIVE
) ==
78 net::WebSocketEventInterface::CHANNEL_ALIVE
,
79 "enum values must match for state_alive");
80 static_assert(static_cast<ChannelState
>(WEBSOCKET_HOST_DELETED
) ==
81 net::WebSocketEventInterface::CHANNEL_DELETED
,
82 "enum values must match for state_deleted");
83 return static_cast<ChannelState
>(host_state
);
86 // Implementation of net::WebSocketEventInterface. Receives events from our
87 // WebSocketChannel object. Each event is translated to an IPC and sent to the
88 // renderer or child process via WebSocketDispatcherHost.
89 class WebSocketEventHandler
: public net::WebSocketEventInterface
{
91 WebSocketEventHandler(WebSocketDispatcherHost
* dispatcher
,
94 ~WebSocketEventHandler() override
;
96 // net::WebSocketEventInterface implementation
98 ChannelState
OnAddChannelResponse(const std::string
& selected_subprotocol
,
99 const std::string
& extensions
) override
;
100 ChannelState
OnDataFrame(bool fin
,
101 WebSocketMessageType type
,
102 const std::vector
<char>& data
) override
;
103 ChannelState
OnClosingHandshake() override
;
104 ChannelState
OnFlowControl(int64 quota
) override
;
105 ChannelState
OnDropChannel(bool was_clean
,
107 const std::string
& reason
) override
;
108 ChannelState
OnFailChannel(const std::string
& message
) override
;
109 ChannelState
OnStartOpeningHandshake(
110 scoped_ptr
<net::WebSocketHandshakeRequestInfo
> request
) override
;
111 ChannelState
OnFinishOpeningHandshake(
112 scoped_ptr
<net::WebSocketHandshakeResponseInfo
> response
) override
;
113 ChannelState
OnSSLCertificateError(
114 scoped_ptr
<net::WebSocketEventInterface::SSLErrorCallbacks
> callbacks
,
116 const net::SSLInfo
& ssl_info
,
117 bool fatal
) override
;
120 class SSLErrorHandlerDelegate
: public SSLErrorHandler::Delegate
{
122 SSLErrorHandlerDelegate(
123 scoped_ptr
<net::WebSocketEventInterface::SSLErrorCallbacks
> callbacks
);
124 ~SSLErrorHandlerDelegate() override
;
126 base::WeakPtr
<SSLErrorHandler::Delegate
> GetWeakPtr();
128 // SSLErrorHandler::Delegate methods
129 void CancelSSLRequest(int error
, const net::SSLInfo
* ssl_info
) override
;
130 void ContinueSSLRequest() override
;
133 scoped_ptr
<net::WebSocketEventInterface::SSLErrorCallbacks
> callbacks_
;
134 base::WeakPtrFactory
<SSLErrorHandlerDelegate
> weak_ptr_factory_
;
136 DISALLOW_COPY_AND_ASSIGN(SSLErrorHandlerDelegate
);
139 WebSocketDispatcherHost
* const dispatcher_
;
140 const int routing_id_
;
141 const int render_frame_id_
;
142 scoped_ptr
<SSLErrorHandlerDelegate
> ssl_error_handler_delegate_
;
144 DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler
);
147 WebSocketEventHandler::WebSocketEventHandler(
148 WebSocketDispatcherHost
* dispatcher
,
151 : dispatcher_(dispatcher
),
152 routing_id_(routing_id
),
153 render_frame_id_(render_frame_id
) {
156 WebSocketEventHandler::~WebSocketEventHandler() {
157 DVLOG(1) << "WebSocketEventHandler destroyed routing_id=" << routing_id_
;
160 ChannelState
WebSocketEventHandler::OnAddChannelResponse(
161 const std::string
& selected_protocol
,
162 const std::string
& extensions
) {
163 DVLOG(3) << "WebSocketEventHandler::OnAddChannelResponse"
164 << " routing_id=" << routing_id_
165 << " selected_protocol=\"" << selected_protocol
<< "\""
166 << " extensions=\"" << extensions
<< "\"";
168 return StateCast(dispatcher_
->SendAddChannelResponse(
169 routing_id_
, selected_protocol
, extensions
));
172 ChannelState
WebSocketEventHandler::OnDataFrame(
174 net::WebSocketFrameHeader::OpCode type
,
175 const std::vector
<char>& data
) {
176 DVLOG(3) << "WebSocketEventHandler::OnDataFrame"
177 << " routing_id=" << routing_id_
<< " fin=" << fin
178 << " type=" << type
<< " data is " << data
.size() << " bytes";
180 return StateCast(dispatcher_
->SendFrame(
181 routing_id_
, fin
, OpCodeToMessageType(type
), data
));
184 ChannelState
WebSocketEventHandler::OnClosingHandshake() {
185 DVLOG(3) << "WebSocketEventHandler::OnClosingHandshake"
186 << " routing_id=" << routing_id_
;
188 return StateCast(dispatcher_
->NotifyClosingHandshake(routing_id_
));
191 ChannelState
WebSocketEventHandler::OnFlowControl(int64 quota
) {
192 DVLOG(3) << "WebSocketEventHandler::OnFlowControl"
193 << " routing_id=" << routing_id_
<< " quota=" << quota
;
195 return StateCast(dispatcher_
->SendFlowControl(routing_id_
, quota
));
198 ChannelState
WebSocketEventHandler::OnDropChannel(bool was_clean
,
200 const std::string
& reason
) {
201 DVLOG(3) << "WebSocketEventHandler::OnDropChannel"
202 << " routing_id=" << routing_id_
<< " was_clean=" << was_clean
203 << " code=" << code
<< " reason=\"" << reason
<< "\"";
206 dispatcher_
->DoDropChannel(routing_id_
, was_clean
, code
, reason
));
209 ChannelState
WebSocketEventHandler::OnFailChannel(const std::string
& message
) {
210 DVLOG(3) << "WebSocketEventHandler::OnFailChannel"
211 << " routing_id=" << routing_id_
212 << " message=\"" << message
<< "\"";
214 return StateCast(dispatcher_
->NotifyFailure(routing_id_
, message
));
217 ChannelState
WebSocketEventHandler::OnStartOpeningHandshake(
218 scoped_ptr
<net::WebSocketHandshakeRequestInfo
> request
) {
219 bool should_send
= dispatcher_
->CanReadRawCookies();
220 DVLOG(3) << "WebSocketEventHandler::OnStartOpeningHandshake "
221 << "should_send=" << should_send
;
224 return WebSocketEventInterface::CHANNEL_ALIVE
;
226 WebSocketHandshakeRequest request_to_pass
;
227 request_to_pass
.url
.Swap(&request
->url
);
228 net::HttpRequestHeaders::Iterator
it(request
->headers
);
230 request_to_pass
.headers
.push_back(std::make_pair(it
.name(), it
.value()));
231 request_to_pass
.headers_text
=
232 base::StringPrintf("GET %s HTTP/1.1\r\n",
233 request_to_pass
.url
.spec().c_str()) +
234 request
->headers
.ToString();
235 request_to_pass
.request_time
= request
->request_time
;
237 return StateCast(dispatcher_
->NotifyStartOpeningHandshake(routing_id_
,
241 ChannelState
WebSocketEventHandler::OnFinishOpeningHandshake(
242 scoped_ptr
<net::WebSocketHandshakeResponseInfo
> response
) {
243 bool should_send
= dispatcher_
->CanReadRawCookies();
244 DVLOG(3) << "WebSocketEventHandler::OnFinishOpeningHandshake "
245 << "should_send=" << should_send
;
248 return WebSocketEventInterface::CHANNEL_ALIVE
;
250 WebSocketHandshakeResponse response_to_pass
;
251 response_to_pass
.url
.Swap(&response
->url
);
252 response_to_pass
.status_code
= response
->status_code
;
253 response_to_pass
.status_text
.swap(response
->status_text
);
255 std::string name
, value
;
256 while (response
->headers
->EnumerateHeaderLines(&iter
, &name
, &value
))
257 response_to_pass
.headers
.push_back(std::make_pair(name
, value
));
258 response_to_pass
.headers_text
=
259 net::HttpUtil::ConvertHeadersBackToHTTPResponse(
260 response
->headers
->raw_headers());
261 response_to_pass
.response_time
= response
->response_time
;
263 return StateCast(dispatcher_
->NotifyFinishOpeningHandshake(routing_id_
,
267 ChannelState
WebSocketEventHandler::OnSSLCertificateError(
268 scoped_ptr
<net::WebSocketEventInterface::SSLErrorCallbacks
> callbacks
,
270 const net::SSLInfo
& ssl_info
,
272 DVLOG(3) << "WebSocketEventHandler::OnSSLCertificateError"
273 << " routing_id=" << routing_id_
<< " url=" << url
.spec()
274 << " cert_status=" << ssl_info
.cert_status
<< " fatal=" << fatal
;
275 ssl_error_handler_delegate_
.reset(
276 new SSLErrorHandlerDelegate(callbacks
.Pass()));
277 SSLManager::OnSSLCertificateError(ssl_error_handler_delegate_
->GetWeakPtr(),
278 RESOURCE_TYPE_SUB_RESOURCE
,
280 dispatcher_
->render_process_id(),
284 // The above method is always asynchronous.
285 return WebSocketEventInterface::CHANNEL_ALIVE
;
288 WebSocketEventHandler::SSLErrorHandlerDelegate::SSLErrorHandlerDelegate(
289 scoped_ptr
<net::WebSocketEventInterface::SSLErrorCallbacks
> callbacks
)
290 : callbacks_(callbacks
.Pass()), weak_ptr_factory_(this) {}
292 WebSocketEventHandler::SSLErrorHandlerDelegate::~SSLErrorHandlerDelegate() {}
294 base::WeakPtr
<SSLErrorHandler::Delegate
>
295 WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() {
296 return weak_ptr_factory_
.GetWeakPtr();
299 void WebSocketEventHandler::SSLErrorHandlerDelegate::CancelSSLRequest(
301 const net::SSLInfo
* ssl_info
) {
302 DVLOG(3) << "SSLErrorHandlerDelegate::CancelSSLRequest"
303 << " error=" << error
304 << " cert_status=" << (ssl_info
? ssl_info
->cert_status
305 : static_cast<net::CertStatus
>(-1));
306 callbacks_
->CancelSSLRequest(error
, ssl_info
);
309 void WebSocketEventHandler::SSLErrorHandlerDelegate::ContinueSSLRequest() {
310 DVLOG(3) << "SSLErrorHandlerDelegate::ContinueSSLRequest";
311 callbacks_
->ContinueSSLRequest();
316 WebSocketHost::WebSocketHost(int routing_id
,
317 WebSocketDispatcherHost
* dispatcher
,
318 net::URLRequestContext
* url_request_context
,
319 base::TimeDelta delay
)
320 : dispatcher_(dispatcher
),
321 url_request_context_(url_request_context
),
322 routing_id_(routing_id
),
324 pending_flow_control_quota_(0),
325 handshake_succeeded_(false),
326 weak_ptr_factory_(this) {
327 DVLOG(1) << "WebSocketHost: created routing_id=" << routing_id
;
330 WebSocketHost::~WebSocketHost() {}
332 void WebSocketHost::GoAway() {
333 OnDropChannel(false, static_cast<uint16
>(net::kWebSocketErrorGoingAway
), "");
336 bool WebSocketHost::OnMessageReceived(const IPC::Message
& message
) {
338 IPC_BEGIN_MESSAGE_MAP(WebSocketHost
, message
)
339 IPC_MESSAGE_HANDLER(WebSocketHostMsg_AddChannelRequest
, OnAddChannelRequest
)
340 IPC_MESSAGE_HANDLER(WebSocketMsg_SendFrame
, OnSendFrame
)
341 IPC_MESSAGE_HANDLER(WebSocketMsg_FlowControl
, OnFlowControl
)
342 IPC_MESSAGE_HANDLER(WebSocketMsg_DropChannel
, OnDropChannel
)
343 IPC_MESSAGE_UNHANDLED(handled
= false)
344 IPC_END_MESSAGE_MAP()
348 void WebSocketHost::OnAddChannelRequest(
349 const GURL
& socket_url
,
350 const std::vector
<std::string
>& requested_protocols
,
351 const url::Origin
& origin
,
352 int render_frame_id
) {
353 DVLOG(3) << "WebSocketHost::OnAddChannelRequest"
354 << " routing_id=" << routing_id_
<< " socket_url=\"" << socket_url
355 << "\" requested_protocols=\""
356 << JoinString(requested_protocols
, ", ") << "\" origin=\""
357 << origin
.string() << "\"";
360 if (delay_
> base::TimeDelta()) {
361 base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
363 base::Bind(&WebSocketHost::AddChannel
, weak_ptr_factory_
.GetWeakPtr(),
364 socket_url
, requested_protocols
, origin
, render_frame_id
),
367 AddChannel(socket_url
, requested_protocols
, origin
, render_frame_id
);
369 // |this| may have been deleted here.
372 void WebSocketHost::AddChannel(
373 const GURL
& socket_url
,
374 const std::vector
<std::string
>& requested_protocols
,
375 const url::Origin
& origin
,
376 int render_frame_id
) {
377 DVLOG(3) << "WebSocketHost::AddChannel"
378 << " routing_id=" << routing_id_
<< " socket_url=\"" << socket_url
379 << "\" requested_protocols=\""
380 << JoinString(requested_protocols
, ", ") << "\" origin=\""
381 << origin
.string() << "\"";
385 scoped_ptr
<net::WebSocketEventInterface
> event_interface(
386 new WebSocketEventHandler(dispatcher_
, routing_id_
, render_frame_id
));
388 new net::WebSocketChannel(event_interface
.Pass(), url_request_context_
));
390 if (pending_flow_control_quota_
> 0) {
391 // channel_->SendFlowControl(pending_flow_control_quota_) must be called
392 // after channel_->SendAddChannelRequest() below.
393 // We post OnFlowControl() here using |weak_ptr_factory_| instead of
394 // calling SendFlowControl directly, because |this| may have been deleted
395 // after channel_->SendAddChannelRequest().
396 base::ThreadTaskRunnerHandle::Get()->PostTask(
397 FROM_HERE
, base::Bind(&WebSocketHost::OnFlowControl
,
398 weak_ptr_factory_
.GetWeakPtr(),
399 pending_flow_control_quota_
));
400 pending_flow_control_quota_
= 0;
403 channel_
->SendAddChannelRequest(socket_url
, requested_protocols
, origin
);
404 // |this| may have been deleted here.
407 void WebSocketHost::OnSendFrame(bool fin
,
408 WebSocketMessageType type
,
409 const std::vector
<char>& data
) {
410 DVLOG(3) << "WebSocketHost::OnSendFrame"
411 << " routing_id=" << routing_id_
<< " fin=" << fin
412 << " type=" << type
<< " data is " << data
.size() << " bytes";
415 channel_
->SendFrame(fin
, MessageTypeToOpCode(type
), data
);
418 void WebSocketHost::OnFlowControl(int64 quota
) {
419 DVLOG(3) << "WebSocketHost::OnFlowControl"
420 << " routing_id=" << routing_id_
<< " quota=" << quota
;
423 // WebSocketChannel is not yet created due to the delay introduced by
424 // per-renderer WebSocket throttling.
425 // SendFlowControl() is called after WebSocketChannel is created.
426 pending_flow_control_quota_
+= quota
;
430 channel_
->SendFlowControl(quota
);
433 void WebSocketHost::OnDropChannel(bool was_clean
,
435 const std::string
& reason
) {
436 DVLOG(3) << "WebSocketHost::OnDropChannel"
437 << " routing_id=" << routing_id_
<< " was_clean=" << was_clean
438 << " code=" << code
<< " reason=\"" << reason
<< "\"";
441 // WebSocketChannel is not yet created due to the delay introduced by
442 // per-renderer WebSocket throttling.
443 WebSocketDispatcherHost::WebSocketHostState result
=
444 dispatcher_
->DoDropChannel(routing_id_
,
446 net::kWebSocketErrorAbnormalClosure
,
448 DCHECK_EQ(WebSocketDispatcherHost::WEBSOCKET_HOST_DELETED
, result
);
452 // TODO(yhirano): Handle |was_clean| appropriately.
453 channel_
->StartClosingHandshake(code
, reason
);
456 } // namespace content