1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "content/browser/renderer_host/websocket_host.h"
7 #include "base/basictypes.h"
8 #include "base/memory/weak_ptr.h"
9 #include "base/strings/string_util.h"
10 #include "content/browser/renderer_host/websocket_dispatcher_host.h"
11 #include "content/browser/ssl/ssl_error_handler.h"
12 #include "content/browser/ssl/ssl_manager.h"
13 #include "content/common/websocket_messages.h"
14 #include "ipc/ipc_message_macros.h"
15 #include "net/http/http_request_headers.h"
16 #include "net/http/http_response_headers.h"
17 #include "net/http/http_util.h"
18 #include "net/ssl/ssl_info.h"
19 #include "net/websockets/websocket_channel.h"
20 #include "net/websockets/websocket_errors.h"
21 #include "net/websockets/websocket_event_interface.h"
22 #include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode
23 #include "net/websockets/websocket_handshake_request_info.h"
24 #include "net/websockets/websocket_handshake_response_info.h"
25 #include "url/origin.h"
31 typedef net::WebSocketEventInterface::ChannelState ChannelState
;
33 // Convert a content::WebSocketMessageType to a
34 // net::WebSocketFrameHeader::OpCode
35 net::WebSocketFrameHeader::OpCode
MessageTypeToOpCode(
36 WebSocketMessageType type
) {
37 DCHECK(type
== WEB_SOCKET_MESSAGE_TYPE_CONTINUATION
||
38 type
== WEB_SOCKET_MESSAGE_TYPE_TEXT
||
39 type
== WEB_SOCKET_MESSAGE_TYPE_BINARY
);
40 typedef net::WebSocketFrameHeader::OpCode OpCode
;
41 // These compile asserts verify that the same underlying values are used for
42 // both types, so we can simply cast between them.
43 static_assert(static_cast<OpCode
>(WEB_SOCKET_MESSAGE_TYPE_CONTINUATION
) ==
44 net::WebSocketFrameHeader::kOpCodeContinuation
,
45 "enum values must match for opcode continuation");
46 static_assert(static_cast<OpCode
>(WEB_SOCKET_MESSAGE_TYPE_TEXT
) ==
47 net::WebSocketFrameHeader::kOpCodeText
,
48 "enum values must match for opcode text");
49 static_assert(static_cast<OpCode
>(WEB_SOCKET_MESSAGE_TYPE_BINARY
) ==
50 net::WebSocketFrameHeader::kOpCodeBinary
,
51 "enum values must match for opcode binary");
52 return static_cast<OpCode
>(type
);
55 WebSocketMessageType
OpCodeToMessageType(
56 net::WebSocketFrameHeader::OpCode opCode
) {
57 DCHECK(opCode
== net::WebSocketFrameHeader::kOpCodeContinuation
||
58 opCode
== net::WebSocketFrameHeader::kOpCodeText
||
59 opCode
== net::WebSocketFrameHeader::kOpCodeBinary
);
60 // This cast is guaranteed valid by the static_assert() statements above.
61 return static_cast<WebSocketMessageType
>(opCode
);
64 ChannelState
StateCast(WebSocketDispatcherHost::WebSocketHostState host_state
) {
65 const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_ALIVE
=
66 WebSocketDispatcherHost::WEBSOCKET_HOST_ALIVE
;
67 const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_DELETED
=
68 WebSocketDispatcherHost::WEBSOCKET_HOST_DELETED
;
70 DCHECK(host_state
== WEBSOCKET_HOST_ALIVE
||
71 host_state
== WEBSOCKET_HOST_DELETED
);
72 // These compile asserts verify that we can get away with using static_cast<>
73 // for the conversion.
74 static_assert(static_cast<ChannelState
>(WEBSOCKET_HOST_ALIVE
) ==
75 net::WebSocketEventInterface::CHANNEL_ALIVE
,
76 "enum values must match for state_alive");
77 static_assert(static_cast<ChannelState
>(WEBSOCKET_HOST_DELETED
) ==
78 net::WebSocketEventInterface::CHANNEL_DELETED
,
79 "enum values must match for state_deleted");
80 return static_cast<ChannelState
>(host_state
);
83 // Implementation of net::WebSocketEventInterface. Receives events from our
84 // WebSocketChannel object. Each event is translated to an IPC and sent to the
85 // renderer or child process via WebSocketDispatcherHost.
86 class WebSocketEventHandler
: public net::WebSocketEventInterface
{
88 WebSocketEventHandler(WebSocketDispatcherHost
* dispatcher
,
91 ~WebSocketEventHandler() override
;
93 // net::WebSocketEventInterface implementation
95 ChannelState
OnAddChannelResponse(const std::string
& selected_subprotocol
,
96 const std::string
& extensions
) override
;
97 ChannelState
OnDataFrame(bool fin
,
98 WebSocketMessageType type
,
99 const std::vector
<char>& data
) override
;
100 ChannelState
OnClosingHandshake() override
;
101 ChannelState
OnFlowControl(int64 quota
) override
;
102 ChannelState
OnDropChannel(bool was_clean
,
104 const std::string
& reason
) override
;
105 ChannelState
OnFailChannel(const std::string
& message
) override
;
106 ChannelState
OnStartOpeningHandshake(
107 scoped_ptr
<net::WebSocketHandshakeRequestInfo
> request
) override
;
108 ChannelState
OnFinishOpeningHandshake(
109 scoped_ptr
<net::WebSocketHandshakeResponseInfo
> response
) override
;
110 ChannelState
OnSSLCertificateError(
111 scoped_ptr
<net::WebSocketEventInterface::SSLErrorCallbacks
> callbacks
,
113 const net::SSLInfo
& ssl_info
,
114 bool fatal
) override
;
117 class SSLErrorHandlerDelegate
: public SSLErrorHandler::Delegate
{
119 SSLErrorHandlerDelegate(
120 scoped_ptr
<net::WebSocketEventInterface::SSLErrorCallbacks
> callbacks
);
121 ~SSLErrorHandlerDelegate() override
;
123 base::WeakPtr
<SSLErrorHandler::Delegate
> GetWeakPtr();
125 // SSLErrorHandler::Delegate methods
126 void CancelSSLRequest(int error
, const net::SSLInfo
* ssl_info
) override
;
127 void ContinueSSLRequest() override
;
130 scoped_ptr
<net::WebSocketEventInterface::SSLErrorCallbacks
> callbacks_
;
131 base::WeakPtrFactory
<SSLErrorHandlerDelegate
> weak_ptr_factory_
;
133 DISALLOW_COPY_AND_ASSIGN(SSLErrorHandlerDelegate
);
136 WebSocketDispatcherHost
* const dispatcher_
;
137 const int routing_id_
;
138 const int render_frame_id_
;
139 scoped_ptr
<SSLErrorHandlerDelegate
> ssl_error_handler_delegate_
;
141 DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler
);
144 WebSocketEventHandler::WebSocketEventHandler(
145 WebSocketDispatcherHost
* dispatcher
,
148 : dispatcher_(dispatcher
),
149 routing_id_(routing_id
),
150 render_frame_id_(render_frame_id
) {
153 WebSocketEventHandler::~WebSocketEventHandler() {
154 DVLOG(1) << "WebSocketEventHandler destroyed routing_id=" << routing_id_
;
157 ChannelState
WebSocketEventHandler::OnAddChannelResponse(
158 const std::string
& selected_protocol
,
159 const std::string
& extensions
) {
160 DVLOG(3) << "WebSocketEventHandler::OnAddChannelResponse"
161 << " routing_id=" << routing_id_
162 << " selected_protocol=\"" << selected_protocol
<< "\""
163 << " extensions=\"" << extensions
<< "\"";
165 return StateCast(dispatcher_
->SendAddChannelResponse(
166 routing_id_
, selected_protocol
, extensions
));
169 ChannelState
WebSocketEventHandler::OnDataFrame(
171 net::WebSocketFrameHeader::OpCode type
,
172 const std::vector
<char>& data
) {
173 DVLOG(3) << "WebSocketEventHandler::OnDataFrame"
174 << " routing_id=" << routing_id_
<< " fin=" << fin
175 << " type=" << type
<< " data is " << data
.size() << " bytes";
177 return StateCast(dispatcher_
->SendFrame(
178 routing_id_
, fin
, OpCodeToMessageType(type
), data
));
181 ChannelState
WebSocketEventHandler::OnClosingHandshake() {
182 DVLOG(3) << "WebSocketEventHandler::OnClosingHandshake"
183 << " routing_id=" << routing_id_
;
185 return StateCast(dispatcher_
->NotifyClosingHandshake(routing_id_
));
188 ChannelState
WebSocketEventHandler::OnFlowControl(int64 quota
) {
189 DVLOG(3) << "WebSocketEventHandler::OnFlowControl"
190 << " routing_id=" << routing_id_
<< " quota=" << quota
;
192 return StateCast(dispatcher_
->SendFlowControl(routing_id_
, quota
));
195 ChannelState
WebSocketEventHandler::OnDropChannel(bool was_clean
,
197 const std::string
& reason
) {
198 DVLOG(3) << "WebSocketEventHandler::OnDropChannel"
199 << " routing_id=" << routing_id_
<< " was_clean=" << was_clean
200 << " code=" << code
<< " reason=\"" << reason
<< "\"";
203 dispatcher_
->DoDropChannel(routing_id_
, was_clean
, code
, reason
));
206 ChannelState
WebSocketEventHandler::OnFailChannel(const std::string
& message
) {
207 DVLOG(3) << "WebSocketEventHandler::OnFailChannel"
208 << " routing_id=" << routing_id_
209 << " message=\"" << message
<< "\"";
211 return StateCast(dispatcher_
->NotifyFailure(routing_id_
, message
));
214 ChannelState
WebSocketEventHandler::OnStartOpeningHandshake(
215 scoped_ptr
<net::WebSocketHandshakeRequestInfo
> request
) {
216 bool should_send
= dispatcher_
->CanReadRawCookies();
217 DVLOG(3) << "WebSocketEventHandler::OnStartOpeningHandshake "
218 << "should_send=" << should_send
;
221 return WebSocketEventInterface::CHANNEL_ALIVE
;
223 WebSocketHandshakeRequest request_to_pass
;
224 request_to_pass
.url
.Swap(&request
->url
);
225 net::HttpRequestHeaders::Iterator
it(request
->headers
);
227 request_to_pass
.headers
.push_back(std::make_pair(it
.name(), it
.value()));
228 request_to_pass
.headers_text
=
229 base::StringPrintf("GET %s HTTP/1.1\r\n",
230 request_to_pass
.url
.spec().c_str()) +
231 request
->headers
.ToString();
232 request_to_pass
.request_time
= request
->request_time
;
234 return StateCast(dispatcher_
->NotifyStartOpeningHandshake(routing_id_
,
238 ChannelState
WebSocketEventHandler::OnFinishOpeningHandshake(
239 scoped_ptr
<net::WebSocketHandshakeResponseInfo
> response
) {
240 bool should_send
= dispatcher_
->CanReadRawCookies();
241 DVLOG(3) << "WebSocketEventHandler::OnFinishOpeningHandshake "
242 << "should_send=" << should_send
;
245 return WebSocketEventInterface::CHANNEL_ALIVE
;
247 WebSocketHandshakeResponse response_to_pass
;
248 response_to_pass
.url
.Swap(&response
->url
);
249 response_to_pass
.status_code
= response
->status_code
;
250 response_to_pass
.status_text
.swap(response
->status_text
);
252 std::string name
, value
;
253 while (response
->headers
->EnumerateHeaderLines(&iter
, &name
, &value
))
254 response_to_pass
.headers
.push_back(std::make_pair(name
, value
));
255 response_to_pass
.headers_text
=
256 net::HttpUtil::ConvertHeadersBackToHTTPResponse(
257 response
->headers
->raw_headers());
258 response_to_pass
.response_time
= response
->response_time
;
260 return StateCast(dispatcher_
->NotifyFinishOpeningHandshake(routing_id_
,
264 ChannelState
WebSocketEventHandler::OnSSLCertificateError(
265 scoped_ptr
<net::WebSocketEventInterface::SSLErrorCallbacks
> callbacks
,
267 const net::SSLInfo
& ssl_info
,
269 DVLOG(3) << "WebSocketEventHandler::OnSSLCertificateError"
270 << " routing_id=" << routing_id_
<< " url=" << url
.spec()
271 << " cert_status=" << ssl_info
.cert_status
<< " fatal=" << fatal
;
272 ssl_error_handler_delegate_
.reset(
273 new SSLErrorHandlerDelegate(callbacks
.Pass()));
274 SSLManager::OnSSLCertificateError(ssl_error_handler_delegate_
->GetWeakPtr(),
275 RESOURCE_TYPE_SUB_RESOURCE
,
277 dispatcher_
->render_process_id(),
281 // The above method is always asynchronous.
282 return WebSocketEventInterface::CHANNEL_ALIVE
;
285 WebSocketEventHandler::SSLErrorHandlerDelegate::SSLErrorHandlerDelegate(
286 scoped_ptr
<net::WebSocketEventInterface::SSLErrorCallbacks
> callbacks
)
287 : callbacks_(callbacks
.Pass()), weak_ptr_factory_(this) {}
289 WebSocketEventHandler::SSLErrorHandlerDelegate::~SSLErrorHandlerDelegate() {}
291 base::WeakPtr
<SSLErrorHandler::Delegate
>
292 WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() {
293 return weak_ptr_factory_
.GetWeakPtr();
296 void WebSocketEventHandler::SSLErrorHandlerDelegate::CancelSSLRequest(
298 const net::SSLInfo
* ssl_info
) {
299 DVLOG(3) << "SSLErrorHandlerDelegate::CancelSSLRequest"
300 << " error=" << error
301 << " cert_status=" << (ssl_info
? ssl_info
->cert_status
302 : static_cast<net::CertStatus
>(-1));
303 callbacks_
->CancelSSLRequest(error
, ssl_info
);
306 void WebSocketEventHandler::SSLErrorHandlerDelegate::ContinueSSLRequest() {
307 DVLOG(3) << "SSLErrorHandlerDelegate::ContinueSSLRequest";
308 callbacks_
->ContinueSSLRequest();
313 WebSocketHost::WebSocketHost(int routing_id
,
314 WebSocketDispatcherHost
* dispatcher
,
315 net::URLRequestContext
* url_request_context
,
316 base::TimeDelta delay
)
317 : dispatcher_(dispatcher
),
318 url_request_context_(url_request_context
),
319 routing_id_(routing_id
),
321 pending_flow_control_quota_(0),
322 handshake_succeeded_(false),
323 weak_ptr_factory_(this) {
324 DVLOG(1) << "WebSocketHost: created routing_id=" << routing_id
;
327 WebSocketHost::~WebSocketHost() {}
329 void WebSocketHost::GoAway() {
330 OnDropChannel(false, static_cast<uint16
>(net::kWebSocketErrorGoingAway
), "");
333 bool WebSocketHost::OnMessageReceived(const IPC::Message
& message
) {
335 IPC_BEGIN_MESSAGE_MAP(WebSocketHost
, message
)
336 IPC_MESSAGE_HANDLER(WebSocketHostMsg_AddChannelRequest
, OnAddChannelRequest
)
337 IPC_MESSAGE_HANDLER(WebSocketMsg_SendFrame
, OnSendFrame
)
338 IPC_MESSAGE_HANDLER(WebSocketMsg_FlowControl
, OnFlowControl
)
339 IPC_MESSAGE_HANDLER(WebSocketMsg_DropChannel
, OnDropChannel
)
340 IPC_MESSAGE_UNHANDLED(handled
= false)
341 IPC_END_MESSAGE_MAP()
345 void WebSocketHost::OnAddChannelRequest(
346 const GURL
& socket_url
,
347 const std::vector
<std::string
>& requested_protocols
,
348 const url::Origin
& origin
,
349 int render_frame_id
) {
350 DVLOG(3) << "WebSocketHost::OnAddChannelRequest"
351 << " routing_id=" << routing_id_
<< " socket_url=\"" << socket_url
352 << "\" requested_protocols=\""
353 << JoinString(requested_protocols
, ", ") << "\" origin=\""
354 << origin
.string() << "\"";
357 if (delay_
> base::TimeDelta()) {
358 base::MessageLoop::current()->PostDelayedTask(
360 base::Bind(&WebSocketHost::AddChannel
,
361 weak_ptr_factory_
.GetWeakPtr(),
368 AddChannel(socket_url
, requested_protocols
, origin
, render_frame_id
);
370 // |this| may have been deleted here.
373 void WebSocketHost::AddChannel(
374 const GURL
& socket_url
,
375 const std::vector
<std::string
>& requested_protocols
,
376 const url::Origin
& origin
,
377 int render_frame_id
) {
378 DVLOG(3) << "WebSocketHost::AddChannel"
379 << " routing_id=" << routing_id_
<< " socket_url=\"" << socket_url
380 << "\" requested_protocols=\""
381 << JoinString(requested_protocols
, ", ") << "\" origin=\""
382 << origin
.string() << "\"";
386 scoped_ptr
<net::WebSocketEventInterface
> event_interface(
387 new WebSocketEventHandler(dispatcher_
, routing_id_
, render_frame_id
));
389 new net::WebSocketChannel(event_interface
.Pass(), url_request_context_
));
391 if (pending_flow_control_quota_
> 0) {
392 // channel_->SendFlowControl(pending_flow_control_quota_) must be called
393 // after channel_->SendAddChannelRequest() below.
394 // We post OnFlowControl() here using |weak_ptr_factory_| instead of
395 // calling SendFlowControl directly, because |this| may have been deleted
396 // after channel_->SendAddChannelRequest().
397 base::MessageLoop::current()->PostTask(
399 base::Bind(&WebSocketHost::OnFlowControl
,
400 weak_ptr_factory_
.GetWeakPtr(),
401 pending_flow_control_quota_
));
402 pending_flow_control_quota_
= 0;
405 channel_
->SendAddChannelRequest(socket_url
, requested_protocols
, origin
);
406 // |this| may have been deleted here.
409 void WebSocketHost::OnSendFrame(bool fin
,
410 WebSocketMessageType type
,
411 const std::vector
<char>& data
) {
412 DVLOG(3) << "WebSocketHost::OnSendFrame"
413 << " routing_id=" << routing_id_
<< " fin=" << fin
414 << " type=" << type
<< " data is " << data
.size() << " bytes";
417 channel_
->SendFrame(fin
, MessageTypeToOpCode(type
), data
);
420 void WebSocketHost::OnFlowControl(int64 quota
) {
421 DVLOG(3) << "WebSocketHost::OnFlowControl"
422 << " routing_id=" << routing_id_
<< " quota=" << quota
;
425 // WebSocketChannel is not yet created due to the delay introduced by
426 // per-renderer WebSocket throttling.
427 // SendFlowControl() is called after WebSocketChannel is created.
428 pending_flow_control_quota_
+= quota
;
432 channel_
->SendFlowControl(quota
);
435 void WebSocketHost::OnDropChannel(bool was_clean
,
437 const std::string
& reason
) {
438 DVLOG(3) << "WebSocketHost::OnDropChannel"
439 << " routing_id=" << routing_id_
<< " was_clean=" << was_clean
440 << " code=" << code
<< " reason=\"" << reason
<< "\"";
443 // WebSocketChannel is not yet created due to the delay introduced by
444 // per-renderer WebSocket throttling.
445 WebSocketDispatcherHost::WebSocketHostState result
=
446 dispatcher_
->DoDropChannel(routing_id_
,
448 net::kWebSocketErrorAbnormalClosure
,
450 DCHECK_EQ(WebSocketDispatcherHost::WEBSOCKET_HOST_DELETED
, result
);
454 // TODO(yhirano): Handle |was_clean| appropriately.
455 channel_
->StartClosingHandshake(code
, reason
);
458 } // namespace content