1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_basic_handshake_stream.h"
13 #include "base/base64.h"
14 #include "base/basictypes.h"
15 #include "base/bind.h"
16 #include "base/compiler_specific.h"
17 #include "base/containers/hash_tables.h"
18 #include "base/logging.h"
19 #include "base/metrics/histogram.h"
20 #include "base/metrics/sparse_histogram.h"
21 #include "base/stl_util.h"
22 #include "base/strings/string_number_conversions.h"
23 #include "base/strings/string_piece.h"
24 #include "base/strings/string_util.h"
25 #include "base/strings/stringprintf.h"
26 #include "base/time/time.h"
27 #include "crypto/random.h"
28 #include "net/base/io_buffer.h"
29 #include "net/http/http_request_headers.h"
30 #include "net/http/http_request_info.h"
31 #include "net/http/http_response_body_drainer.h"
32 #include "net/http/http_response_headers.h"
33 #include "net/http/http_status_code.h"
34 #include "net/http/http_stream_parser.h"
35 #include "net/socket/client_socket_handle.h"
36 #include "net/socket/websocket_transport_client_socket_pool.h"
37 #include "net/websockets/websocket_basic_stream.h"
38 #include "net/websockets/websocket_deflate_predictor.h"
39 #include "net/websockets/websocket_deflate_predictor_impl.h"
40 #include "net/websockets/websocket_deflate_stream.h"
41 #include "net/websockets/websocket_deflater.h"
42 #include "net/websockets/websocket_extension_parser.h"
43 #include "net/websockets/websocket_handshake_challenge.h"
44 #include "net/websockets/websocket_handshake_constants.h"
45 #include "net/websockets/websocket_handshake_request_info.h"
46 #include "net/websockets/websocket_handshake_response_info.h"
47 #include "net/websockets/websocket_stream.h"
53 // TODO(yhirano): Remove these functions once http://crbug.com/399535 is fixed.
54 NOINLINE
void RunCallbackWithOk(const CompletionCallback
& callback
,
56 DCHECK_EQ(result
, OK
);
60 NOINLINE
void RunCallbackWithInvalidResponseCausedByRedirect(
61 const CompletionCallback
& callback
,
63 DCHECK_EQ(result
, ERR_INVALID_RESPONSE
);
64 callback
.Run(ERR_INVALID_RESPONSE
);
67 NOINLINE
void RunCallbackWithInvalidResponse(
68 const CompletionCallback
& callback
,
70 DCHECK_EQ(result
, ERR_INVALID_RESPONSE
);
71 callback
.Run(ERR_INVALID_RESPONSE
);
74 NOINLINE
void RunCallback(const CompletionCallback
& callback
, int result
) {
80 // TODO(ricea): If more extensions are added, replace this with a more general
82 struct WebSocketExtensionParams
{
83 WebSocketExtensionParams()
84 : deflate_enabled(false),
85 client_window_bits(15),
86 deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT
) {}
89 int client_window_bits
;
90 WebSocketDeflater::ContextTakeOverMode deflate_mode
;
95 enum GetHeaderResult
{
101 std::string
MissingHeaderMessage(const std::string
& header_name
) {
102 return std::string("'") + header_name
+ "' header is missing";
105 std::string
MultipleHeaderValuesMessage(const std::string
& header_name
) {
109 "' header must not appear more than once in a response";
112 std::string
GenerateHandshakeChallenge() {
113 std::string
raw_challenge(websockets::kRawChallengeLength
, '\0');
114 crypto::RandBytes(string_as_array(&raw_challenge
), raw_challenge
.length());
115 std::string encoded_challenge
;
116 base::Base64Encode(raw_challenge
, &encoded_challenge
);
117 return encoded_challenge
;
120 void AddVectorHeaderIfNonEmpty(const char* name
,
121 const std::vector
<std::string
>& value
,
122 HttpRequestHeaders
* headers
) {
125 headers
->SetHeader(name
, JoinString(value
, ", "));
128 GetHeaderResult
GetSingleHeaderValue(const HttpResponseHeaders
* headers
,
129 const base::StringPiece
& name
,
130 std::string
* value
) {
131 void* state
= nullptr;
132 size_t num_values
= 0;
133 std::string temp_value
;
134 while (headers
->EnumerateHeader(&state
, name
, &temp_value
)) {
135 if (++num_values
> 1)
136 return GET_HEADER_MULTIPLE
;
139 return num_values
> 0 ? GET_HEADER_OK
: GET_HEADER_MISSING
;
142 bool ValidateHeaderHasSingleValue(GetHeaderResult result
,
143 const std::string
& header_name
,
144 std::string
* failure_message
) {
145 if (result
== GET_HEADER_MISSING
) {
146 *failure_message
= MissingHeaderMessage(header_name
);
149 if (result
== GET_HEADER_MULTIPLE
) {
150 *failure_message
= MultipleHeaderValuesMessage(header_name
);
153 DCHECK_EQ(result
, GET_HEADER_OK
);
157 bool ValidateUpgrade(const HttpResponseHeaders
* headers
,
158 std::string
* failure_message
) {
160 GetHeaderResult result
=
161 GetSingleHeaderValue(headers
, websockets::kUpgrade
, &value
);
162 if (!ValidateHeaderHasSingleValue(result
,
163 websockets::kUpgrade
,
168 if (!LowerCaseEqualsASCII(value
, websockets::kWebSocketLowercase
)) {
170 "'Upgrade' header value is not 'WebSocket': " + value
;
176 bool ValidateSecWebSocketAccept(const HttpResponseHeaders
* headers
,
177 const std::string
& expected
,
178 std::string
* failure_message
) {
180 GetHeaderResult result
=
181 GetSingleHeaderValue(headers
, websockets::kSecWebSocketAccept
, &actual
);
182 if (!ValidateHeaderHasSingleValue(result
,
183 websockets::kSecWebSocketAccept
,
188 if (expected
!= actual
) {
189 *failure_message
= "Incorrect 'Sec-WebSocket-Accept' header value";
195 bool ValidateConnection(const HttpResponseHeaders
* headers
,
196 std::string
* failure_message
) {
197 // Connection header is permitted to contain other tokens.
198 if (!headers
->HasHeader(HttpRequestHeaders::kConnection
)) {
199 *failure_message
= MissingHeaderMessage(HttpRequestHeaders::kConnection
);
202 if (!headers
->HasHeaderValue(HttpRequestHeaders::kConnection
,
203 websockets::kUpgrade
)) {
204 *failure_message
= "'Connection' header value must contain 'Upgrade'";
210 bool ValidateSubProtocol(
211 const HttpResponseHeaders
* headers
,
212 const std::vector
<std::string
>& requested_sub_protocols
,
213 std::string
* sub_protocol
,
214 std::string
* failure_message
) {
215 void* state
= nullptr;
217 base::hash_set
<std::string
> requested_set(requested_sub_protocols
.begin(),
218 requested_sub_protocols
.end());
220 bool has_multiple_protocols
= false;
221 bool has_invalid_protocol
= false;
223 while (!has_invalid_protocol
|| !has_multiple_protocols
) {
224 std::string temp_value
;
225 if (!headers
->EnumerateHeader(
226 &state
, websockets::kSecWebSocketProtocol
, &temp_value
))
229 if (requested_set
.count(value
) == 0)
230 has_invalid_protocol
= true;
232 has_multiple_protocols
= true;
235 if (has_multiple_protocols
) {
237 MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol
);
239 } else if (count
> 0 && requested_sub_protocols
.size() == 0) {
241 std::string("Response must not include 'Sec-WebSocket-Protocol' "
242 "header if not present in request: ")
245 } else if (has_invalid_protocol
) {
247 "'Sec-WebSocket-Protocol' header value '" +
249 "' in response does not match any of sent values";
251 } else if (requested_sub_protocols
.size() > 0 && count
== 0) {
253 "Sent non-empty 'Sec-WebSocket-Protocol' header "
254 "but no response was received";
257 *sub_protocol
= value
;
261 bool DeflateError(std::string
* message
, const base::StringPiece
& piece
) {
262 *message
= "Error in permessage-deflate: ";
263 piece
.AppendToString(message
);
267 bool ValidatePerMessageDeflateExtension(const WebSocketExtension
& extension
,
268 std::string
* failure_message
,
269 WebSocketExtensionParams
* params
) {
270 static const char kClientPrefix
[] = "client_";
271 static const char kServerPrefix
[] = "server_";
272 static const char kNoContextTakeover
[] = "no_context_takeover";
273 static const char kMaxWindowBits
[] = "max_window_bits";
274 const size_t kPrefixLen
= arraysize(kClientPrefix
) - 1;
275 static_assert(kPrefixLen
== arraysize(kServerPrefix
) - 1,
276 "the strings server and client must be the same length");
277 typedef std::vector
<WebSocketExtension::Parameter
> ParameterVector
;
279 DCHECK_EQ("permessage-deflate", extension
.name());
280 const ParameterVector
& parameters
= extension
.parameters();
281 std::set
<std::string
> seen_names
;
282 for (ParameterVector::const_iterator it
= parameters
.begin();
283 it
!= parameters
.end(); ++it
) {
284 const std::string
& name
= it
->name();
285 if (seen_names
.count(name
) != 0) {
288 "Received duplicate permessage-deflate extension parameter " + name
);
290 seen_names
.insert(name
);
291 const std::string
client_or_server(name
, 0, kPrefixLen
);
292 const bool is_client
= (client_or_server
== kClientPrefix
);
293 if (!is_client
&& client_or_server
!= kServerPrefix
) {
296 "Received an unexpected permessage-deflate extension parameter");
298 const std::string
rest(name
, kPrefixLen
);
299 if (rest
== kNoContextTakeover
) {
300 if (it
->HasValue()) {
301 return DeflateError(failure_message
,
302 "Received invalid " + name
+ " parameter");
305 params
->deflate_mode
= WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT
;
306 } else if (rest
== kMaxWindowBits
) {
308 return DeflateError(failure_message
, name
+ " must have value");
310 if (!base::StringToInt(it
->value(), &bits
) || bits
< 8 || bits
> 15 ||
311 it
->value()[0] == '0' ||
312 it
->value().find_first_not_of("0123456789") != std::string::npos
) {
313 return DeflateError(failure_message
,
314 "Received invalid " + name
+ " parameter");
317 params
->client_window_bits
= bits
;
321 "Received an unexpected permessage-deflate extension parameter");
324 params
->deflate_enabled
= true;
328 bool ValidateExtensions(const HttpResponseHeaders
* headers
,
329 const std::vector
<std::string
>& requested_extensions
,
330 std::string
* extensions
,
331 std::string
* failure_message
,
332 WebSocketExtensionParams
* params
) {
333 void* state
= nullptr;
335 std::vector
<std::string
> accepted_extensions
;
336 // TODO(ricea): If adding support for additional extensions, generalise this
338 bool seen_permessage_deflate
= false;
339 while (headers
->EnumerateHeader(
340 &state
, websockets::kSecWebSocketExtensions
, &value
)) {
341 WebSocketExtensionParser parser
;
343 if (parser
.has_error()) {
344 // TODO(yhirano) Set appropriate failure message.
346 "'Sec-WebSocket-Extensions' header value is "
347 "rejected by the parser: " +
351 if (parser
.extension().name() == "permessage-deflate") {
352 if (seen_permessage_deflate
) {
353 *failure_message
= "Received duplicate permessage-deflate response";
356 seen_permessage_deflate
= true;
357 if (!ValidatePerMessageDeflateExtension(
358 parser
.extension(), failure_message
, params
))
362 "Found an unsupported extension '" +
363 parser
.extension().name() +
364 "' in 'Sec-WebSocket-Extensions' header";
367 accepted_extensions
.push_back(value
);
369 *extensions
= JoinString(accepted_extensions
, ", ");
375 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream(
376 scoped_ptr
<ClientSocketHandle
> connection
,
377 WebSocketStream::ConnectDelegate
* connect_delegate
,
379 std::vector
<std::string
> requested_sub_protocols
,
380 std::vector
<std::string
> requested_extensions
,
381 std::string
* failure_message
)
382 : state_(connection
.release(), using_proxy
),
383 connect_delegate_(connect_delegate
),
384 http_response_info_(nullptr),
385 requested_sub_protocols_(requested_sub_protocols
),
386 requested_extensions_(requested_extensions
),
387 failure_message_(failure_message
) {
388 DCHECK(connect_delegate
);
389 DCHECK(failure_message
);
392 WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {}
394 int WebSocketBasicHandshakeStream::InitializeStream(
395 const HttpRequestInfo
* request_info
,
396 RequestPriority priority
,
397 const BoundNetLog
& net_log
,
398 const CompletionCallback
& callback
) {
399 url_
= request_info
->url
;
400 state_
.Initialize(request_info
, priority
, net_log
, callback
);
404 int WebSocketBasicHandshakeStream::SendRequest(
405 const HttpRequestHeaders
& headers
,
406 HttpResponseInfo
* response
,
407 const CompletionCallback
& callback
) {
408 DCHECK(!headers
.HasHeader(websockets::kSecWebSocketKey
));
409 DCHECK(!headers
.HasHeader(websockets::kSecWebSocketProtocol
));
410 DCHECK(!headers
.HasHeader(websockets::kSecWebSocketExtensions
));
411 DCHECK(headers
.HasHeader(HttpRequestHeaders::kOrigin
));
412 DCHECK(headers
.HasHeader(websockets::kUpgrade
));
413 DCHECK(headers
.HasHeader(HttpRequestHeaders::kConnection
));
414 DCHECK(headers
.HasHeader(websockets::kSecWebSocketVersion
));
417 http_response_info_
= response
;
419 // Create a copy of the headers object, so that we can add the
420 // Sec-WebSockey-Key header.
421 HttpRequestHeaders enriched_headers
;
422 enriched_headers
.CopyFrom(headers
);
423 std::string handshake_challenge
;
424 if (handshake_challenge_for_testing_
) {
425 handshake_challenge
= *handshake_challenge_for_testing_
;
426 handshake_challenge_for_testing_
.reset();
428 handshake_challenge
= GenerateHandshakeChallenge();
430 enriched_headers
.SetHeader(websockets::kSecWebSocketKey
, handshake_challenge
);
432 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions
,
433 requested_extensions_
,
435 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol
,
436 requested_sub_protocols_
,
439 handshake_challenge_response_
=
440 ComputeSecWebSocketAccept(handshake_challenge
);
442 DCHECK(connect_delegate_
);
443 scoped_ptr
<WebSocketHandshakeRequestInfo
> request(
444 new WebSocketHandshakeRequestInfo(url_
, base::Time::Now()));
445 request
->headers
.CopyFrom(enriched_headers
);
446 connect_delegate_
->OnStartOpeningHandshake(request
.Pass());
448 return parser()->SendRequest(
449 state_
.GenerateRequestLine(), enriched_headers
, response
, callback
);
452 int WebSocketBasicHandshakeStream::ReadResponseHeaders(
453 const CompletionCallback
& callback
) {
454 // HttpStreamParser uses a weak pointer when reading from the
455 // socket, so it won't be called back after being destroyed. The
456 // HttpStreamParser is owned by HttpBasicState which is owned by this object,
457 // so this use of base::Unretained() is safe.
458 int rv
= parser()->ReadResponseHeaders(
459 base::Bind(&WebSocketBasicHandshakeStream::ReadResponseHeadersCallback
,
460 base::Unretained(this),
462 if (rv
== ERR_IO_PENDING
)
464 bool is_redirect
= false;
465 return ValidateResponse(rv
, &is_redirect
);
468 int WebSocketBasicHandshakeStream::ReadResponseBody(
471 const CompletionCallback
& callback
) {
472 return parser()->ReadResponseBody(buf
, buf_len
, callback
);
475 void WebSocketBasicHandshakeStream::Close(bool not_reusable
) {
476 // This class ignores the value of |not_reusable| and never lets the socket be
479 parser()->Close(true);
482 bool WebSocketBasicHandshakeStream::IsResponseBodyComplete() const {
483 return parser()->IsResponseBodyComplete();
486 bool WebSocketBasicHandshakeStream::CanFindEndOfResponse() const {
487 return parser() && parser()->CanFindEndOfResponse();
490 bool WebSocketBasicHandshakeStream::IsConnectionReused() const {
491 return parser()->IsConnectionReused();
494 void WebSocketBasicHandshakeStream::SetConnectionReused() {
495 parser()->SetConnectionReused();
498 bool WebSocketBasicHandshakeStream::IsConnectionReusable() const {
502 int64
WebSocketBasicHandshakeStream::GetTotalReceivedBytes() const {
506 bool WebSocketBasicHandshakeStream::GetLoadTimingInfo(
507 LoadTimingInfo
* load_timing_info
) const {
508 return state_
.connection()->GetLoadTimingInfo(IsConnectionReused(),
512 void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo
* ssl_info
) {
513 parser()->GetSSLInfo(ssl_info
);
516 void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo(
517 SSLCertRequestInfo
* cert_request_info
) {
518 parser()->GetSSLCertRequestInfo(cert_request_info
);
521 bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { return false; }
523 void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession
* session
) {
524 HttpResponseBodyDrainer
* drainer
= new HttpResponseBodyDrainer(this);
525 drainer
->Start(session
);
526 // |drainer| will delete itself.
529 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority
) {
530 // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is
531 // gone, then copy whatever has happened there over here.
534 UploadProgress
WebSocketBasicHandshakeStream::GetUploadProgress() const {
535 return UploadProgress();
538 HttpStream
* WebSocketBasicHandshakeStream::RenewStreamForAuth() {
539 // Return null because we don't support renewing the stream.
543 scoped_ptr
<WebSocketStream
> WebSocketBasicHandshakeStream::Upgrade() {
544 // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make
545 // sure it does not touch it again before it is destroyed.
546 state_
.DeleteParser();
547 WebSocketTransportClientSocketPool::UnlockEndpoint(state_
.connection());
548 scoped_ptr
<WebSocketStream
> basic_stream(
549 new WebSocketBasicStream(state_
.ReleaseConnection(),
553 DCHECK(extension_params_
.get());
554 if (extension_params_
->deflate_enabled
) {
555 UMA_HISTOGRAM_ENUMERATION(
556 "Net.WebSocket.DeflateMode",
557 extension_params_
->deflate_mode
,
558 WebSocketDeflater::NUM_CONTEXT_TAKEOVER_MODE_TYPES
);
560 return scoped_ptr
<WebSocketStream
>(
561 new WebSocketDeflateStream(basic_stream
.Pass(),
562 extension_params_
->deflate_mode
,
563 extension_params_
->client_window_bits
,
564 scoped_ptr
<WebSocketDeflatePredictor
>(
565 new WebSocketDeflatePredictorImpl
)));
567 return basic_stream
.Pass();
571 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
572 const std::string
& key
) {
573 handshake_challenge_for_testing_
.reset(new std::string(key
));
576 void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback(
577 const CompletionCallback
& callback
,
579 bool is_redirect
= false;
580 int rv
= ValidateResponse(result
, &is_redirect
);
582 // TODO(yhirano): Simplify this statement once http://crbug.com/399535 is
586 RunCallbackWithOk(callback
, rv
);
588 case ERR_INVALID_RESPONSE
:
590 RunCallbackWithInvalidResponseCausedByRedirect(callback
, rv
);
592 RunCallbackWithInvalidResponse(callback
, rv
);
595 RunCallback(callback
, rv
);
600 void WebSocketBasicHandshakeStream::OnFinishOpeningHandshake() {
601 DCHECK(http_response_info_
);
602 WebSocketDispatchOnFinishOpeningHandshake(connect_delegate_
,
604 http_response_info_
->headers
,
605 http_response_info_
->response_time
);
608 int WebSocketBasicHandshakeStream::ValidateResponse(int rv
,
610 DCHECK(http_response_info_
);
611 *is_redirect
= false;
612 // Most net errors happen during connection, so they are not seen by this
613 // method. The histogram for error codes is created in
614 // Delegate::OnResponseStarted in websocket_stream.cc instead.
616 const HttpResponseHeaders
* headers
= http_response_info_
->headers
.get();
617 const int response_code
= headers
->response_code();
618 *is_redirect
= HttpResponseHeaders::IsRedirectResponseCode(response_code
);
619 UMA_HISTOGRAM_SPARSE_SLOWLY("Net.WebSocket.ResponseCode", response_code
);
620 switch (response_code
) {
621 case HTTP_SWITCHING_PROTOCOLS
:
622 OnFinishOpeningHandshake();
623 return ValidateUpgradeResponse(headers
);
625 // We need to pass these through for authentication to work.
626 case HTTP_UNAUTHORIZED
:
627 case HTTP_PROXY_AUTHENTICATION_REQUIRED
:
630 // Other status codes are potentially risky (see the warnings in the
631 // WHATWG WebSocket API spec) and so are dropped by default.
633 // A WebSocket server cannot be using HTTP/0.9, so if we see version
634 // 0.9, it means the response was garbage.
635 // Reporting "Unexpected response code: 200" in this case is not
636 // helpful, so use a different error message.
637 if (headers
->GetHttpVersion() == HttpVersion(0, 9)) {
639 "Error during WebSocket handshake: Invalid status line");
641 set_failure_message(base::StringPrintf(
642 "Error during WebSocket handshake: Unexpected response code: %d",
643 headers
->response_code()));
645 OnFinishOpeningHandshake();
646 return ERR_INVALID_RESPONSE
;
649 if (rv
== ERR_EMPTY_RESPONSE
) {
651 "Connection closed before receiving a handshake response");
654 set_failure_message(std::string("Error during WebSocket handshake: ") +
656 OnFinishOpeningHandshake();
661 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
662 const HttpResponseHeaders
* headers
) {
663 extension_params_
.reset(new WebSocketExtensionParams
);
664 std::string failure_message
;
665 if (ValidateUpgrade(headers
, &failure_message
) &&
666 ValidateSecWebSocketAccept(
667 headers
, handshake_challenge_response_
, &failure_message
) &&
668 ValidateConnection(headers
, &failure_message
) &&
669 ValidateSubProtocol(headers
,
670 requested_sub_protocols_
,
673 ValidateExtensions(headers
,
674 requested_extensions_
,
677 extension_params_
.get())) {
680 set_failure_message("Error during WebSocket handshake: " + failure_message
);
681 return ERR_INVALID_RESPONSE
;
684 void WebSocketBasicHandshakeStream::set_failure_message(
685 const std::string
& failure_message
) {
686 *failure_message_
= failure_message
;