1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_handshake_handler.h"
9 #include "base/base64.h"
10 #include "base/sha1.h"
11 #include "base/strings/string_number_conversions.h"
12 #include "base/strings/string_piece.h"
13 #include "base/strings/string_tokenizer.h"
14 #include "base/strings/string_util.h"
15 #include "base/strings/stringprintf.h"
16 #include "net/http/http_request_headers.h"
17 #include "net/http/http_response_headers.h"
18 #include "net/http/http_util.h"
19 #include "net/websockets/websocket_handshake_constants.h"
25 const int kVersionHeaderValueForRFC6455
= 13;
27 // Splits |handshake_message| into Status-Line or Request-Line (including CRLF)
28 // and headers (excluding 2nd CRLF of double CRLFs at the end of a handshake
30 void ParseHandshakeHeader(
31 const char* handshake_message
, int len
,
32 std::string
* request_line
,
33 std::string
* headers
) {
34 size_t i
= base::StringPiece(handshake_message
, len
).find_first_of("\r\n");
35 if (i
== base::StringPiece::npos
) {
36 *request_line
= std::string(handshake_message
, len
);
40 // |request_line| includes \r\n.
41 *request_line
= std::string(handshake_message
, i
+ 2);
43 int header_len
= len
- (i
+ 2) - 2;
45 // |handshake_message| includes trailing \r\n\r\n.
46 // |headers| doesn't include 2nd \r\n.
47 *headers
= std::string(handshake_message
+ i
+ 2, header_len
);
53 void FetchHeaders(const std::string
& headers
,
54 const char* const headers_to_get
[],
55 size_t headers_to_get_len
,
56 std::vector
<std::string
>* values
) {
57 net::HttpUtil::HeadersIterator
iter(headers
.begin(), headers
.end(), "\r\n");
58 while (iter
.GetNext()) {
59 for (size_t i
= 0; i
< headers_to_get_len
; i
++) {
60 if (LowerCaseEqualsASCII(iter
.name_begin(), iter
.name_end(),
62 values
->push_back(iter
.values());
68 bool GetHeaderName(std::string::const_iterator line_begin
,
69 std::string::const_iterator line_end
,
70 std::string::const_iterator
* name_begin
,
71 std::string::const_iterator
* name_end
) {
72 std::string::const_iterator colon
= std::find(line_begin
, line_end
, ':');
73 if (colon
== line_end
) {
76 *name_begin
= line_begin
;
78 if (*name_begin
== *name_end
|| net::HttpUtil::IsLWS(**name_begin
))
80 net::HttpUtil::TrimLWS(name_begin
, name_end
);
84 // Similar to HttpUtil::StripHeaders, but it preserves malformed headers, that
85 // is, lines that are not formatted as "<name>: <value>\r\n".
86 std::string
FilterHeaders(
87 const std::string
& headers
,
88 const char* const headers_to_remove
[],
89 size_t headers_to_remove_len
) {
90 std::string filtered_headers
;
92 base::StringTokenizer
lines(headers
.begin(), headers
.end(), "\r\n");
93 while (lines
.GetNext()) {
94 std::string::const_iterator line_begin
= lines
.token_begin();
95 std::string::const_iterator line_end
= lines
.token_end();
96 std::string::const_iterator name_begin
;
97 std::string::const_iterator name_end
;
98 bool should_remove
= false;
99 if (GetHeaderName(line_begin
, line_end
, &name_begin
, &name_end
)) {
100 for (size_t i
= 0; i
< headers_to_remove_len
; ++i
) {
101 if (LowerCaseEqualsASCII(name_begin
, name_end
, headers_to_remove
[i
])) {
102 should_remove
= true;
107 if (!should_remove
) {
108 filtered_headers
.append(line_begin
, line_end
);
109 filtered_headers
.append("\r\n");
112 return filtered_headers
;
115 bool CheckVersionInRequest(const std::string
& request_headers
) {
116 std::vector
<std::string
> values
;
117 const char* const headers_to_get
[1] = {
118 websockets::kSecWebSocketVersionLowercase
};
119 FetchHeaders(request_headers
, headers_to_get
, 1, &values
);
120 DCHECK_LE(values
.size(), 1U);
125 bool conversion_success
= base::StringToInt(values
[0], &version
);
126 if (!conversion_success
)
129 return version
== kVersionHeaderValueForRFC6455
;
132 // Append a header to a string. Equivalent to
133 // response_message += header + ": " + value + "\r\n"
134 // but avoids unnecessary allocations and copies.
135 void AppendHeader(const base::StringPiece
& header
,
136 const base::StringPiece
& value
,
137 std::string
* response_message
) {
138 static const char kColonSpace
[] = ": ";
139 const size_t kColonSpaceSize
= sizeof(kColonSpace
) - 1;
140 static const char kCrNl
[] = "\r\n";
141 const size_t kCrNlSize
= sizeof(kCrNl
) - 1;
144 header
.size() + kColonSpaceSize
+ value
.size() + kCrNlSize
;
145 response_message
->reserve(response_message
->size() + extra_size
);
146 response_message
->append(header
.begin(), header
.end());
147 response_message
->append(kColonSpace
, kColonSpace
+ kColonSpaceSize
);
148 response_message
->append(value
.begin(), value
.end());
149 response_message
->append(kCrNl
, kCrNl
+ kCrNlSize
);
154 WebSocketHandshakeRequestHandler::WebSocketHandshakeRequestHandler()
155 : original_length_(0),
158 bool WebSocketHandshakeRequestHandler::ParseRequest(
159 const char* data
, int length
) {
160 DCHECK_GT(length
, 0);
161 std::string
input(data
, length
);
162 int input_header_length
=
163 HttpUtil::LocateEndOfHeaders(input
.data(), input
.size(), 0);
164 if (input_header_length
<= 0)
167 ParseHandshakeHeader(input
.data(),
172 if (!CheckVersionInRequest(headers_
)) {
177 original_length_
= input_header_length
;
181 size_t WebSocketHandshakeRequestHandler::original_length() const {
182 return original_length_
;
185 void WebSocketHandshakeRequestHandler::AppendHeaderIfMissing(
186 const std::string
& name
, const std::string
& value
) {
187 DCHECK(!headers_
.empty());
188 HttpUtil::AppendHeaderIfMissing(name
.c_str(), value
, &headers_
);
191 void WebSocketHandshakeRequestHandler::RemoveHeaders(
192 const char* const headers_to_remove
[],
193 size_t headers_to_remove_len
) {
194 DCHECK(!headers_
.empty());
195 headers_
= FilterHeaders(
196 headers_
, headers_to_remove
, headers_to_remove_len
);
199 HttpRequestInfo
WebSocketHandshakeRequestHandler::GetRequestInfo(
200 const GURL
& url
, std::string
* challenge
) {
201 HttpRequestInfo request_info
;
202 request_info
.url
= url
;
203 size_t method_end
= base::StringPiece(request_line_
).find_first_of(" ");
204 if (method_end
!= base::StringPiece::npos
)
205 request_info
.method
= std::string(request_line_
.data(), method_end
);
207 request_info
.extra_headers
.Clear();
208 request_info
.extra_headers
.AddHeadersFromString(headers_
);
210 request_info
.extra_headers
.RemoveHeader(websockets::kUpgrade
);
211 request_info
.extra_headers
.RemoveHeader(HttpRequestHeaders::kConnection
);
214 bool header_present
= request_info
.extra_headers
.GetHeader(
215 websockets::kSecWebSocketKey
, &key
);
216 DCHECK(header_present
);
217 request_info
.extra_headers
.RemoveHeader(websockets::kSecWebSocketKey
);
222 bool WebSocketHandshakeRequestHandler::GetRequestHeaderBlock(
224 SpdyHeaderBlock
* headers
,
225 std::string
* challenge
,
226 int spdy_protocol_version
) {
227 // Construct opening handshake request headers as a SPDY header block.
228 // For details, see WebSocket Layering over SPDY/3 Draft 8.
229 if (spdy_protocol_version
<= 2) {
230 (*headers
)["path"] = url
.path();
231 (*headers
)["version"] = "WebSocket/13";
232 (*headers
)["scheme"] = url
.scheme();
234 (*headers
)[":path"] = url
.path();
235 (*headers
)[":version"] = "WebSocket/13";
236 (*headers
)[":scheme"] = url
.scheme();
239 HttpUtil::HeadersIterator
iter(headers_
.begin(), headers_
.end(), "\r\n");
240 while (iter
.GetNext()) {
241 if (LowerCaseEqualsASCII(iter
.name_begin(),
243 websockets::kUpgradeLowercase
) ||
244 LowerCaseEqualsASCII(
245 iter
.name_begin(), iter
.name_end(), "connection") ||
246 LowerCaseEqualsASCII(iter
.name_begin(),
248 websockets::kSecWebSocketVersionLowercase
)) {
249 // These headers must be ignored.
251 } else if (LowerCaseEqualsASCII(iter
.name_begin(),
253 websockets::kSecWebSocketKeyLowercase
)) {
254 *challenge
= iter
.values();
255 // Sec-WebSocket-Key is not sent to the server.
257 } else if (LowerCaseEqualsASCII(
258 iter
.name_begin(), iter
.name_end(), "host") ||
259 LowerCaseEqualsASCII(
260 iter
.name_begin(), iter
.name_end(), "origin") ||
261 LowerCaseEqualsASCII(
264 websockets::kSecWebSocketProtocolLowercase
) ||
265 LowerCaseEqualsASCII(
268 websockets::kSecWebSocketExtensionsLowercase
)) {
269 // TODO(toyoshim): Some WebSocket extensions may not be compatible with
270 // SPDY. We should omit them from a Sec-WebSocket-Extension header.
272 if (spdy_protocol_version
<= 2)
273 name
= base::StringToLowerASCII(iter
.name());
275 name
= ":" + base::StringToLowerASCII(iter
.name());
276 (*headers
)[name
] = iter
.values();
279 // Others should be sent out to |headers|.
280 std::string name
= base::StringToLowerASCII(iter
.name());
281 SpdyHeaderBlock::iterator found
= headers
->find(name
);
282 if (found
== headers
->end()) {
283 (*headers
)[name
] = iter
.values();
285 // For now, websocket doesn't use multiple headers, but follows to http.
286 found
->second
.append(1, '\0'); // +=() doesn't append 0's
287 found
->second
.append(iter
.values());
294 std::string
WebSocketHandshakeRequestHandler::GetRawRequest() {
295 DCHECK(!request_line_
.empty());
296 DCHECK(!headers_
.empty());
298 std::string raw_request
= request_line_
+ headers_
+ "\r\n";
299 raw_length_
= raw_request
.size();
303 size_t WebSocketHandshakeRequestHandler::raw_length() const {
304 DCHECK_GT(raw_length_
, 0);
308 WebSocketHandshakeResponseHandler::WebSocketHandshakeResponseHandler()
309 : original_header_length_(0) {}
311 WebSocketHandshakeResponseHandler::~WebSocketHandshakeResponseHandler() {}
313 size_t WebSocketHandshakeResponseHandler::ParseRawResponse(
314 const char* data
, int length
) {
315 DCHECK_GT(length
, 0);
317 DCHECK(!status_line_
.empty());
318 // headers_ might be empty for wrong response from server.
323 size_t old_original_length
= original_
.size();
325 original_
.append(data
, length
);
326 // TODO(ukai): fail fast when response gives wrong status code.
327 original_header_length_
= HttpUtil::LocateEndOfHeaders(
328 original_
.data(), original_
.size(), 0);
332 ParseHandshakeHeader(original_
.data(),
333 original_header_length_
,
336 int header_size
= status_line_
.size() + headers_
.size();
337 DCHECK_GE(original_header_length_
, header_size
);
338 header_separator_
= std::string(original_
.data() + header_size
,
339 original_header_length_
- header_size
);
340 return original_header_length_
- old_original_length
;
343 bool WebSocketHandshakeResponseHandler::HasResponse() const {
344 return original_header_length_
> 0 &&
345 static_cast<size_t>(original_header_length_
) <= original_
.size();
348 void ComputeSecWebSocketAccept(const std::string
& key
,
349 std::string
* accept
) {
353 base::SHA1HashString(key
+ websockets::kWebSocketGuid
);
354 base::Base64Encode(hash
, accept
);
357 bool WebSocketHandshakeResponseHandler::ParseResponseInfo(
358 const HttpResponseInfo
& response_info
,
359 const std::string
& challenge
) {
360 if (!response_info
.headers
.get())
363 // TODO(ricea): Eliminate all the reallocations and string copies.
364 std::string response_message
;
365 response_message
= response_info
.headers
->GetStatusLine();
366 response_message
+= "\r\n";
368 AppendHeader(websockets::kUpgrade
,
369 websockets::kWebSocketLowercase
,
373 HttpRequestHeaders::kConnection
, websockets::kUpgrade
, &response_message
);
375 std::string websocket_accept
;
376 ComputeSecWebSocketAccept(challenge
, &websocket_accept
);
378 websockets::kSecWebSocketAccept
, websocket_accept
, &response_message
);
383 while (response_info
.headers
->EnumerateHeaderLines(&iter
, &name
, &value
)) {
384 AppendHeader(name
, value
, &response_message
);
386 response_message
+= "\r\n";
388 return ParseRawResponse(response_message
.data(),
389 response_message
.size()) == response_message
.size();
392 bool WebSocketHandshakeResponseHandler::ParseResponseHeaderBlock(
393 const SpdyHeaderBlock
& headers
,
394 const std::string
& challenge
,
395 int spdy_protocol_version
) {
396 SpdyHeaderBlock::const_iterator status
;
397 if (spdy_protocol_version
<= 2)
398 status
= headers
.find("status");
400 status
= headers
.find(":status");
401 if (status
== headers
.end())
405 base::SHA1HashString(challenge
+ websockets::kWebSocketGuid
);
406 std::string websocket_accept
;
407 base::Base64Encode(hash
, &websocket_accept
);
409 std::string response_message
= base::StringPrintf(
410 "%s %s\r\n", websockets::kHttpProtocolVersion
, status
->second
.c_str());
413 websockets::kUpgrade
, websockets::kWebSocketLowercase
, &response_message
);
415 HttpRequestHeaders::kConnection
, websockets::kUpgrade
, &response_message
);
417 websockets::kSecWebSocketAccept
, websocket_accept
, &response_message
);
419 for (SpdyHeaderBlock::const_iterator iter
= headers
.begin();
420 iter
!= headers
.end();
422 // For each value, if the server sends a NUL-separated list of values,
423 // we separate that back out into individual headers for each value
425 if ((spdy_protocol_version
<= 2 &&
426 LowerCaseEqualsASCII(iter
->first
, "status")) ||
427 (spdy_protocol_version
>= 3 &&
428 LowerCaseEqualsASCII(iter
->first
, ":status"))) {
429 // The status value is already handled as the first line of
430 // |response_message|. Just skip here.
433 const std::string
& value
= iter
->second
;
437 end
= value
.find('\0', start
);
439 if (end
!= std::string::npos
)
440 tval
= value
.substr(start
, (end
- start
));
442 tval
= value
.substr(start
);
443 if (spdy_protocol_version
>= 3 &&
444 (LowerCaseEqualsASCII(iter
->first
,
445 websockets::kSecWebSocketProtocolSpdy3
) ||
446 LowerCaseEqualsASCII(iter
->first
,
447 websockets::kSecWebSocketExtensionsSpdy3
)))
448 AppendHeader(iter
->first
.substr(1), tval
, &response_message
);
450 AppendHeader(iter
->first
, tval
, &response_message
);
452 } while (end
!= std::string::npos
);
454 response_message
+= "\r\n";
456 return ParseRawResponse(response_message
.data(),
457 response_message
.size()) == response_message
.size();
460 void WebSocketHandshakeResponseHandler::GetHeaders(
461 const char* const headers_to_get
[],
462 size_t headers_to_get_len
,
463 std::vector
<std::string
>* values
) {
464 DCHECK(HasResponse());
465 DCHECK(!status_line_
.empty());
466 // headers_ might be empty for wrong response from server.
467 if (headers_
.empty())
470 FetchHeaders(headers_
, headers_to_get
, headers_to_get_len
, values
);
473 void WebSocketHandshakeResponseHandler::RemoveHeaders(
474 const char* const headers_to_remove
[],
475 size_t headers_to_remove_len
) {
476 DCHECK(HasResponse());
477 DCHECK(!status_line_
.empty());
478 // headers_ might be empty for wrong response from server.
479 if (headers_
.empty())
482 headers_
= FilterHeaders(headers_
, headers_to_remove
, headers_to_remove_len
);
485 std::string
WebSocketHandshakeResponseHandler::GetRawResponse() const {
486 DCHECK(HasResponse());
487 return original_
.substr(0, original_header_length_
);
490 std::string
WebSocketHandshakeResponseHandler::GetResponse() {
491 DCHECK(HasResponse());
492 DCHECK(!status_line_
.empty());
493 // headers_ might be empty for wrong response from server.
495 return status_line_
+ headers_
+ header_separator_
;