Merge Chromium + Blink git repositories
[chromium-blink-merge.git] / content / browser / renderer_host / websocket_host.cc
blob70c846e39810248f7ac1974cdeb04a1f5efb9a9d
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "content/browser/renderer_host/websocket_host.h"
7 #include "base/basictypes.h"
8 #include "base/location.h"
9 #include "base/memory/weak_ptr.h"
10 #include "base/single_thread_task_runner.h"
11 #include "base/strings/string_util.h"
12 #include "base/thread_task_runner_handle.h"
13 #include "content/browser/renderer_host/websocket_dispatcher_host.h"
14 #include "content/browser/ssl/ssl_error_handler.h"
15 #include "content/browser/ssl/ssl_manager.h"
16 #include "content/common/websocket_messages.h"
17 #include "ipc/ipc_message_macros.h"
18 #include "net/http/http_request_headers.h"
19 #include "net/http/http_response_headers.h"
20 #include "net/http/http_util.h"
21 #include "net/ssl/ssl_info.h"
22 #include "net/websockets/websocket_channel.h"
23 #include "net/websockets/websocket_errors.h"
24 #include "net/websockets/websocket_event_interface.h"
25 #include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode
26 #include "net/websockets/websocket_handshake_request_info.h"
27 #include "net/websockets/websocket_handshake_response_info.h"
28 #include "url/origin.h"
30 namespace content {
32 namespace {
34 typedef net::WebSocketEventInterface::ChannelState ChannelState;
36 // Convert a content::WebSocketMessageType to a
37 // net::WebSocketFrameHeader::OpCode
38 net::WebSocketFrameHeader::OpCode MessageTypeToOpCode(
39 WebSocketMessageType type) {
40 DCHECK(type == WEB_SOCKET_MESSAGE_TYPE_CONTINUATION ||
41 type == WEB_SOCKET_MESSAGE_TYPE_TEXT ||
42 type == WEB_SOCKET_MESSAGE_TYPE_BINARY);
43 typedef net::WebSocketFrameHeader::OpCode OpCode;
44 // These compile asserts verify that the same underlying values are used for
45 // both types, so we can simply cast between them.
46 static_assert(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_CONTINUATION) ==
47 net::WebSocketFrameHeader::kOpCodeContinuation,
48 "enum values must match for opcode continuation");
49 static_assert(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_TEXT) ==
50 net::WebSocketFrameHeader::kOpCodeText,
51 "enum values must match for opcode text");
52 static_assert(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_BINARY) ==
53 net::WebSocketFrameHeader::kOpCodeBinary,
54 "enum values must match for opcode binary");
55 return static_cast<OpCode>(type);
58 WebSocketMessageType OpCodeToMessageType(
59 net::WebSocketFrameHeader::OpCode opCode) {
60 DCHECK(opCode == net::WebSocketFrameHeader::kOpCodeContinuation ||
61 opCode == net::WebSocketFrameHeader::kOpCodeText ||
62 opCode == net::WebSocketFrameHeader::kOpCodeBinary);
63 // This cast is guaranteed valid by the static_assert() statements above.
64 return static_cast<WebSocketMessageType>(opCode);
67 ChannelState StateCast(WebSocketDispatcherHost::WebSocketHostState host_state) {
68 const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_ALIVE =
69 WebSocketDispatcherHost::WEBSOCKET_HOST_ALIVE;
70 const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_DELETED =
71 WebSocketDispatcherHost::WEBSOCKET_HOST_DELETED;
73 DCHECK(host_state == WEBSOCKET_HOST_ALIVE ||
74 host_state == WEBSOCKET_HOST_DELETED);
75 // These compile asserts verify that we can get away with using static_cast<>
76 // for the conversion.
77 static_assert(static_cast<ChannelState>(WEBSOCKET_HOST_ALIVE) ==
78 net::WebSocketEventInterface::CHANNEL_ALIVE,
79 "enum values must match for state_alive");
80 static_assert(static_cast<ChannelState>(WEBSOCKET_HOST_DELETED) ==
81 net::WebSocketEventInterface::CHANNEL_DELETED,
82 "enum values must match for state_deleted");
83 return static_cast<ChannelState>(host_state);
86 // Implementation of net::WebSocketEventInterface. Receives events from our
87 // WebSocketChannel object. Each event is translated to an IPC and sent to the
88 // renderer or child process via WebSocketDispatcherHost.
89 class WebSocketEventHandler : public net::WebSocketEventInterface {
90 public:
91 WebSocketEventHandler(WebSocketDispatcherHost* dispatcher,
92 int routing_id,
93 int render_frame_id);
94 ~WebSocketEventHandler() override;
96 // net::WebSocketEventInterface implementation
98 ChannelState OnAddChannelResponse(const std::string& selected_subprotocol,
99 const std::string& extensions) override;
100 ChannelState OnDataFrame(bool fin,
101 WebSocketMessageType type,
102 const std::vector<char>& data) override;
103 ChannelState OnClosingHandshake() override;
104 ChannelState OnFlowControl(int64 quota) override;
105 ChannelState OnDropChannel(bool was_clean,
106 uint16 code,
107 const std::string& reason) override;
108 ChannelState OnFailChannel(const std::string& message) override;
109 ChannelState OnStartOpeningHandshake(
110 scoped_ptr<net::WebSocketHandshakeRequestInfo> request) override;
111 ChannelState OnFinishOpeningHandshake(
112 scoped_ptr<net::WebSocketHandshakeResponseInfo> response) override;
113 ChannelState OnSSLCertificateError(
114 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
115 const GURL& url,
116 const net::SSLInfo& ssl_info,
117 bool fatal) override;
119 private:
120 class SSLErrorHandlerDelegate : public SSLErrorHandler::Delegate {
121 public:
122 SSLErrorHandlerDelegate(
123 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks);
124 ~SSLErrorHandlerDelegate() override;
126 base::WeakPtr<SSLErrorHandler::Delegate> GetWeakPtr();
128 // SSLErrorHandler::Delegate methods
129 void CancelSSLRequest(int error, const net::SSLInfo* ssl_info) override;
130 void ContinueSSLRequest() override;
132 private:
133 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks_;
134 base::WeakPtrFactory<SSLErrorHandlerDelegate> weak_ptr_factory_;
136 DISALLOW_COPY_AND_ASSIGN(SSLErrorHandlerDelegate);
139 WebSocketDispatcherHost* const dispatcher_;
140 const int routing_id_;
141 const int render_frame_id_;
142 scoped_ptr<SSLErrorHandlerDelegate> ssl_error_handler_delegate_;
144 DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler);
147 WebSocketEventHandler::WebSocketEventHandler(
148 WebSocketDispatcherHost* dispatcher,
149 int routing_id,
150 int render_frame_id)
151 : dispatcher_(dispatcher),
152 routing_id_(routing_id),
153 render_frame_id_(render_frame_id) {
156 WebSocketEventHandler::~WebSocketEventHandler() {
157 DVLOG(1) << "WebSocketEventHandler destroyed routing_id=" << routing_id_;
160 ChannelState WebSocketEventHandler::OnAddChannelResponse(
161 const std::string& selected_protocol,
162 const std::string& extensions) {
163 DVLOG(3) << "WebSocketEventHandler::OnAddChannelResponse"
164 << " routing_id=" << routing_id_
165 << " selected_protocol=\"" << selected_protocol << "\""
166 << " extensions=\"" << extensions << "\"";
168 return StateCast(dispatcher_->SendAddChannelResponse(
169 routing_id_, selected_protocol, extensions));
172 ChannelState WebSocketEventHandler::OnDataFrame(
173 bool fin,
174 net::WebSocketFrameHeader::OpCode type,
175 const std::vector<char>& data) {
176 DVLOG(3) << "WebSocketEventHandler::OnDataFrame"
177 << " routing_id=" << routing_id_ << " fin=" << fin
178 << " type=" << type << " data is " << data.size() << " bytes";
180 return StateCast(dispatcher_->SendFrame(
181 routing_id_, fin, OpCodeToMessageType(type), data));
184 ChannelState WebSocketEventHandler::OnClosingHandshake() {
185 DVLOG(3) << "WebSocketEventHandler::OnClosingHandshake"
186 << " routing_id=" << routing_id_;
188 return StateCast(dispatcher_->NotifyClosingHandshake(routing_id_));
191 ChannelState WebSocketEventHandler::OnFlowControl(int64 quota) {
192 DVLOG(3) << "WebSocketEventHandler::OnFlowControl"
193 << " routing_id=" << routing_id_ << " quota=" << quota;
195 return StateCast(dispatcher_->SendFlowControl(routing_id_, quota));
198 ChannelState WebSocketEventHandler::OnDropChannel(bool was_clean,
199 uint16 code,
200 const std::string& reason) {
201 DVLOG(3) << "WebSocketEventHandler::OnDropChannel"
202 << " routing_id=" << routing_id_ << " was_clean=" << was_clean
203 << " code=" << code << " reason=\"" << reason << "\"";
205 return StateCast(
206 dispatcher_->DoDropChannel(routing_id_, was_clean, code, reason));
209 ChannelState WebSocketEventHandler::OnFailChannel(const std::string& message) {
210 DVLOG(3) << "WebSocketEventHandler::OnFailChannel"
211 << " routing_id=" << routing_id_
212 << " message=\"" << message << "\"";
214 return StateCast(dispatcher_->NotifyFailure(routing_id_, message));
217 ChannelState WebSocketEventHandler::OnStartOpeningHandshake(
218 scoped_ptr<net::WebSocketHandshakeRequestInfo> request) {
219 bool should_send = dispatcher_->CanReadRawCookies();
220 DVLOG(3) << "WebSocketEventHandler::OnStartOpeningHandshake "
221 << "should_send=" << should_send;
223 if (!should_send)
224 return WebSocketEventInterface::CHANNEL_ALIVE;
226 WebSocketHandshakeRequest request_to_pass;
227 request_to_pass.url.Swap(&request->url);
228 net::HttpRequestHeaders::Iterator it(request->headers);
229 while (it.GetNext())
230 request_to_pass.headers.push_back(std::make_pair(it.name(), it.value()));
231 request_to_pass.headers_text =
232 base::StringPrintf("GET %s HTTP/1.1\r\n",
233 request_to_pass.url.spec().c_str()) +
234 request->headers.ToString();
235 request_to_pass.request_time = request->request_time;
237 return StateCast(dispatcher_->NotifyStartOpeningHandshake(routing_id_,
238 request_to_pass));
241 ChannelState WebSocketEventHandler::OnFinishOpeningHandshake(
242 scoped_ptr<net::WebSocketHandshakeResponseInfo> response) {
243 bool should_send = dispatcher_->CanReadRawCookies();
244 DVLOG(3) << "WebSocketEventHandler::OnFinishOpeningHandshake "
245 << "should_send=" << should_send;
247 if (!should_send)
248 return WebSocketEventInterface::CHANNEL_ALIVE;
250 WebSocketHandshakeResponse response_to_pass;
251 response_to_pass.url.Swap(&response->url);
252 response_to_pass.status_code = response->status_code;
253 response_to_pass.status_text.swap(response->status_text);
254 void* iter = NULL;
255 std::string name, value;
256 while (response->headers->EnumerateHeaderLines(&iter, &name, &value))
257 response_to_pass.headers.push_back(std::make_pair(name, value));
258 response_to_pass.headers_text =
259 net::HttpUtil::ConvertHeadersBackToHTTPResponse(
260 response->headers->raw_headers());
261 response_to_pass.response_time = response->response_time;
263 return StateCast(dispatcher_->NotifyFinishOpeningHandshake(routing_id_,
264 response_to_pass));
267 ChannelState WebSocketEventHandler::OnSSLCertificateError(
268 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
269 const GURL& url,
270 const net::SSLInfo& ssl_info,
271 bool fatal) {
272 DVLOG(3) << "WebSocketEventHandler::OnSSLCertificateError"
273 << " routing_id=" << routing_id_ << " url=" << url.spec()
274 << " cert_status=" << ssl_info.cert_status << " fatal=" << fatal;
275 ssl_error_handler_delegate_.reset(
276 new SSLErrorHandlerDelegate(callbacks.Pass()));
277 SSLManager::OnSSLCertificateError(ssl_error_handler_delegate_->GetWeakPtr(),
278 RESOURCE_TYPE_SUB_RESOURCE,
279 url,
280 dispatcher_->render_process_id(),
281 render_frame_id_,
282 ssl_info,
283 fatal);
284 // The above method is always asynchronous.
285 return WebSocketEventInterface::CHANNEL_ALIVE;
288 WebSocketEventHandler::SSLErrorHandlerDelegate::SSLErrorHandlerDelegate(
289 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks)
290 : callbacks_(callbacks.Pass()), weak_ptr_factory_(this) {}
292 WebSocketEventHandler::SSLErrorHandlerDelegate::~SSLErrorHandlerDelegate() {}
294 base::WeakPtr<SSLErrorHandler::Delegate>
295 WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() {
296 return weak_ptr_factory_.GetWeakPtr();
299 void WebSocketEventHandler::SSLErrorHandlerDelegate::CancelSSLRequest(
300 int error,
301 const net::SSLInfo* ssl_info) {
302 DVLOG(3) << "SSLErrorHandlerDelegate::CancelSSLRequest"
303 << " error=" << error
304 << " cert_status=" << (ssl_info ? ssl_info->cert_status
305 : static_cast<net::CertStatus>(-1));
306 callbacks_->CancelSSLRequest(error, ssl_info);
309 void WebSocketEventHandler::SSLErrorHandlerDelegate::ContinueSSLRequest() {
310 DVLOG(3) << "SSLErrorHandlerDelegate::ContinueSSLRequest";
311 callbacks_->ContinueSSLRequest();
314 } // namespace
316 WebSocketHost::WebSocketHost(int routing_id,
317 WebSocketDispatcherHost* dispatcher,
318 net::URLRequestContext* url_request_context,
319 base::TimeDelta delay)
320 : dispatcher_(dispatcher),
321 url_request_context_(url_request_context),
322 routing_id_(routing_id),
323 delay_(delay),
324 pending_flow_control_quota_(0),
325 handshake_succeeded_(false),
326 weak_ptr_factory_(this) {
327 DVLOG(1) << "WebSocketHost: created routing_id=" << routing_id;
330 WebSocketHost::~WebSocketHost() {}
332 void WebSocketHost::GoAway() {
333 OnDropChannel(false, static_cast<uint16>(net::kWebSocketErrorGoingAway), "");
336 bool WebSocketHost::OnMessageReceived(const IPC::Message& message) {
337 bool handled = true;
338 IPC_BEGIN_MESSAGE_MAP(WebSocketHost, message)
339 IPC_MESSAGE_HANDLER(WebSocketHostMsg_AddChannelRequest, OnAddChannelRequest)
340 IPC_MESSAGE_HANDLER(WebSocketMsg_SendFrame, OnSendFrame)
341 IPC_MESSAGE_HANDLER(WebSocketMsg_FlowControl, OnFlowControl)
342 IPC_MESSAGE_HANDLER(WebSocketMsg_DropChannel, OnDropChannel)
343 IPC_MESSAGE_UNHANDLED(handled = false)
344 IPC_END_MESSAGE_MAP()
345 return handled;
348 void WebSocketHost::OnAddChannelRequest(
349 const GURL& socket_url,
350 const std::vector<std::string>& requested_protocols,
351 const url::Origin& origin,
352 int render_frame_id) {
353 DVLOG(3) << "WebSocketHost::OnAddChannelRequest"
354 << " routing_id=" << routing_id_ << " socket_url=\"" << socket_url
355 << "\" requested_protocols=\""
356 << base::JoinString(requested_protocols, ", ") << "\" origin=\""
357 << origin << "\"";
359 DCHECK(!channel_);
360 if (delay_ > base::TimeDelta()) {
361 base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
362 FROM_HERE,
363 base::Bind(&WebSocketHost::AddChannel, weak_ptr_factory_.GetWeakPtr(),
364 socket_url, requested_protocols, origin, render_frame_id),
365 delay_);
366 } else {
367 AddChannel(socket_url, requested_protocols, origin, render_frame_id);
369 // |this| may have been deleted here.
372 void WebSocketHost::AddChannel(
373 const GURL& socket_url,
374 const std::vector<std::string>& requested_protocols,
375 const url::Origin& origin,
376 int render_frame_id) {
377 DVLOG(3) << "WebSocketHost::AddChannel"
378 << " routing_id=" << routing_id_ << " socket_url=\"" << socket_url
379 << "\" requested_protocols=\""
380 << base::JoinString(requested_protocols, ", ") << "\" origin=\""
381 << origin << "\"";
383 DCHECK(!channel_);
385 scoped_ptr<net::WebSocketEventInterface> event_interface(
386 new WebSocketEventHandler(dispatcher_, routing_id_, render_frame_id));
387 channel_.reset(
388 new net::WebSocketChannel(event_interface.Pass(), url_request_context_));
390 if (pending_flow_control_quota_ > 0) {
391 // channel_->SendFlowControl(pending_flow_control_quota_) must be called
392 // after channel_->SendAddChannelRequest() below.
393 // We post OnFlowControl() here using |weak_ptr_factory_| instead of
394 // calling SendFlowControl directly, because |this| may have been deleted
395 // after channel_->SendAddChannelRequest().
396 base::ThreadTaskRunnerHandle::Get()->PostTask(
397 FROM_HERE, base::Bind(&WebSocketHost::OnFlowControl,
398 weak_ptr_factory_.GetWeakPtr(),
399 pending_flow_control_quota_));
400 pending_flow_control_quota_ = 0;
403 channel_->SendAddChannelRequest(socket_url, requested_protocols, origin);
404 // |this| may have been deleted here.
407 void WebSocketHost::OnSendFrame(bool fin,
408 WebSocketMessageType type,
409 const std::vector<char>& data) {
410 DVLOG(3) << "WebSocketHost::OnSendFrame"
411 << " routing_id=" << routing_id_ << " fin=" << fin
412 << " type=" << type << " data is " << data.size() << " bytes";
414 DCHECK(channel_);
415 channel_->SendFrame(fin, MessageTypeToOpCode(type), data);
418 void WebSocketHost::OnFlowControl(int64 quota) {
419 DVLOG(3) << "WebSocketHost::OnFlowControl"
420 << " routing_id=" << routing_id_ << " quota=" << quota;
422 if (!channel_) {
423 // WebSocketChannel is not yet created due to the delay introduced by
424 // per-renderer WebSocket throttling.
425 // SendFlowControl() is called after WebSocketChannel is created.
426 pending_flow_control_quota_ += quota;
427 return;
430 channel_->SendFlowControl(quota);
433 void WebSocketHost::OnDropChannel(bool was_clean,
434 uint16 code,
435 const std::string& reason) {
436 DVLOG(3) << "WebSocketHost::OnDropChannel"
437 << " routing_id=" << routing_id_ << " was_clean=" << was_clean
438 << " code=" << code << " reason=\"" << reason << "\"";
440 if (!channel_) {
441 // WebSocketChannel is not yet created due to the delay introduced by
442 // per-renderer WebSocket throttling.
443 WebSocketDispatcherHost::WebSocketHostState result =
444 dispatcher_->DoDropChannel(routing_id_,
445 false,
446 net::kWebSocketErrorAbnormalClosure,
447 "");
448 DCHECK_EQ(WebSocketDispatcherHost::WEBSOCKET_HOST_DELETED, result);
449 return;
452 // TODO(yhirano): Handle |was_clean| appropriately.
453 channel_->StartClosingHandshake(code, reason);
456 } // namespace content