Disable signin-to-Chrome when using Guest profile.
[chromium-blink-merge.git] / net / websockets / websocket_basic_handshake_stream.cc
blobd4ed6f8530a645e94fde7f9523d95f61331394df
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_basic_handshake_stream.h"
7 #include <algorithm>
8 #include <iterator>
9 #include <set>
10 #include <string>
11 #include <vector>
13 #include "base/base64.h"
14 #include "base/basictypes.h"
15 #include "base/bind.h"
16 #include "base/containers/hash_tables.h"
17 #include "base/stl_util.h"
18 #include "base/strings/string_number_conversions.h"
19 #include "base/strings/string_piece.h"
20 #include "base/strings/string_util.h"
21 #include "base/strings/stringprintf.h"
22 #include "base/time/time.h"
23 #include "crypto/random.h"
24 #include "net/http/http_request_headers.h"
25 #include "net/http/http_request_info.h"
26 #include "net/http/http_response_body_drainer.h"
27 #include "net/http/http_response_headers.h"
28 #include "net/http/http_status_code.h"
29 #include "net/http/http_stream_parser.h"
30 #include "net/socket/client_socket_handle.h"
31 #include "net/websockets/websocket_basic_stream.h"
32 #include "net/websockets/websocket_deflate_predictor.h"
33 #include "net/websockets/websocket_deflate_predictor_impl.h"
34 #include "net/websockets/websocket_deflate_stream.h"
35 #include "net/websockets/websocket_deflater.h"
36 #include "net/websockets/websocket_extension_parser.h"
37 #include "net/websockets/websocket_handshake_constants.h"
38 #include "net/websockets/websocket_handshake_handler.h"
39 #include "net/websockets/websocket_handshake_request_info.h"
40 #include "net/websockets/websocket_handshake_response_info.h"
41 #include "net/websockets/websocket_stream.h"
43 namespace net {
45 // TODO(ricea): If more extensions are added, replace this with a more general
46 // mechanism.
47 struct WebSocketExtensionParams {
48 WebSocketExtensionParams()
49 : deflate_enabled(false),
50 client_window_bits(15),
51 deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {}
53 bool deflate_enabled;
54 int client_window_bits;
55 WebSocketDeflater::ContextTakeOverMode deflate_mode;
58 namespace {
60 enum GetHeaderResult {
61 GET_HEADER_OK,
62 GET_HEADER_MISSING,
63 GET_HEADER_MULTIPLE,
66 std::string MissingHeaderMessage(const std::string& header_name) {
67 return std::string("'") + header_name + "' header is missing";
70 std::string MultipleHeaderValuesMessage(const std::string& header_name) {
71 return
72 std::string("'") +
73 header_name +
74 "' header must not appear more than once in a response";
77 std::string GenerateHandshakeChallenge() {
78 std::string raw_challenge(websockets::kRawChallengeLength, '\0');
79 crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length());
80 std::string encoded_challenge;
81 base::Base64Encode(raw_challenge, &encoded_challenge);
82 return encoded_challenge;
85 void AddVectorHeaderIfNonEmpty(const char* name,
86 const std::vector<std::string>& value,
87 HttpRequestHeaders* headers) {
88 if (value.empty())
89 return;
90 headers->SetHeader(name, JoinString(value, ", "));
93 GetHeaderResult GetSingleHeaderValue(const HttpResponseHeaders* headers,
94 const base::StringPiece& name,
95 std::string* value) {
96 void* state = NULL;
97 size_t num_values = 0;
98 std::string temp_value;
99 while (headers->EnumerateHeader(&state, name, &temp_value)) {
100 if (++num_values > 1)
101 return GET_HEADER_MULTIPLE;
102 *value = temp_value;
104 return num_values > 0 ? GET_HEADER_OK : GET_HEADER_MISSING;
107 bool ValidateHeaderHasSingleValue(GetHeaderResult result,
108 const std::string& header_name,
109 std::string* failure_message) {
110 if (result == GET_HEADER_MISSING) {
111 *failure_message = MissingHeaderMessage(header_name);
112 return false;
114 if (result == GET_HEADER_MULTIPLE) {
115 *failure_message = MultipleHeaderValuesMessage(header_name);
116 return false;
118 DCHECK_EQ(result, GET_HEADER_OK);
119 return true;
122 bool ValidateUpgrade(const HttpResponseHeaders* headers,
123 std::string* failure_message) {
124 std::string value;
125 GetHeaderResult result =
126 GetSingleHeaderValue(headers, websockets::kUpgrade, &value);
127 if (!ValidateHeaderHasSingleValue(result,
128 websockets::kUpgrade,
129 failure_message)) {
130 return false;
133 if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) {
134 *failure_message =
135 "'Upgrade' header value is not 'WebSocket': " + value;
136 return false;
138 return true;
141 bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers,
142 const std::string& expected,
143 std::string* failure_message) {
144 std::string actual;
145 GetHeaderResult result =
146 GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual);
147 if (!ValidateHeaderHasSingleValue(result,
148 websockets::kSecWebSocketAccept,
149 failure_message)) {
150 return false;
153 if (expected != actual) {
154 *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value";
155 return false;
157 return true;
160 bool ValidateConnection(const HttpResponseHeaders* headers,
161 std::string* failure_message) {
162 // Connection header is permitted to contain other tokens.
163 if (!headers->HasHeader(HttpRequestHeaders::kConnection)) {
164 *failure_message = MissingHeaderMessage(HttpRequestHeaders::kConnection);
165 return false;
167 if (!headers->HasHeaderValue(HttpRequestHeaders::kConnection,
168 websockets::kUpgrade)) {
169 *failure_message = "'Connection' header value must contain 'Upgrade'";
170 return false;
172 return true;
175 bool ValidateSubProtocol(
176 const HttpResponseHeaders* headers,
177 const std::vector<std::string>& requested_sub_protocols,
178 std::string* sub_protocol,
179 std::string* failure_message) {
180 void* state = NULL;
181 std::string value;
182 base::hash_set<std::string> requested_set(requested_sub_protocols.begin(),
183 requested_sub_protocols.end());
184 int count = 0;
185 bool has_multiple_protocols = false;
186 bool has_invalid_protocol = false;
188 while (!has_invalid_protocol || !has_multiple_protocols) {
189 std::string temp_value;
190 if (!headers->EnumerateHeader(
191 &state, websockets::kSecWebSocketProtocol, &temp_value))
192 break;
193 value = temp_value;
194 if (requested_set.count(value) == 0)
195 has_invalid_protocol = true;
196 if (++count > 1)
197 has_multiple_protocols = true;
200 if (has_multiple_protocols) {
201 *failure_message =
202 MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol);
203 return false;
204 } else if (count > 0 && requested_sub_protocols.size() == 0) {
205 *failure_message =
206 std::string("Response must not include 'Sec-WebSocket-Protocol' "
207 "header if not present in request: ")
208 + value;
209 return false;
210 } else if (has_invalid_protocol) {
211 *failure_message =
212 "'Sec-WebSocket-Protocol' header value '" +
213 value +
214 "' in response does not match any of sent values";
215 return false;
216 } else if (requested_sub_protocols.size() > 0 && count == 0) {
217 *failure_message =
218 "Sent non-empty 'Sec-WebSocket-Protocol' header "
219 "but no response was received";
220 return false;
222 *sub_protocol = value;
223 return true;
226 bool DeflateError(std::string* message, const base::StringPiece& piece) {
227 *message = "Error in permessage-deflate: ";
228 piece.AppendToString(message);
229 return false;
232 bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension,
233 std::string* failure_message,
234 WebSocketExtensionParams* params) {
235 static const char kClientPrefix[] = "client_";
236 static const char kServerPrefix[] = "server_";
237 static const char kNoContextTakeover[] = "no_context_takeover";
238 static const char kMaxWindowBits[] = "max_window_bits";
239 const size_t kPrefixLen = arraysize(kClientPrefix) - 1;
240 COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1,
241 the_strings_server_and_client_must_be_the_same_length);
242 typedef std::vector<WebSocketExtension::Parameter> ParameterVector;
244 DCHECK_EQ("permessage-deflate", extension.name());
245 const ParameterVector& parameters = extension.parameters();
246 std::set<std::string> seen_names;
247 for (ParameterVector::const_iterator it = parameters.begin();
248 it != parameters.end(); ++it) {
249 const std::string& name = it->name();
250 if (seen_names.count(name) != 0) {
251 return DeflateError(
252 failure_message,
253 "Received duplicate permessage-deflate extension parameter " + name);
255 seen_names.insert(name);
256 const std::string client_or_server(name, 0, kPrefixLen);
257 const bool is_client = (client_or_server == kClientPrefix);
258 if (!is_client && client_or_server != kServerPrefix) {
259 return DeflateError(
260 failure_message,
261 "Received an unexpected permessage-deflate extension parameter");
263 const std::string rest(name, kPrefixLen);
264 if (rest == kNoContextTakeover) {
265 if (it->HasValue()) {
266 return DeflateError(failure_message,
267 "Received invalid " + name + " parameter");
269 if (is_client)
270 params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT;
271 } else if (rest == kMaxWindowBits) {
272 if (!it->HasValue())
273 return DeflateError(failure_message, name + " must have value");
274 int bits = 0;
275 if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 ||
276 it->value()[0] == '0' ||
277 it->value().find_first_not_of("0123456789") != std::string::npos) {
278 return DeflateError(failure_message,
279 "Received invalid " + name + " parameter");
281 if (is_client)
282 params->client_window_bits = bits;
283 } else {
284 return DeflateError(
285 failure_message,
286 "Received an unexpected permessage-deflate extension parameter");
289 params->deflate_enabled = true;
290 return true;
293 bool ValidateExtensions(const HttpResponseHeaders* headers,
294 const std::vector<std::string>& requested_extensions,
295 std::string* extensions,
296 std::string* failure_message,
297 WebSocketExtensionParams* params) {
298 void* state = NULL;
299 std::string value;
300 std::vector<std::string> accepted_extensions;
301 // TODO(ricea): If adding support for additional extensions, generalise this
302 // code.
303 bool seen_permessage_deflate = false;
304 while (headers->EnumerateHeader(
305 &state, websockets::kSecWebSocketExtensions, &value)) {
306 WebSocketExtensionParser parser;
307 parser.Parse(value);
308 if (parser.has_error()) {
309 // TODO(yhirano) Set appropriate failure message.
310 *failure_message =
311 "'Sec-WebSocket-Extensions' header value is "
312 "rejected by the parser: " +
313 value;
314 return false;
316 if (parser.extension().name() == "permessage-deflate") {
317 if (seen_permessage_deflate) {
318 *failure_message = "Received duplicate permessage-deflate response";
319 return false;
321 seen_permessage_deflate = true;
322 if (!ValidatePerMessageDeflateExtension(
323 parser.extension(), failure_message, params))
324 return false;
325 } else {
326 *failure_message =
327 "Found an unsupported extension '" +
328 parser.extension().name() +
329 "' in 'Sec-WebSocket-Extensions' header";
330 return false;
332 accepted_extensions.push_back(value);
334 *extensions = JoinString(accepted_extensions, ", ");
335 return true;
338 } // namespace
340 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream(
341 scoped_ptr<ClientSocketHandle> connection,
342 WebSocketStream::ConnectDelegate* connect_delegate,
343 bool using_proxy,
344 std::vector<std::string> requested_sub_protocols,
345 std::vector<std::string> requested_extensions)
346 : state_(connection.release(), using_proxy),
347 connect_delegate_(connect_delegate),
348 http_response_info_(NULL),
349 requested_sub_protocols_(requested_sub_protocols),
350 requested_extensions_(requested_extensions) {}
352 WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {}
354 int WebSocketBasicHandshakeStream::InitializeStream(
355 const HttpRequestInfo* request_info,
356 RequestPriority priority,
357 const BoundNetLog& net_log,
358 const CompletionCallback& callback) {
359 url_ = request_info->url;
360 state_.Initialize(request_info, priority, net_log, callback);
361 return OK;
364 int WebSocketBasicHandshakeStream::SendRequest(
365 const HttpRequestHeaders& headers,
366 HttpResponseInfo* response,
367 const CompletionCallback& callback) {
368 DCHECK(!headers.HasHeader(websockets::kSecWebSocketKey));
369 DCHECK(!headers.HasHeader(websockets::kSecWebSocketProtocol));
370 DCHECK(!headers.HasHeader(websockets::kSecWebSocketExtensions));
371 DCHECK(headers.HasHeader(HttpRequestHeaders::kOrigin));
372 DCHECK(headers.HasHeader(websockets::kUpgrade));
373 DCHECK(headers.HasHeader(HttpRequestHeaders::kConnection));
374 DCHECK(headers.HasHeader(websockets::kSecWebSocketVersion));
375 DCHECK(parser());
377 http_response_info_ = response;
379 // Create a copy of the headers object, so that we can add the
380 // Sec-WebSockey-Key header.
381 HttpRequestHeaders enriched_headers;
382 enriched_headers.CopyFrom(headers);
383 std::string handshake_challenge;
384 if (handshake_challenge_for_testing_) {
385 handshake_challenge = *handshake_challenge_for_testing_;
386 handshake_challenge_for_testing_.reset();
387 } else {
388 handshake_challenge = GenerateHandshakeChallenge();
390 enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge);
392 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
393 requested_extensions_,
394 &enriched_headers);
395 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
396 requested_sub_protocols_,
397 &enriched_headers);
399 ComputeSecWebSocketAccept(handshake_challenge,
400 &handshake_challenge_response_);
402 DCHECK(connect_delegate_);
403 scoped_ptr<WebSocketHandshakeRequestInfo> request(
404 new WebSocketHandshakeRequestInfo(url_, base::Time::Now()));
405 request->headers.CopyFrom(enriched_headers);
406 connect_delegate_->OnStartOpeningHandshake(request.Pass());
408 return parser()->SendRequest(
409 state_.GenerateRequestLine(), enriched_headers, response, callback);
412 int WebSocketBasicHandshakeStream::ReadResponseHeaders(
413 const CompletionCallback& callback) {
414 // HttpStreamParser uses a weak pointer when reading from the
415 // socket, so it won't be called back after being destroyed. The
416 // HttpStreamParser is owned by HttpBasicState which is owned by this object,
417 // so this use of base::Unretained() is safe.
418 int rv = parser()->ReadResponseHeaders(
419 base::Bind(&WebSocketBasicHandshakeStream::ReadResponseHeadersCallback,
420 base::Unretained(this),
421 callback));
422 if (rv == ERR_IO_PENDING)
423 return rv;
424 return ValidateResponse(rv);
427 const HttpResponseInfo* WebSocketBasicHandshakeStream::GetResponseInfo() const {
428 return parser()->GetResponseInfo();
431 int WebSocketBasicHandshakeStream::ReadResponseBody(
432 IOBuffer* buf,
433 int buf_len,
434 const CompletionCallback& callback) {
435 return parser()->ReadResponseBody(buf, buf_len, callback);
438 void WebSocketBasicHandshakeStream::Close(bool not_reusable) {
439 // This class ignores the value of |not_reusable| and never lets the socket be
440 // re-used.
441 if (parser())
442 parser()->Close(true);
445 bool WebSocketBasicHandshakeStream::IsResponseBodyComplete() const {
446 return parser()->IsResponseBodyComplete();
449 bool WebSocketBasicHandshakeStream::CanFindEndOfResponse() const {
450 return parser() && parser()->CanFindEndOfResponse();
453 bool WebSocketBasicHandshakeStream::IsConnectionReused() const {
454 return parser()->IsConnectionReused();
457 void WebSocketBasicHandshakeStream::SetConnectionReused() {
458 parser()->SetConnectionReused();
461 bool WebSocketBasicHandshakeStream::IsConnectionReusable() const {
462 return false;
465 int64 WebSocketBasicHandshakeStream::GetTotalReceivedBytes() const {
466 return 0;
469 bool WebSocketBasicHandshakeStream::GetLoadTimingInfo(
470 LoadTimingInfo* load_timing_info) const {
471 return state_.connection()->GetLoadTimingInfo(IsConnectionReused(),
472 load_timing_info);
475 void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {
476 parser()->GetSSLInfo(ssl_info);
479 void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo(
480 SSLCertRequestInfo* cert_request_info) {
481 parser()->GetSSLCertRequestInfo(cert_request_info);
484 bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { return false; }
486 void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession* session) {
487 HttpResponseBodyDrainer* drainer = new HttpResponseBodyDrainer(this);
488 drainer->Start(session);
489 // |drainer| will delete itself.
492 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) {
493 // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is
494 // gone, then copy whatever has happened there over here.
497 scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() {
498 // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make
499 // sure it does not touch it again before it is destroyed.
500 state_.DeleteParser();
501 scoped_ptr<WebSocketStream> basic_stream(
502 new WebSocketBasicStream(state_.ReleaseConnection(),
503 state_.read_buf(),
504 sub_protocol_,
505 extensions_));
506 DCHECK(extension_params_.get());
507 if (extension_params_->deflate_enabled) {
508 return scoped_ptr<WebSocketStream>(
509 new WebSocketDeflateStream(basic_stream.Pass(),
510 extension_params_->deflate_mode,
511 extension_params_->client_window_bits,
512 scoped_ptr<WebSocketDeflatePredictor>(
513 new WebSocketDeflatePredictorImpl)));
514 } else {
515 return basic_stream.Pass();
519 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
520 const std::string& key) {
521 handshake_challenge_for_testing_.reset(new std::string(key));
524 std::string WebSocketBasicHandshakeStream::GetFailureMessage() const {
525 return failure_message_;
528 void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback(
529 const CompletionCallback& callback,
530 int result) {
531 callback.Run(ValidateResponse(result));
534 void WebSocketBasicHandshakeStream::OnFinishOpeningHandshake() {
535 DCHECK(connect_delegate_);
536 DCHECK(http_response_info_);
537 scoped_refptr<HttpResponseHeaders> headers = http_response_info_->headers;
538 // If the headers are too large, HttpStreamParser will just not parse them at
539 // all.
540 if (headers) {
541 scoped_ptr<WebSocketHandshakeResponseInfo> response(
542 new WebSocketHandshakeResponseInfo(url_,
543 headers->response_code(),
544 headers->GetStatusText(),
545 headers,
546 http_response_info_->response_time));
547 connect_delegate_->OnFinishOpeningHandshake(response.Pass());
551 int WebSocketBasicHandshakeStream::ValidateResponse(int rv) {
552 DCHECK(http_response_info_);
553 const HttpResponseHeaders* headers = http_response_info_->headers.get();
554 if (rv >= 0) {
555 switch (headers->response_code()) {
556 case HTTP_SWITCHING_PROTOCOLS:
557 OnFinishOpeningHandshake();
558 return ValidateUpgradeResponse(headers);
560 // We need to pass these through for authentication to work.
561 case HTTP_UNAUTHORIZED:
562 case HTTP_PROXY_AUTHENTICATION_REQUIRED:
563 return OK;
565 // Other status codes are potentially risky (see the warnings in the
566 // WHATWG WebSocket API spec) and so are dropped by default.
567 default:
568 // A WebSocket server cannot be using HTTP/0.9, so if we see version
569 // 0.9, it means the response was garbage.
570 // Reporting "Unexpected response code: 200" in this case is not
571 // helpful, so use a different error message.
572 if (headers->GetHttpVersion() == HttpVersion(0, 9)) {
573 failure_message_ =
574 "Error during WebSocket handshake: Invalid status line";
575 } else {
576 failure_message_ = base::StringPrintf(
577 "Error during WebSocket handshake: Unexpected response code: %d",
578 headers->response_code());
580 OnFinishOpeningHandshake();
581 return ERR_INVALID_RESPONSE;
583 } else {
584 if (rv == ERR_EMPTY_RESPONSE) {
585 failure_message_ =
586 "Connection closed before receiving a handshake response";
587 return rv;
589 failure_message_ =
590 std::string("Error during WebSocket handshake: ") + ErrorToString(rv);
591 OnFinishOpeningHandshake();
592 return rv;
596 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
597 const HttpResponseHeaders* headers) {
598 extension_params_.reset(new WebSocketExtensionParams);
599 if (ValidateUpgrade(headers, &failure_message_) &&
600 ValidateSecWebSocketAccept(headers,
601 handshake_challenge_response_,
602 &failure_message_) &&
603 ValidateConnection(headers, &failure_message_) &&
604 ValidateSubProtocol(headers,
605 requested_sub_protocols_,
606 &sub_protocol_,
607 &failure_message_) &&
608 ValidateExtensions(headers,
609 requested_extensions_,
610 &extensions_,
611 &failure_message_,
612 extension_params_.get())) {
613 return OK;
615 failure_message_ = "Error during WebSocket handshake: " + failure_message_;
616 return ERR_INVALID_RESPONSE;
619 } // namespace net