bookmarks: Move bookmark_test_helpers.h into 'bookmarks' namespace.
[chromium-blink-merge.git] / net / websockets / websocket_basic_handshake_stream.cc
blob2eee942478e92be827f9163c76c96a0d2ee9f947
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_basic_handshake_stream.h"
7 #include <algorithm>
8 #include <iterator>
9 #include <set>
10 #include <string>
11 #include <vector>
13 #include "base/base64.h"
14 #include "base/basictypes.h"
15 #include "base/bind.h"
16 #include "base/containers/hash_tables.h"
17 #include "base/logging.h"
18 #include "base/metrics/histogram.h"
19 #include "base/metrics/sparse_histogram.h"
20 #include "base/stl_util.h"
21 #include "base/strings/string_number_conversions.h"
22 #include "base/strings/string_piece.h"
23 #include "base/strings/string_util.h"
24 #include "base/strings/stringprintf.h"
25 #include "base/time/time.h"
26 #include "crypto/random.h"
27 #include "net/http/http_request_headers.h"
28 #include "net/http/http_request_info.h"
29 #include "net/http/http_response_body_drainer.h"
30 #include "net/http/http_response_headers.h"
31 #include "net/http/http_status_code.h"
32 #include "net/http/http_stream_parser.h"
33 #include "net/socket/client_socket_handle.h"
34 #include "net/socket/websocket_transport_client_socket_pool.h"
35 #include "net/websockets/websocket_basic_stream.h"
36 #include "net/websockets/websocket_deflate_predictor.h"
37 #include "net/websockets/websocket_deflate_predictor_impl.h"
38 #include "net/websockets/websocket_deflate_stream.h"
39 #include "net/websockets/websocket_deflater.h"
40 #include "net/websockets/websocket_extension_parser.h"
41 #include "net/websockets/websocket_handshake_constants.h"
42 #include "net/websockets/websocket_handshake_handler.h"
43 #include "net/websockets/websocket_handshake_request_info.h"
44 #include "net/websockets/websocket_handshake_response_info.h"
45 #include "net/websockets/websocket_stream.h"
47 namespace net {
49 // TODO(ricea): If more extensions are added, replace this with a more general
50 // mechanism.
51 struct WebSocketExtensionParams {
52 WebSocketExtensionParams()
53 : deflate_enabled(false),
54 client_window_bits(15),
55 deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {}
57 bool deflate_enabled;
58 int client_window_bits;
59 WebSocketDeflater::ContextTakeOverMode deflate_mode;
62 namespace {
64 enum GetHeaderResult {
65 GET_HEADER_OK,
66 GET_HEADER_MISSING,
67 GET_HEADER_MULTIPLE,
70 std::string MissingHeaderMessage(const std::string& header_name) {
71 return std::string("'") + header_name + "' header is missing";
74 std::string MultipleHeaderValuesMessage(const std::string& header_name) {
75 return
76 std::string("'") +
77 header_name +
78 "' header must not appear more than once in a response";
81 std::string GenerateHandshakeChallenge() {
82 std::string raw_challenge(websockets::kRawChallengeLength, '\0');
83 crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length());
84 std::string encoded_challenge;
85 base::Base64Encode(raw_challenge, &encoded_challenge);
86 return encoded_challenge;
89 void AddVectorHeaderIfNonEmpty(const char* name,
90 const std::vector<std::string>& value,
91 HttpRequestHeaders* headers) {
92 if (value.empty())
93 return;
94 headers->SetHeader(name, JoinString(value, ", "));
97 GetHeaderResult GetSingleHeaderValue(const HttpResponseHeaders* headers,
98 const base::StringPiece& name,
99 std::string* value) {
100 void* state = NULL;
101 size_t num_values = 0;
102 std::string temp_value;
103 while (headers->EnumerateHeader(&state, name, &temp_value)) {
104 if (++num_values > 1)
105 return GET_HEADER_MULTIPLE;
106 *value = temp_value;
108 return num_values > 0 ? GET_HEADER_OK : GET_HEADER_MISSING;
111 bool ValidateHeaderHasSingleValue(GetHeaderResult result,
112 const std::string& header_name,
113 std::string* failure_message) {
114 if (result == GET_HEADER_MISSING) {
115 *failure_message = MissingHeaderMessage(header_name);
116 return false;
118 if (result == GET_HEADER_MULTIPLE) {
119 *failure_message = MultipleHeaderValuesMessage(header_name);
120 return false;
122 DCHECK_EQ(result, GET_HEADER_OK);
123 return true;
126 bool ValidateUpgrade(const HttpResponseHeaders* headers,
127 std::string* failure_message) {
128 std::string value;
129 GetHeaderResult result =
130 GetSingleHeaderValue(headers, websockets::kUpgrade, &value);
131 if (!ValidateHeaderHasSingleValue(result,
132 websockets::kUpgrade,
133 failure_message)) {
134 return false;
137 if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) {
138 *failure_message =
139 "'Upgrade' header value is not 'WebSocket': " + value;
140 return false;
142 return true;
145 bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers,
146 const std::string& expected,
147 std::string* failure_message) {
148 std::string actual;
149 GetHeaderResult result =
150 GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual);
151 if (!ValidateHeaderHasSingleValue(result,
152 websockets::kSecWebSocketAccept,
153 failure_message)) {
154 return false;
157 if (expected != actual) {
158 *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value";
159 return false;
161 return true;
164 bool ValidateConnection(const HttpResponseHeaders* headers,
165 std::string* failure_message) {
166 // Connection header is permitted to contain other tokens.
167 if (!headers->HasHeader(HttpRequestHeaders::kConnection)) {
168 *failure_message = MissingHeaderMessage(HttpRequestHeaders::kConnection);
169 return false;
171 if (!headers->HasHeaderValue(HttpRequestHeaders::kConnection,
172 websockets::kUpgrade)) {
173 *failure_message = "'Connection' header value must contain 'Upgrade'";
174 return false;
176 return true;
179 bool ValidateSubProtocol(
180 const HttpResponseHeaders* headers,
181 const std::vector<std::string>& requested_sub_protocols,
182 std::string* sub_protocol,
183 std::string* failure_message) {
184 void* state = NULL;
185 std::string value;
186 base::hash_set<std::string> requested_set(requested_sub_protocols.begin(),
187 requested_sub_protocols.end());
188 int count = 0;
189 bool has_multiple_protocols = false;
190 bool has_invalid_protocol = false;
192 while (!has_invalid_protocol || !has_multiple_protocols) {
193 std::string temp_value;
194 if (!headers->EnumerateHeader(
195 &state, websockets::kSecWebSocketProtocol, &temp_value))
196 break;
197 value = temp_value;
198 if (requested_set.count(value) == 0)
199 has_invalid_protocol = true;
200 if (++count > 1)
201 has_multiple_protocols = true;
204 if (has_multiple_protocols) {
205 *failure_message =
206 MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol);
207 return false;
208 } else if (count > 0 && requested_sub_protocols.size() == 0) {
209 *failure_message =
210 std::string("Response must not include 'Sec-WebSocket-Protocol' "
211 "header if not present in request: ")
212 + value;
213 return false;
214 } else if (has_invalid_protocol) {
215 *failure_message =
216 "'Sec-WebSocket-Protocol' header value '" +
217 value +
218 "' in response does not match any of sent values";
219 return false;
220 } else if (requested_sub_protocols.size() > 0 && count == 0) {
221 *failure_message =
222 "Sent non-empty 'Sec-WebSocket-Protocol' header "
223 "but no response was received";
224 return false;
226 *sub_protocol = value;
227 return true;
230 bool DeflateError(std::string* message, const base::StringPiece& piece) {
231 *message = "Error in permessage-deflate: ";
232 piece.AppendToString(message);
233 return false;
236 bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension,
237 std::string* failure_message,
238 WebSocketExtensionParams* params) {
239 static const char kClientPrefix[] = "client_";
240 static const char kServerPrefix[] = "server_";
241 static const char kNoContextTakeover[] = "no_context_takeover";
242 static const char kMaxWindowBits[] = "max_window_bits";
243 const size_t kPrefixLen = arraysize(kClientPrefix) - 1;
244 COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1,
245 the_strings_server_and_client_must_be_the_same_length);
246 typedef std::vector<WebSocketExtension::Parameter> ParameterVector;
248 DCHECK_EQ("permessage-deflate", extension.name());
249 const ParameterVector& parameters = extension.parameters();
250 std::set<std::string> seen_names;
251 for (ParameterVector::const_iterator it = parameters.begin();
252 it != parameters.end(); ++it) {
253 const std::string& name = it->name();
254 if (seen_names.count(name) != 0) {
255 return DeflateError(
256 failure_message,
257 "Received duplicate permessage-deflate extension parameter " + name);
259 seen_names.insert(name);
260 const std::string client_or_server(name, 0, kPrefixLen);
261 const bool is_client = (client_or_server == kClientPrefix);
262 if (!is_client && client_or_server != kServerPrefix) {
263 return DeflateError(
264 failure_message,
265 "Received an unexpected permessage-deflate extension parameter");
267 const std::string rest(name, kPrefixLen);
268 if (rest == kNoContextTakeover) {
269 if (it->HasValue()) {
270 return DeflateError(failure_message,
271 "Received invalid " + name + " parameter");
273 if (is_client)
274 params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT;
275 } else if (rest == kMaxWindowBits) {
276 if (!it->HasValue())
277 return DeflateError(failure_message, name + " must have value");
278 int bits = 0;
279 if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 ||
280 it->value()[0] == '0' ||
281 it->value().find_first_not_of("0123456789") != std::string::npos) {
282 return DeflateError(failure_message,
283 "Received invalid " + name + " parameter");
285 if (is_client)
286 params->client_window_bits = bits;
287 } else {
288 return DeflateError(
289 failure_message,
290 "Received an unexpected permessage-deflate extension parameter");
293 params->deflate_enabled = true;
294 return true;
297 bool ValidateExtensions(const HttpResponseHeaders* headers,
298 const std::vector<std::string>& requested_extensions,
299 std::string* extensions,
300 std::string* failure_message,
301 WebSocketExtensionParams* params) {
302 void* state = NULL;
303 std::string value;
304 std::vector<std::string> accepted_extensions;
305 // TODO(ricea): If adding support for additional extensions, generalise this
306 // code.
307 bool seen_permessage_deflate = false;
308 while (headers->EnumerateHeader(
309 &state, websockets::kSecWebSocketExtensions, &value)) {
310 WebSocketExtensionParser parser;
311 parser.Parse(value);
312 if (parser.has_error()) {
313 // TODO(yhirano) Set appropriate failure message.
314 *failure_message =
315 "'Sec-WebSocket-Extensions' header value is "
316 "rejected by the parser: " +
317 value;
318 return false;
320 if (parser.extension().name() == "permessage-deflate") {
321 if (seen_permessage_deflate) {
322 *failure_message = "Received duplicate permessage-deflate response";
323 return false;
325 seen_permessage_deflate = true;
326 if (!ValidatePerMessageDeflateExtension(
327 parser.extension(), failure_message, params))
328 return false;
329 } else {
330 *failure_message =
331 "Found an unsupported extension '" +
332 parser.extension().name() +
333 "' in 'Sec-WebSocket-Extensions' header";
334 return false;
336 accepted_extensions.push_back(value);
338 *extensions = JoinString(accepted_extensions, ", ");
339 return true;
342 } // namespace
344 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream(
345 scoped_ptr<ClientSocketHandle> connection,
346 WebSocketStream::ConnectDelegate* connect_delegate,
347 bool using_proxy,
348 std::vector<std::string> requested_sub_protocols,
349 std::vector<std::string> requested_extensions,
350 std::string* failure_message)
351 : state_(connection.release(), using_proxy),
352 connect_delegate_(connect_delegate),
353 http_response_info_(NULL),
354 requested_sub_protocols_(requested_sub_protocols),
355 requested_extensions_(requested_extensions),
356 failure_message_(failure_message) {
357 DCHECK(connect_delegate);
358 DCHECK(failure_message);
361 WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {}
363 int WebSocketBasicHandshakeStream::InitializeStream(
364 const HttpRequestInfo* request_info,
365 RequestPriority priority,
366 const BoundNetLog& net_log,
367 const CompletionCallback& callback) {
368 url_ = request_info->url;
369 state_.Initialize(request_info, priority, net_log, callback);
370 return OK;
373 int WebSocketBasicHandshakeStream::SendRequest(
374 const HttpRequestHeaders& headers,
375 HttpResponseInfo* response,
376 const CompletionCallback& callback) {
377 DCHECK(!headers.HasHeader(websockets::kSecWebSocketKey));
378 DCHECK(!headers.HasHeader(websockets::kSecWebSocketProtocol));
379 DCHECK(!headers.HasHeader(websockets::kSecWebSocketExtensions));
380 DCHECK(headers.HasHeader(HttpRequestHeaders::kOrigin));
381 DCHECK(headers.HasHeader(websockets::kUpgrade));
382 DCHECK(headers.HasHeader(HttpRequestHeaders::kConnection));
383 DCHECK(headers.HasHeader(websockets::kSecWebSocketVersion));
384 DCHECK(parser());
386 http_response_info_ = response;
388 // Create a copy of the headers object, so that we can add the
389 // Sec-WebSockey-Key header.
390 HttpRequestHeaders enriched_headers;
391 enriched_headers.CopyFrom(headers);
392 std::string handshake_challenge;
393 if (handshake_challenge_for_testing_) {
394 handshake_challenge = *handshake_challenge_for_testing_;
395 handshake_challenge_for_testing_.reset();
396 } else {
397 handshake_challenge = GenerateHandshakeChallenge();
399 enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge);
401 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
402 requested_extensions_,
403 &enriched_headers);
404 AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
405 requested_sub_protocols_,
406 &enriched_headers);
408 ComputeSecWebSocketAccept(handshake_challenge,
409 &handshake_challenge_response_);
411 DCHECK(connect_delegate_);
412 scoped_ptr<WebSocketHandshakeRequestInfo> request(
413 new WebSocketHandshakeRequestInfo(url_, base::Time::Now()));
414 request->headers.CopyFrom(enriched_headers);
415 connect_delegate_->OnStartOpeningHandshake(request.Pass());
417 return parser()->SendRequest(
418 state_.GenerateRequestLine(), enriched_headers, response, callback);
421 int WebSocketBasicHandshakeStream::ReadResponseHeaders(
422 const CompletionCallback& callback) {
423 // HttpStreamParser uses a weak pointer when reading from the
424 // socket, so it won't be called back after being destroyed. The
425 // HttpStreamParser is owned by HttpBasicState which is owned by this object,
426 // so this use of base::Unretained() is safe.
427 int rv = parser()->ReadResponseHeaders(
428 base::Bind(&WebSocketBasicHandshakeStream::ReadResponseHeadersCallback,
429 base::Unretained(this),
430 callback));
431 if (rv == ERR_IO_PENDING)
432 return rv;
433 return ValidateResponse(rv);
436 int WebSocketBasicHandshakeStream::ReadResponseBody(
437 IOBuffer* buf,
438 int buf_len,
439 const CompletionCallback& callback) {
440 return parser()->ReadResponseBody(buf, buf_len, callback);
443 void WebSocketBasicHandshakeStream::Close(bool not_reusable) {
444 // This class ignores the value of |not_reusable| and never lets the socket be
445 // re-used.
446 if (parser())
447 parser()->Close(true);
450 bool WebSocketBasicHandshakeStream::IsResponseBodyComplete() const {
451 return parser()->IsResponseBodyComplete();
454 bool WebSocketBasicHandshakeStream::CanFindEndOfResponse() const {
455 return parser() && parser()->CanFindEndOfResponse();
458 bool WebSocketBasicHandshakeStream::IsConnectionReused() const {
459 return parser()->IsConnectionReused();
462 void WebSocketBasicHandshakeStream::SetConnectionReused() {
463 parser()->SetConnectionReused();
466 bool WebSocketBasicHandshakeStream::IsConnectionReusable() const {
467 return false;
470 int64 WebSocketBasicHandshakeStream::GetTotalReceivedBytes() const {
471 return 0;
474 bool WebSocketBasicHandshakeStream::GetLoadTimingInfo(
475 LoadTimingInfo* load_timing_info) const {
476 return state_.connection()->GetLoadTimingInfo(IsConnectionReused(),
477 load_timing_info);
480 void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo* ssl_info) {
481 parser()->GetSSLInfo(ssl_info);
484 void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo(
485 SSLCertRequestInfo* cert_request_info) {
486 parser()->GetSSLCertRequestInfo(cert_request_info);
489 bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { return false; }
491 void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession* session) {
492 HttpResponseBodyDrainer* drainer = new HttpResponseBodyDrainer(this);
493 drainer->Start(session);
494 // |drainer| will delete itself.
497 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) {
498 // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is
499 // gone, then copy whatever has happened there over here.
502 scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() {
503 // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make
504 // sure it does not touch it again before it is destroyed.
505 state_.DeleteParser();
506 WebSocketTransportClientSocketPool::UnlockEndpoint(state_.connection());
507 scoped_ptr<WebSocketStream> basic_stream(
508 new WebSocketBasicStream(state_.ReleaseConnection(),
509 state_.read_buf(),
510 sub_protocol_,
511 extensions_));
512 DCHECK(extension_params_.get());
513 if (extension_params_->deflate_enabled) {
514 UMA_HISTOGRAM_ENUMERATION(
515 "Net.WebSocket.DeflateMode",
516 extension_params_->deflate_mode,
517 WebSocketDeflater::NUM_CONTEXT_TAKEOVER_MODE_TYPES);
519 return scoped_ptr<WebSocketStream>(
520 new WebSocketDeflateStream(basic_stream.Pass(),
521 extension_params_->deflate_mode,
522 extension_params_->client_window_bits,
523 scoped_ptr<WebSocketDeflatePredictor>(
524 new WebSocketDeflatePredictorImpl)));
525 } else {
526 return basic_stream.Pass();
530 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
531 const std::string& key) {
532 handshake_challenge_for_testing_.reset(new std::string(key));
535 void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback(
536 const CompletionCallback& callback,
537 int result) {
538 callback.Run(ValidateResponse(result));
541 void WebSocketBasicHandshakeStream::OnFinishOpeningHandshake() {
542 DCHECK(http_response_info_);
543 WebSocketDispatchOnFinishOpeningHandshake(connect_delegate_,
544 url_,
545 http_response_info_->headers,
546 http_response_info_->response_time);
549 int WebSocketBasicHandshakeStream::ValidateResponse(int rv) {
550 DCHECK(http_response_info_);
551 // Most net errors happen during connection, so they are not seen by this
552 // method. The histogram for error codes is created in
553 // Delegate::OnResponseStarted in websocket_stream.cc instead.
554 if (rv >= 0) {
555 const HttpResponseHeaders* headers = http_response_info_->headers.get();
556 const int response_code = headers->response_code();
557 UMA_HISTOGRAM_SPARSE_SLOWLY("Net.WebSocket.ResponseCode", response_code);
558 switch (response_code) {
559 case HTTP_SWITCHING_PROTOCOLS:
560 OnFinishOpeningHandshake();
561 return ValidateUpgradeResponse(headers);
563 // We need to pass these through for authentication to work.
564 case HTTP_UNAUTHORIZED:
565 case HTTP_PROXY_AUTHENTICATION_REQUIRED:
566 return OK;
568 // Other status codes are potentially risky (see the warnings in the
569 // WHATWG WebSocket API spec) and so are dropped by default.
570 default:
571 // A WebSocket server cannot be using HTTP/0.9, so if we see version
572 // 0.9, it means the response was garbage.
573 // Reporting "Unexpected response code: 200" in this case is not
574 // helpful, so use a different error message.
575 if (headers->GetHttpVersion() == HttpVersion(0, 9)) {
576 set_failure_message(
577 "Error during WebSocket handshake: Invalid status line");
578 } else {
579 set_failure_message(base::StringPrintf(
580 "Error during WebSocket handshake: Unexpected response code: %d",
581 headers->response_code()));
583 OnFinishOpeningHandshake();
584 return ERR_INVALID_RESPONSE;
586 } else {
587 if (rv == ERR_EMPTY_RESPONSE) {
588 set_failure_message(
589 "Connection closed before receiving a handshake response");
590 return rv;
592 set_failure_message(std::string("Error during WebSocket handshake: ") +
593 ErrorToString(rv));
594 OnFinishOpeningHandshake();
595 return rv;
599 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
600 const HttpResponseHeaders* headers) {
601 extension_params_.reset(new WebSocketExtensionParams);
602 std::string failure_message;
603 if (ValidateUpgrade(headers, &failure_message) &&
604 ValidateSecWebSocketAccept(
605 headers, handshake_challenge_response_, &failure_message) &&
606 ValidateConnection(headers, &failure_message) &&
607 ValidateSubProtocol(headers,
608 requested_sub_protocols_,
609 &sub_protocol_,
610 &failure_message) &&
611 ValidateExtensions(headers,
612 requested_extensions_,
613 &extensions_,
614 &failure_message,
615 extension_params_.get())) {
616 return OK;
618 set_failure_message("Error during WebSocket handshake: " + failure_message);
619 return ERR_INVALID_RESPONSE;
622 void WebSocketBasicHandshakeStream::set_failure_message(
623 const std::string& failure_message) {
624 *failure_message_ = failure_message;
627 } // namespace net