1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_basic_handshake_stream.h"
13 #include "base/base64.h"
14 #include "base/basictypes.h"
15 #include "base/bind.h"
16 #include "base/containers/hash_tables.h"
17 #include "base/logging.h"
18 #include "base/metrics/histogram.h"
19 #include "base/metrics/sparse_histogram.h"
20 #include "base/stl_util.h"
21 #include "base/strings/string_number_conversions.h"
22 #include "base/strings/string_piece.h"
23 #include "base/strings/string_util.h"
24 #include "base/strings/stringprintf.h"
25 #include "base/time/time.h"
26 #include "crypto/random.h"
27 #include "net/http/http_request_headers.h"
28 #include "net/http/http_request_info.h"
29 #include "net/http/http_response_body_drainer.h"
30 #include "net/http/http_response_headers.h"
31 #include "net/http/http_status_code.h"
32 #include "net/http/http_stream_parser.h"
33 #include "net/socket/client_socket_handle.h"
34 #include "net/socket/websocket_transport_client_socket_pool.h"
35 #include "net/websockets/websocket_basic_stream.h"
36 #include "net/websockets/websocket_deflate_predictor.h"
37 #include "net/websockets/websocket_deflate_predictor_impl.h"
38 #include "net/websockets/websocket_deflate_stream.h"
39 #include "net/websockets/websocket_deflater.h"
40 #include "net/websockets/websocket_extension_parser.h"
41 #include "net/websockets/websocket_handshake_constants.h"
42 #include "net/websockets/websocket_handshake_handler.h"
43 #include "net/websockets/websocket_handshake_request_info.h"
44 #include "net/websockets/websocket_handshake_response_info.h"
45 #include "net/websockets/websocket_stream.h"
49 // TODO(ricea): If more extensions are added, replace this with a more general
51 struct WebSocketExtensionParams
{
52 WebSocketExtensionParams()
53 : deflate_enabled(false),
54 client_window_bits(15),
55 deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT
) {}
58 int client_window_bits
;
59 WebSocketDeflater::ContextTakeOverMode deflate_mode
;
64 enum GetHeaderResult
{
70 std::string
MissingHeaderMessage(const std::string
& header_name
) {
71 return std::string("'") + header_name
+ "' header is missing";
74 std::string
MultipleHeaderValuesMessage(const std::string
& header_name
) {
78 "' header must not appear more than once in a response";
81 std::string
GenerateHandshakeChallenge() {
82 std::string
raw_challenge(websockets::kRawChallengeLength
, '\0');
83 crypto::RandBytes(string_as_array(&raw_challenge
), raw_challenge
.length());
84 std::string encoded_challenge
;
85 base::Base64Encode(raw_challenge
, &encoded_challenge
);
86 return encoded_challenge
;
89 void AddVectorHeaderIfNonEmpty(const char* name
,
90 const std::vector
<std::string
>& value
,
91 HttpRequestHeaders
* headers
) {
94 headers
->SetHeader(name
, JoinString(value
, ", "));
97 GetHeaderResult
GetSingleHeaderValue(const HttpResponseHeaders
* headers
,
98 const base::StringPiece
& name
,
101 size_t num_values
= 0;
102 std::string temp_value
;
103 while (headers
->EnumerateHeader(&state
, name
, &temp_value
)) {
104 if (++num_values
> 1)
105 return GET_HEADER_MULTIPLE
;
108 return num_values
> 0 ? GET_HEADER_OK
: GET_HEADER_MISSING
;
111 bool ValidateHeaderHasSingleValue(GetHeaderResult result
,
112 const std::string
& header_name
,
113 std::string
* failure_message
) {
114 if (result
== GET_HEADER_MISSING
) {
115 *failure_message
= MissingHeaderMessage(header_name
);
118 if (result
== GET_HEADER_MULTIPLE
) {
119 *failure_message
= MultipleHeaderValuesMessage(header_name
);
122 DCHECK_EQ(result
, GET_HEADER_OK
);
126 bool ValidateUpgrade(const HttpResponseHeaders
* headers
,
127 std::string
* failure_message
) {
129 GetHeaderResult result
=
130 GetSingleHeaderValue(headers
, websockets::kUpgrade
, &value
);
131 if (!ValidateHeaderHasSingleValue(result
,
132 websockets::kUpgrade
,
137 if (!LowerCaseEqualsASCII(value
, websockets::kWebSocketLowercase
)) {
139 "'Upgrade' header value is not 'WebSocket': " + value
;
145 bool ValidateSecWebSocketAccept(const HttpResponseHeaders
* headers
,
146 const std::string
& expected
,
147 std::string
* failure_message
) {
149 GetHeaderResult result
=
150 GetSingleHeaderValue(headers
, websockets::kSecWebSocketAccept
, &actual
);
151 if (!ValidateHeaderHasSingleValue(result
,
152 websockets::kSecWebSocketAccept
,
157 if (expected
!= actual
) {
158 *failure_message
= "Incorrect 'Sec-WebSocket-Accept' header value";
164 bool ValidateConnection(const HttpResponseHeaders
* headers
,
165 std::string
* failure_message
) {
166 // Connection header is permitted to contain other tokens.
167 if (!headers
->HasHeader(HttpRequestHeaders::kConnection
)) {
168 *failure_message
= MissingHeaderMessage(HttpRequestHeaders::kConnection
);
171 if (!headers
->HasHeaderValue(HttpRequestHeaders::kConnection
,
172 websockets::kUpgrade
)) {
173 *failure_message
= "'Connection' header value must contain 'Upgrade'";
179 bool ValidateSubProtocol(
180 const HttpResponseHeaders
* headers
,
181 const std::vector
<std::string
>& requested_sub_protocols
,
182 std::string
* sub_protocol
,
183 std::string
* failure_message
) {
186 base::hash_set
<std::string
> requested_set(requested_sub_protocols
.begin(),
187 requested_sub_protocols
.end());
189 bool has_multiple_protocols
= false;
190 bool has_invalid_protocol
= false;
192 while (!has_invalid_protocol
|| !has_multiple_protocols
) {
193 std::string temp_value
;
194 if (!headers
->EnumerateHeader(
195 &state
, websockets::kSecWebSocketProtocol
, &temp_value
))
198 if (requested_set
.count(value
) == 0)
199 has_invalid_protocol
= true;
201 has_multiple_protocols
= true;
204 if (has_multiple_protocols
) {
206 MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol
);
208 } else if (count
> 0 && requested_sub_protocols
.size() == 0) {
210 std::string("Response must not include 'Sec-WebSocket-Protocol' "
211 "header if not present in request: ")
214 } else if (has_invalid_protocol
) {
216 "'Sec-WebSocket-Protocol' header value '" +
218 "' in response does not match any of sent values";
220 } else if (requested_sub_protocols
.size() > 0 && count
== 0) {
222 "Sent non-empty 'Sec-WebSocket-Protocol' header "
223 "but no response was received";
226 *sub_protocol
= value
;
230 bool DeflateError(std::string
* message
, const base::StringPiece
& piece
) {
231 *message
= "Error in permessage-deflate: ";
232 piece
.AppendToString(message
);
236 bool ValidatePerMessageDeflateExtension(const WebSocketExtension
& extension
,
237 std::string
* failure_message
,
238 WebSocketExtensionParams
* params
) {
239 static const char kClientPrefix
[] = "client_";
240 static const char kServerPrefix
[] = "server_";
241 static const char kNoContextTakeover
[] = "no_context_takeover";
242 static const char kMaxWindowBits
[] = "max_window_bits";
243 const size_t kPrefixLen
= arraysize(kClientPrefix
) - 1;
244 COMPILE_ASSERT(kPrefixLen
== arraysize(kServerPrefix
) - 1,
245 the_strings_server_and_client_must_be_the_same_length
);
246 typedef std::vector
<WebSocketExtension::Parameter
> ParameterVector
;
248 DCHECK_EQ("permessage-deflate", extension
.name());
249 const ParameterVector
& parameters
= extension
.parameters();
250 std::set
<std::string
> seen_names
;
251 for (ParameterVector::const_iterator it
= parameters
.begin();
252 it
!= parameters
.end(); ++it
) {
253 const std::string
& name
= it
->name();
254 if (seen_names
.count(name
) != 0) {
257 "Received duplicate permessage-deflate extension parameter " + name
);
259 seen_names
.insert(name
);
260 const std::string
client_or_server(name
, 0, kPrefixLen
);
261 const bool is_client
= (client_or_server
== kClientPrefix
);
262 if (!is_client
&& client_or_server
!= kServerPrefix
) {
265 "Received an unexpected permessage-deflate extension parameter");
267 const std::string
rest(name
, kPrefixLen
);
268 if (rest
== kNoContextTakeover
) {
269 if (it
->HasValue()) {
270 return DeflateError(failure_message
,
271 "Received invalid " + name
+ " parameter");
274 params
->deflate_mode
= WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT
;
275 } else if (rest
== kMaxWindowBits
) {
277 return DeflateError(failure_message
, name
+ " must have value");
279 if (!base::StringToInt(it
->value(), &bits
) || bits
< 8 || bits
> 15 ||
280 it
->value()[0] == '0' ||
281 it
->value().find_first_not_of("0123456789") != std::string::npos
) {
282 return DeflateError(failure_message
,
283 "Received invalid " + name
+ " parameter");
286 params
->client_window_bits
= bits
;
290 "Received an unexpected permessage-deflate extension parameter");
293 params
->deflate_enabled
= true;
297 bool ValidateExtensions(const HttpResponseHeaders
* headers
,
298 const std::vector
<std::string
>& requested_extensions
,
299 std::string
* extensions
,
300 std::string
* failure_message
,
301 WebSocketExtensionParams
* params
) {
304 std::vector
<std::string
> accepted_extensions
;
305 // TODO(ricea): If adding support for additional extensions, generalise this
307 bool seen_permessage_deflate
= false;
308 while (headers
->EnumerateHeader(
309 &state
, websockets::kSecWebSocketExtensions
, &value
)) {
310 WebSocketExtensionParser parser
;
312 if (parser
.has_error()) {
313 // TODO(yhirano) Set appropriate failure message.
315 "'Sec-WebSocket-Extensions' header value is "
316 "rejected by the parser: " +
320 if (parser
.extension().name() == "permessage-deflate") {
321 if (seen_permessage_deflate
) {
322 *failure_message
= "Received duplicate permessage-deflate response";
325 seen_permessage_deflate
= true;
326 if (!ValidatePerMessageDeflateExtension(
327 parser
.extension(), failure_message
, params
))
331 "Found an unsupported extension '" +
332 parser
.extension().name() +
333 "' in 'Sec-WebSocket-Extensions' header";
336 accepted_extensions
.push_back(value
);
338 *extensions
= JoinString(accepted_extensions
, ", ");
344 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream(
345 scoped_ptr
<ClientSocketHandle
> connection
,
346 WebSocketStream::ConnectDelegate
* connect_delegate
,
348 std::vector
<std::string
> requested_sub_protocols
,
349 std::vector
<std::string
> requested_extensions
,
350 std::string
* failure_message
)
351 : state_(connection
.release(), using_proxy
),
352 connect_delegate_(connect_delegate
),
353 http_response_info_(NULL
),
354 requested_sub_protocols_(requested_sub_protocols
),
355 requested_extensions_(requested_extensions
),
356 failure_message_(failure_message
) {
357 DCHECK(connect_delegate
);
358 DCHECK(failure_message
);
361 WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {}
363 int WebSocketBasicHandshakeStream::InitializeStream(
364 const HttpRequestInfo
* request_info
,
365 RequestPriority priority
,
366 const BoundNetLog
& net_log
,
367 const CompletionCallback
& callback
) {
368 url_
= request_info
->url
;
369 state_
.Initialize(request_info
, priority
, net_log
, callback
);
373 int WebSocketBasicHandshakeStream::SendRequest(
374 const HttpRequestHeaders
& headers
,
375 HttpResponseInfo
* response
,
376 const CompletionCallback
& callback
) {
377 DCHECK(!headers
.HasHeader(websockets::kSecWebSocketKey
));
378 DCHECK(!headers
.HasHeader(websockets::kSecWebSocketProtocol
));
379 DCHECK(!headers
.HasHeader(websockets::kSecWebSocketExtensions
));
380 DCHECK(headers
.HasHeader(HttpRequestHeaders::kOrigin
));
381 DCHECK(headers
.HasHeader(websockets::kUpgrade
));
382 DCHECK(headers
.HasHeader(HttpRequestHeaders::kConnection
));
383 DCHECK(headers
.HasHeader(websockets::kSecWebSocketVersion
));
386 http_response_info_
= response
;
388 // Create a copy of the headers object, so that we can add the
389 // Sec-WebSockey-Key header.
390 HttpRequestHeaders enriched_headers
;
391 enriched_headers
.CopyFrom(headers
);
392 std::string handshake_challenge
;
393 if (handshake_challenge_for_testing_
) {
394 handshake_challenge
= *handshake_challenge_for_testing_
;
395 handshake_challenge_for_testing_
.reset();
397 handshake_challenge
= GenerateHandshakeChallenge();
399 enriched_headers
.SetHeader(websockets::kSecWebSocketKey
, handshake_challenge
);
401 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions
,
402 requested_extensions_
,
404 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol
,
405 requested_sub_protocols_
,
408 ComputeSecWebSocketAccept(handshake_challenge
,
409 &handshake_challenge_response_
);
411 DCHECK(connect_delegate_
);
412 scoped_ptr
<WebSocketHandshakeRequestInfo
> request(
413 new WebSocketHandshakeRequestInfo(url_
, base::Time::Now()));
414 request
->headers
.CopyFrom(enriched_headers
);
415 connect_delegate_
->OnStartOpeningHandshake(request
.Pass());
417 return parser()->SendRequest(
418 state_
.GenerateRequestLine(), enriched_headers
, response
, callback
);
421 int WebSocketBasicHandshakeStream::ReadResponseHeaders(
422 const CompletionCallback
& callback
) {
423 // HttpStreamParser uses a weak pointer when reading from the
424 // socket, so it won't be called back after being destroyed. The
425 // HttpStreamParser is owned by HttpBasicState which is owned by this object,
426 // so this use of base::Unretained() is safe.
427 int rv
= parser()->ReadResponseHeaders(
428 base::Bind(&WebSocketBasicHandshakeStream::ReadResponseHeadersCallback
,
429 base::Unretained(this),
431 if (rv
== ERR_IO_PENDING
)
433 return ValidateResponse(rv
);
436 int WebSocketBasicHandshakeStream::ReadResponseBody(
439 const CompletionCallback
& callback
) {
440 return parser()->ReadResponseBody(buf
, buf_len
, callback
);
443 void WebSocketBasicHandshakeStream::Close(bool not_reusable
) {
444 // This class ignores the value of |not_reusable| and never lets the socket be
447 parser()->Close(true);
450 bool WebSocketBasicHandshakeStream::IsResponseBodyComplete() const {
451 return parser()->IsResponseBodyComplete();
454 bool WebSocketBasicHandshakeStream::CanFindEndOfResponse() const {
455 return parser() && parser()->CanFindEndOfResponse();
458 bool WebSocketBasicHandshakeStream::IsConnectionReused() const {
459 return parser()->IsConnectionReused();
462 void WebSocketBasicHandshakeStream::SetConnectionReused() {
463 parser()->SetConnectionReused();
466 bool WebSocketBasicHandshakeStream::IsConnectionReusable() const {
470 int64
WebSocketBasicHandshakeStream::GetTotalReceivedBytes() const {
474 bool WebSocketBasicHandshakeStream::GetLoadTimingInfo(
475 LoadTimingInfo
* load_timing_info
) const {
476 return state_
.connection()->GetLoadTimingInfo(IsConnectionReused(),
480 void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo
* ssl_info
) {
481 parser()->GetSSLInfo(ssl_info
);
484 void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo(
485 SSLCertRequestInfo
* cert_request_info
) {
486 parser()->GetSSLCertRequestInfo(cert_request_info
);
489 bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { return false; }
491 void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession
* session
) {
492 HttpResponseBodyDrainer
* drainer
= new HttpResponseBodyDrainer(this);
493 drainer
->Start(session
);
494 // |drainer| will delete itself.
497 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority
) {
498 // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is
499 // gone, then copy whatever has happened there over here.
502 scoped_ptr
<WebSocketStream
> WebSocketBasicHandshakeStream::Upgrade() {
503 // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make
504 // sure it does not touch it again before it is destroyed.
505 state_
.DeleteParser();
506 WebSocketTransportClientSocketPool::UnlockEndpoint(state_
.connection());
507 scoped_ptr
<WebSocketStream
> basic_stream(
508 new WebSocketBasicStream(state_
.ReleaseConnection(),
512 DCHECK(extension_params_
.get());
513 if (extension_params_
->deflate_enabled
) {
514 UMA_HISTOGRAM_ENUMERATION(
515 "Net.WebSocket.DeflateMode",
516 extension_params_
->deflate_mode
,
517 WebSocketDeflater::NUM_CONTEXT_TAKEOVER_MODE_TYPES
);
519 return scoped_ptr
<WebSocketStream
>(
520 new WebSocketDeflateStream(basic_stream
.Pass(),
521 extension_params_
->deflate_mode
,
522 extension_params_
->client_window_bits
,
523 scoped_ptr
<WebSocketDeflatePredictor
>(
524 new WebSocketDeflatePredictorImpl
)));
526 return basic_stream
.Pass();
530 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
531 const std::string
& key
) {
532 handshake_challenge_for_testing_
.reset(new std::string(key
));
535 void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback(
536 const CompletionCallback
& callback
,
538 callback
.Run(ValidateResponse(result
));
541 void WebSocketBasicHandshakeStream::OnFinishOpeningHandshake() {
542 DCHECK(http_response_info_
);
543 WebSocketDispatchOnFinishOpeningHandshake(connect_delegate_
,
545 http_response_info_
->headers
,
546 http_response_info_
->response_time
);
549 int WebSocketBasicHandshakeStream::ValidateResponse(int rv
) {
550 DCHECK(http_response_info_
);
551 // Most net errors happen during connection, so they are not seen by this
552 // method. The histogram for error codes is created in
553 // Delegate::OnResponseStarted in websocket_stream.cc instead.
555 const HttpResponseHeaders
* headers
= http_response_info_
->headers
.get();
556 const int response_code
= headers
->response_code();
557 UMA_HISTOGRAM_SPARSE_SLOWLY("Net.WebSocket.ResponseCode", response_code
);
558 switch (response_code
) {
559 case HTTP_SWITCHING_PROTOCOLS
:
560 OnFinishOpeningHandshake();
561 return ValidateUpgradeResponse(headers
);
563 // We need to pass these through for authentication to work.
564 case HTTP_UNAUTHORIZED
:
565 case HTTP_PROXY_AUTHENTICATION_REQUIRED
:
568 // Other status codes are potentially risky (see the warnings in the
569 // WHATWG WebSocket API spec) and so are dropped by default.
571 // A WebSocket server cannot be using HTTP/0.9, so if we see version
572 // 0.9, it means the response was garbage.
573 // Reporting "Unexpected response code: 200" in this case is not
574 // helpful, so use a different error message.
575 if (headers
->GetHttpVersion() == HttpVersion(0, 9)) {
577 "Error during WebSocket handshake: Invalid status line");
579 set_failure_message(base::StringPrintf(
580 "Error during WebSocket handshake: Unexpected response code: %d",
581 headers
->response_code()));
583 OnFinishOpeningHandshake();
584 return ERR_INVALID_RESPONSE
;
587 if (rv
== ERR_EMPTY_RESPONSE
) {
589 "Connection closed before receiving a handshake response");
592 set_failure_message(std::string("Error during WebSocket handshake: ") +
594 OnFinishOpeningHandshake();
599 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
600 const HttpResponseHeaders
* headers
) {
601 extension_params_
.reset(new WebSocketExtensionParams
);
602 std::string failure_message
;
603 if (ValidateUpgrade(headers
, &failure_message
) &&
604 ValidateSecWebSocketAccept(
605 headers
, handshake_challenge_response_
, &failure_message
) &&
606 ValidateConnection(headers
, &failure_message
) &&
607 ValidateSubProtocol(headers
,
608 requested_sub_protocols_
,
611 ValidateExtensions(headers
,
612 requested_extensions_
,
615 extension_params_
.get())) {
618 set_failure_message("Error during WebSocket handshake: " + failure_message
);
619 return ERR_INVALID_RESPONSE
;
622 void WebSocketBasicHandshakeStream::set_failure_message(
623 const std::string
& failure_message
) {
624 *failure_message_
= failure_message
;