1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_stream.h"
7 #include "base/logging.h"
8 #include "base/memory/scoped_ptr.h"
9 #include "base/metrics/histogram.h"
10 #include "base/metrics/sparse_histogram.h"
11 #include "base/profiler/scoped_tracker.h"
12 #include "base/strings/stringprintf.h"
13 #include "base/time/time.h"
14 #include "base/timer/timer.h"
15 #include "net/base/load_flags.h"
16 #include "net/http/http_request_headers.h"
17 #include "net/http/http_response_headers.h"
18 #include "net/http/http_status_code.h"
19 #include "net/url_request/redirect_info.h"
20 #include "net/url_request/url_request.h"
21 #include "net/url_request/url_request_context.h"
22 #include "net/websockets/websocket_errors.h"
23 #include "net/websockets/websocket_event_interface.h"
24 #include "net/websockets/websocket_handshake_constants.h"
25 #include "net/websockets/websocket_handshake_stream_base.h"
26 #include "net/websockets/websocket_handshake_stream_create_helper.h"
28 #include "url/origin.h"
33 // The timeout duration of WebSocket handshake.
34 // It is defined as the same value as the TCP connection timeout value in
35 // net/socket/websocket_transport_client_socket_pool.cc to make it hard for
36 // JavaScript programs to recognize the timeout cause.
37 const int kHandshakeTimeoutIntervalInSeconds
= 240;
39 class StreamRequestImpl
;
41 class Delegate
: public URLRequest::Delegate
{
43 enum HandshakeResult
{
47 NUM_HANDSHAKE_RESULT_TYPES
,
50 explicit Delegate(StreamRequestImpl
* owner
)
51 : owner_(owner
), result_(INCOMPLETE
) {}
52 ~Delegate() override
{
53 UMA_HISTOGRAM_ENUMERATION(
54 "Net.WebSocket.HandshakeResult", result_
, NUM_HANDSHAKE_RESULT_TYPES
);
57 // Implementation of URLRequest::Delegate methods.
58 void OnReceivedRedirect(URLRequest
* request
,
59 const RedirectInfo
& redirect_info
,
60 bool* defer_redirect
) override
;
62 void OnResponseStarted(URLRequest
* request
) override
;
64 void OnAuthRequired(URLRequest
* request
,
65 AuthChallengeInfo
* auth_info
) override
;
67 void OnCertificateRequested(URLRequest
* request
,
68 SSLCertRequestInfo
* cert_request_info
) override
;
70 void OnSSLCertificateError(URLRequest
* request
,
71 const SSLInfo
& ssl_info
,
74 void OnReadCompleted(URLRequest
* request
, int bytes_read
) override
;
77 StreamRequestImpl
* owner_
;
78 HandshakeResult result_
;
81 class StreamRequestImpl
: public WebSocketStreamRequest
{
85 const URLRequestContext
* context
,
86 const url::Origin
& origin
,
87 scoped_ptr
<WebSocketStream::ConnectDelegate
> connect_delegate
,
88 scoped_ptr
<WebSocketHandshakeStreamCreateHelper
> create_helper
)
89 : delegate_(new Delegate(this)),
91 context
->CreateRequest(url
, DEFAULT_PRIORITY
, delegate_
.get())),
92 connect_delegate_(connect_delegate
.Pass()),
93 create_helper_(create_helper
.release()) {
94 create_helper_
->set_failure_message(&failure_message_
);
95 HttpRequestHeaders headers
;
96 headers
.SetHeader(websockets::kUpgrade
, websockets::kWebSocketLowercase
);
97 headers
.SetHeader(HttpRequestHeaders::kConnection
, websockets::kUpgrade
);
98 headers
.SetHeader(HttpRequestHeaders::kOrigin
, origin
.string());
99 headers
.SetHeader(websockets::kSecWebSocketVersion
,
100 websockets::kSupportedVersion
);
101 url_request_
->SetExtraRequestHeaders(headers
);
103 // This passes the ownership of |create_helper_| to |url_request_|.
104 url_request_
->SetUserData(
105 WebSocketHandshakeStreamBase::CreateHelper::DataKey(),
107 url_request_
->SetLoadFlags(LOAD_DISABLE_CACHE
| LOAD_BYPASS_CACHE
);
110 // Destroying this object destroys the URLRequest, which cancels the request
111 // and so terminates the handshake if it is incomplete.
112 ~StreamRequestImpl() override
{}
114 void Start(scoped_ptr
<base::Timer
> timer
) {
116 base::TimeDelta
timeout(base::TimeDelta::FromSeconds(
117 kHandshakeTimeoutIntervalInSeconds
));
118 timer_
= timer
.Pass();
119 timer_
->Start(FROM_HERE
, timeout
,
120 base::Bind(&StreamRequestImpl::OnTimeout
,
121 base::Unretained(this)));
122 url_request_
->Start();
125 void PerformUpgrade() {
128 connect_delegate_
->OnSuccess(create_helper_
->Upgrade());
131 std::string
FailureMessageFromNetError() {
132 int error
= url_request_
->status().error();
133 if (error
== ERR_TUNNEL_CONNECTION_FAILED
) {
134 // This error is common and confusing, so special-case it.
135 // TODO(ricea): Include the HostPortPair of the selected proxy server in
136 // the error message. This is not currently possible because it isn't set
137 // in HttpResponseInfo when a ERR_TUNNEL_CONNECTION_FAILED error happens.
138 return "Establishing a tunnel via proxy server failed.";
140 return std::string("Error in connection establishment: ") +
141 ErrorToString(url_request_
->status().error());
145 void ReportFailure() {
148 if (failure_message_
.empty()) {
149 switch (url_request_
->status().status()) {
150 case URLRequestStatus::SUCCESS
:
151 case URLRequestStatus::IO_PENDING
:
153 case URLRequestStatus::CANCELED
:
154 if (url_request_
->status().error() == ERR_TIMED_OUT
)
155 failure_message_
= "WebSocket opening handshake timed out";
157 failure_message_
= "WebSocket opening handshake was canceled";
159 case URLRequestStatus::FAILED
:
160 failure_message_
= FailureMessageFromNetError();
164 ReportFailureWithMessage(failure_message_
);
167 void ReportFailureWithMessage(const std::string
& failure_message
) {
168 connect_delegate_
->OnFailure(failure_message
);
171 void OnFinishOpeningHandshake() {
172 WebSocketDispatchOnFinishOpeningHandshake(connect_delegate(),
174 url_request_
->response_headers(),
175 url_request_
->response_time());
178 WebSocketStream::ConnectDelegate
* connect_delegate() const {
179 return connect_delegate_
.get();
183 url_request_
->CancelWithError(ERR_TIMED_OUT
);
187 // |delegate_| needs to be declared before |url_request_| so that it gets
188 // initialised first.
189 scoped_ptr
<Delegate
> delegate_
;
191 // Deleting the StreamRequestImpl object deletes this URLRequest object,
192 // cancelling the whole connection.
193 scoped_ptr
<URLRequest
> url_request_
;
195 scoped_ptr
<WebSocketStream::ConnectDelegate
> connect_delegate_
;
197 // Owned by the URLRequest.
198 WebSocketHandshakeStreamCreateHelper
* create_helper_
;
200 // The failure message supplied by WebSocketBasicHandshakeStream, if any.
201 std::string failure_message_
;
203 // A timer for handshake timeout.
204 scoped_ptr
<base::Timer
> timer_
;
207 class SSLErrorCallbacks
: public WebSocketEventInterface::SSLErrorCallbacks
{
209 explicit SSLErrorCallbacks(URLRequest
* url_request
)
210 : url_request_(url_request
) {}
212 void CancelSSLRequest(int error
, const SSLInfo
* ssl_info
) override
{
214 url_request_
->CancelWithSSLError(error
, *ssl_info
);
216 url_request_
->CancelWithError(error
);
220 void ContinueSSLRequest() override
{
221 url_request_
->ContinueDespiteLastError();
225 URLRequest
* url_request_
;
228 void Delegate::OnReceivedRedirect(URLRequest
* request
,
229 const RedirectInfo
& redirect_info
,
230 bool* defer_redirect
) {
231 // This code should never be reached for externally generated redirects,
232 // as WebSocketBasicHandshakeStream is responsible for filtering out
233 // all response codes besides 101, 401, and 407. As such, the URLRequest
234 // should never see a redirect sent over the network. However, internal
235 // redirects also result in this method being called, such as those
237 // Because it's security critical to prevent externally-generated
238 // redirects in WebSockets, perform additional checks to ensure this
240 GURL::Replacements replacements
;
241 replacements
.SetSchemeStr("wss");
242 GURL expected_url
= request
->original_url().ReplaceComponents(replacements
);
243 if (redirect_info
.new_method
!= "GET" ||
244 redirect_info
.new_url
!= expected_url
) {
245 // This should not happen.
246 DLOG(FATAL
) << "Unauthorized WebSocket redirect to "
247 << redirect_info
.new_method
<< " "
248 << redirect_info
.new_url
.spec();
253 void Delegate::OnResponseStarted(URLRequest
* request
) {
254 // TODO(vadimt): Remove ScopedTracker below once crbug.com/423948 is fixed.
255 tracked_objects::ScopedTracker
tracking_profile(
256 FROM_HERE_WITH_EXPLICIT_FUNCTION("423948 Delegate::OnResponseStarted"));
258 // All error codes, including OK and ABORTED, as with
259 // Net.ErrorCodesForMainFrame3
260 UMA_HISTOGRAM_SPARSE_SLOWLY("Net.WebSocket.ErrorCodes",
261 -request
->status().error());
262 if (!request
->status().is_success()) {
263 DVLOG(3) << "OnResponseStarted (request failed)";
264 owner_
->ReportFailure();
267 const int response_code
= request
->GetResponseCode();
268 DVLOG(3) << "OnResponseStarted (response code " << response_code
<< ")";
269 switch (response_code
) {
270 case HTTP_SWITCHING_PROTOCOLS
:
272 owner_
->PerformUpgrade();
275 case HTTP_UNAUTHORIZED
:
277 owner_
->OnFinishOpeningHandshake();
278 owner_
->ReportFailureWithMessage(
279 "HTTP Authentication failed; no valid credentials available");
282 case HTTP_PROXY_AUTHENTICATION_REQUIRED
:
284 owner_
->OnFinishOpeningHandshake();
285 owner_
->ReportFailureWithMessage("Proxy authentication failed");
290 owner_
->ReportFailure();
294 void Delegate::OnAuthRequired(URLRequest
* request
,
295 AuthChallengeInfo
* auth_info
) {
296 // This should only be called if credentials are not already stored.
297 request
->CancelAuth();
300 void Delegate::OnCertificateRequested(URLRequest
* request
,
301 SSLCertRequestInfo
* cert_request_info
) {
302 // This method is called when a client certificate is requested, and the
303 // request context does not already contain a client certificate selection for
304 // the endpoint. In this case, a main frame resource request would pop-up UI
305 // to permit selection of a client certificate, but since WebSockets are
306 // sub-resources they should not pop-up UI and so there is nothing more we can
311 void Delegate::OnSSLCertificateError(URLRequest
* request
,
312 const SSLInfo
& ssl_info
,
314 owner_
->connect_delegate()->OnSSLCertificateError(
315 scoped_ptr
<WebSocketEventInterface::SSLErrorCallbacks
>(
316 new SSLErrorCallbacks(request
)),
321 void Delegate::OnReadCompleted(URLRequest
* request
, int bytes_read
) {
327 WebSocketStreamRequest::~WebSocketStreamRequest() {}
329 WebSocketStream::WebSocketStream() {}
330 WebSocketStream::~WebSocketStream() {}
332 WebSocketStream::ConnectDelegate::~ConnectDelegate() {}
334 scoped_ptr
<WebSocketStreamRequest
> WebSocketStream::CreateAndConnectStream(
335 const GURL
& socket_url
,
336 const std::vector
<std::string
>& requested_subprotocols
,
337 const url::Origin
& origin
,
338 URLRequestContext
* url_request_context
,
339 const BoundNetLog
& net_log
,
340 scoped_ptr
<ConnectDelegate
> connect_delegate
) {
341 scoped_ptr
<WebSocketHandshakeStreamCreateHelper
> create_helper(
342 new WebSocketHandshakeStreamCreateHelper(connect_delegate
.get(),
343 requested_subprotocols
));
344 scoped_ptr
<StreamRequestImpl
> request(
345 new StreamRequestImpl(socket_url
,
348 connect_delegate
.Pass(),
349 create_helper
.Pass()));
350 request
->Start(scoped_ptr
<base::Timer
>(new base::Timer(false, false)));
351 return request
.Pass();
354 // This is declared in websocket_test_util.h.
355 scoped_ptr
<WebSocketStreamRequest
> CreateAndConnectStreamForTesting(
356 const GURL
& socket_url
,
357 scoped_ptr
<WebSocketHandshakeStreamCreateHelper
> create_helper
,
358 const url::Origin
& origin
,
359 URLRequestContext
* url_request_context
,
360 const BoundNetLog
& net_log
,
361 scoped_ptr
<WebSocketStream::ConnectDelegate
> connect_delegate
,
362 scoped_ptr
<base::Timer
> timer
) {
363 scoped_ptr
<StreamRequestImpl
> request(
364 new StreamRequestImpl(socket_url
,
367 connect_delegate
.Pass(),
368 create_helper
.Pass()));
369 request
->Start(timer
.Pass());
370 return request
.Pass();
373 void WebSocketDispatchOnFinishOpeningHandshake(
374 WebSocketStream::ConnectDelegate
* connect_delegate
,
376 const scoped_refptr
<HttpResponseHeaders
>& headers
,
377 base::Time response_time
) {
378 DCHECK(connect_delegate
);
380 connect_delegate
->OnFinishOpeningHandshake(make_scoped_ptr(
381 new WebSocketHandshakeResponseInfo(url
,
382 headers
->response_code(),
383 headers
->GetStatusText(),