1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/server/web_socket.h"
9 #include "base/base64.h"
10 #include "base/rand_util.h"
11 #include "base/logging.h"
13 #include "base/sha1.h"
14 #include "base/strings/string_number_conversions.h"
15 #include "base/strings/stringprintf.h"
16 #include "base/sys_byteorder.h"
17 #include "net/server/http_connection.h"
18 #include "net/server/http_server.h"
19 #include "net/server/http_server_request_info.h"
20 #include "net/server/http_server_response_info.h"
26 static uint32
WebSocketKeyFingerprint(const std::string
& str
) {
28 const char* p_char
= str
.c_str();
29 int length
= str
.length();
31 for (int i
= 0; i
< length
; ++i
) {
32 if (p_char
[i
] >= '0' && p_char
[i
] <= '9')
33 result
.append(&p_char
[i
], 1);
34 else if (p_char
[i
] == ' ')
40 if (!base::StringToInt64(result
, &number
))
42 return base::HostToNet32(static_cast<uint32
>(number
/ spaces
));
45 class WebSocketHixie76
: public net::WebSocket
{
47 static net::WebSocket
* Create(HttpServer
* server
,
48 HttpConnection
* connection
,
49 const HttpServerRequestInfo
& request
,
51 if (connection
->read_buf()->GetSize() <
52 static_cast<int>(*pos
+ kWebSocketHandshakeBodyLen
))
54 return new WebSocketHixie76(server
, connection
, request
, pos
);
57 virtual void Accept(const HttpServerRequestInfo
& request
) OVERRIDE
{
58 std::string key1
= request
.GetHeaderValue("sec-websocket-key1");
59 std::string key2
= request
.GetHeaderValue("sec-websocket-key2");
61 uint32 fp1
= WebSocketKeyFingerprint(key1
);
62 uint32 fp2
= WebSocketKeyFingerprint(key2
);
65 memcpy(data
, &fp1
, 4);
66 memcpy(data
+ 4, &fp2
, 4);
67 memcpy(data
+ 8, &key3_
[0], 8);
69 base::MD5Digest digest
;
70 base::MD5Sum(data
, 16, &digest
);
72 std::string origin
= request
.GetHeaderValue("origin");
73 std::string host
= request
.GetHeaderValue("host");
74 std::string location
= "ws://" + host
+ request
.path
;
77 base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
78 "Upgrade: WebSocket\r\n"
79 "Connection: Upgrade\r\n"
80 "Sec-WebSocket-Origin: %s\r\n"
81 "Sec-WebSocket-Location: %s\r\n"
85 server_
->SendRaw(connection_
->id(),
86 std::string(reinterpret_cast<char*>(digest
.a
), 16));
89 virtual ParseResult
Read(std::string
* message
) OVERRIDE
{
91 HttpConnection::ReadIOBuffer
* read_buf
= connection_
->read_buf();
92 if (read_buf
->StartOfBuffer()[0])
95 base::StringPiece
data(read_buf
->StartOfBuffer(), read_buf
->GetSize());
96 size_t pos
= data
.find('\377', 1);
97 if (pos
== base::StringPiece::npos
)
98 return FRAME_INCOMPLETE
;
100 message
->assign(data
.data() + 1, pos
- 1);
101 read_buf
->DidConsume(pos
+ 1);
106 virtual void Send(const std::string
& message
) OVERRIDE
{
107 char message_start
= 0;
108 char message_end
= -1;
109 server_
->SendRaw(connection_
->id(), std::string(1, message_start
));
110 server_
->SendRaw(connection_
->id(), message
);
111 server_
->SendRaw(connection_
->id(), std::string(1, message_end
));
115 static const int kWebSocketHandshakeBodyLen
;
117 WebSocketHixie76(HttpServer
* server
,
118 HttpConnection
* connection
,
119 const HttpServerRequestInfo
& request
,
121 : WebSocket(server
, connection
) {
122 std::string key1
= request
.GetHeaderValue("sec-websocket-key1");
123 std::string key2
= request
.GetHeaderValue("sec-websocket-key2");
126 server
->SendResponse(
128 HttpServerResponseInfo::CreateFor500(
129 "Invalid request format. Sec-WebSocket-Key1 is empty or isn't "
135 server
->SendResponse(
137 HttpServerResponseInfo::CreateFor500(
138 "Invalid request format. Sec-WebSocket-Key2 is empty or isn't "
143 key3_
.assign(connection
->read_buf()->StartOfBuffer() + *pos
,
144 kWebSocketHandshakeBodyLen
);
145 *pos
+= kWebSocketHandshakeBodyLen
;
150 DISALLOW_COPY_AND_ASSIGN(WebSocketHixie76
);
153 const int WebSocketHixie76::kWebSocketHandshakeBodyLen
= 8;
156 // Constants for hybi-10 frame format.
160 const OpCode kOpCodeContinuation
= 0x0;
161 const OpCode kOpCodeText
= 0x1;
162 const OpCode kOpCodeBinary
= 0x2;
163 const OpCode kOpCodeClose
= 0x8;
164 const OpCode kOpCodePing
= 0x9;
165 const OpCode kOpCodePong
= 0xA;
167 const unsigned char kFinalBit
= 0x80;
168 const unsigned char kReserved1Bit
= 0x40;
169 const unsigned char kReserved2Bit
= 0x20;
170 const unsigned char kReserved3Bit
= 0x10;
171 const unsigned char kOpCodeMask
= 0xF;
172 const unsigned char kMaskBit
= 0x80;
173 const unsigned char kPayloadLengthMask
= 0x7F;
175 const size_t kMaxSingleBytePayloadLength
= 125;
176 const size_t kTwoBytePayloadLengthField
= 126;
177 const size_t kEightBytePayloadLengthField
= 127;
178 const size_t kMaskingKeyWidthInBytes
= 4;
180 class WebSocketHybi17
: public WebSocket
{
182 static WebSocket
* Create(HttpServer
* server
,
183 HttpConnection
* connection
,
184 const HttpServerRequestInfo
& request
,
186 std::string version
= request
.GetHeaderValue("sec-websocket-version");
187 if (version
!= "8" && version
!= "13")
190 std::string key
= request
.GetHeaderValue("sec-websocket-key");
192 server
->SendResponse(
194 HttpServerResponseInfo::CreateFor500(
195 "Invalid request format. Sec-WebSocket-Key is empty or isn't "
199 return new WebSocketHybi17(server
, connection
, request
, pos
);
202 virtual void Accept(const HttpServerRequestInfo
& request
) OVERRIDE
{
203 static const char* const kWebSocketGuid
=
204 "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
205 std::string key
= request
.GetHeaderValue("sec-websocket-key");
206 std::string data
= base::StringPrintf("%s%s", key
.c_str(), kWebSocketGuid
);
207 std::string encoded_hash
;
208 base::Base64Encode(base::SHA1HashString(data
), &encoded_hash
);
212 base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
213 "Upgrade: WebSocket\r\n"
214 "Connection: Upgrade\r\n"
215 "Sec-WebSocket-Accept: %s\r\n"
217 encoded_hash
.c_str()));
220 virtual ParseResult
Read(std::string
* message
) OVERRIDE
{
221 HttpConnection::ReadIOBuffer
* read_buf
= connection_
->read_buf();
222 base::StringPiece
frame(read_buf
->StartOfBuffer(), read_buf
->GetSize());
223 int bytes_consumed
= 0;
225 WebSocket::DecodeFrameHybi17(frame
, true, &bytes_consumed
, message
);
226 if (result
== FRAME_OK
)
227 read_buf
->DidConsume(bytes_consumed
);
228 if (result
== FRAME_CLOSE
)
233 virtual void Send(const std::string
& message
) OVERRIDE
{
236 server_
->SendRaw(connection_
->id(),
237 WebSocket::EncodeFrameHybi17(message
, 0));
241 WebSocketHybi17(HttpServer
* server
,
242 HttpConnection
* connection
,
243 const HttpServerRequestInfo
& request
,
245 : WebSocket(server
, connection
),
264 const char* payload_
;
265 size_t payload_length_
;
266 const char* frame_end_
;
269 DISALLOW_COPY_AND_ASSIGN(WebSocketHybi17
);
272 } // anonymous namespace
274 WebSocket
* WebSocket::CreateWebSocket(HttpServer
* server
,
275 HttpConnection
* connection
,
276 const HttpServerRequestInfo
& request
,
278 WebSocket
* socket
= WebSocketHybi17::Create(server
, connection
, request
, pos
);
282 return WebSocketHixie76::Create(server
, connection
, request
, pos
);
286 WebSocket::ParseResult
WebSocket::DecodeFrameHybi17(
287 const base::StringPiece
& frame
,
290 std::string
* output
) {
291 size_t data_length
= frame
.length();
293 return FRAME_INCOMPLETE
;
295 const char* buffer_begin
= const_cast<char*>(frame
.data());
296 const char* p
= buffer_begin
;
297 const char* buffer_end
= p
+ data_length
;
299 unsigned char first_byte
= *p
++;
300 unsigned char second_byte
= *p
++;
302 bool final
= (first_byte
& kFinalBit
) != 0;
303 bool reserved1
= (first_byte
& kReserved1Bit
) != 0;
304 bool reserved2
= (first_byte
& kReserved2Bit
) != 0;
305 bool reserved3
= (first_byte
& kReserved3Bit
) != 0;
306 int op_code
= first_byte
& kOpCodeMask
;
307 bool masked
= (second_byte
& kMaskBit
) != 0;
308 if (!final
|| reserved1
|| reserved2
|| reserved3
)
309 return FRAME_ERROR
; // Extensions and not supported.
318 case kOpCodeBinary
: // We don't support binary frames yet.
319 case kOpCodeContinuation
: // We don't support binary frames yet.
320 case kOpCodePing
: // We don't support binary frames yet.
321 case kOpCodePong
: // We don't support binary frames yet.
326 if (client_frame
&& !masked
) // In Hybi-17 spec client MUST mask his frame.
329 uint64 payload_length64
= second_byte
& kPayloadLengthMask
;
330 if (payload_length64
> kMaxSingleBytePayloadLength
) {
331 int extended_payload_length_size
;
332 if (payload_length64
== kTwoBytePayloadLengthField
)
333 extended_payload_length_size
= 2;
335 DCHECK(payload_length64
== kEightBytePayloadLengthField
);
336 extended_payload_length_size
= 8;
338 if (buffer_end
- p
< extended_payload_length_size
)
339 return FRAME_INCOMPLETE
;
340 payload_length64
= 0;
341 for (int i
= 0; i
< extended_payload_length_size
; ++i
) {
342 payload_length64
<<= 8;
343 payload_length64
|= static_cast<unsigned char>(*p
++);
347 size_t actual_masking_key_length
= masked
? kMaskingKeyWidthInBytes
: 0;
348 static const uint64 max_payload_length
= 0x7FFFFFFFFFFFFFFFull
;
349 static size_t max_length
= std::numeric_limits
<size_t>::max();
350 if (payload_length64
> max_payload_length
||
351 payload_length64
+ actual_masking_key_length
> max_length
) {
352 // WebSocket frame length too large.
355 size_t payload_length
= static_cast<size_t>(payload_length64
);
357 size_t total_length
= actual_masking_key_length
+ payload_length
;
358 if (static_cast<size_t>(buffer_end
- p
) < total_length
)
359 return FRAME_INCOMPLETE
;
362 output
->resize(payload_length
);
363 const char* masking_key
= p
;
364 char* payload
= const_cast<char*>(p
+ kMaskingKeyWidthInBytes
);
365 for (size_t i
= 0; i
< payload_length
; ++i
) // Unmask the payload.
366 (*output
)[i
] = payload
[i
] ^ masking_key
[i
% kMaskingKeyWidthInBytes
];
368 output
->assign(p
, p
+ payload_length
);
371 size_t pos
= p
+ actual_masking_key_length
+ payload_length
- buffer_begin
;
372 *bytes_consumed
= pos
;
373 return closed
? FRAME_CLOSE
: FRAME_OK
;
377 std::string
WebSocket::EncodeFrameHybi17(const std::string
& message
,
379 std::vector
<char> frame
;
380 OpCode op_code
= kOpCodeText
;
381 size_t data_length
= message
.length();
383 frame
.push_back(kFinalBit
| op_code
);
384 char mask_key_bit
= masking_key
!= 0 ? kMaskBit
: 0;
385 if (data_length
<= kMaxSingleBytePayloadLength
)
386 frame
.push_back(data_length
| mask_key_bit
);
387 else if (data_length
<= 0xFFFF) {
388 frame
.push_back(kTwoBytePayloadLengthField
| mask_key_bit
);
389 frame
.push_back((data_length
& 0xFF00) >> 8);
390 frame
.push_back(data_length
& 0xFF);
392 frame
.push_back(kEightBytePayloadLengthField
| mask_key_bit
);
393 char extended_payload_length
[8];
394 size_t remaining
= data_length
;
395 // Fill the length into extended_payload_length in the network byte order.
396 for (int i
= 0; i
< 8; ++i
) {
397 extended_payload_length
[7 - i
] = remaining
& 0xFF;
400 frame
.insert(frame
.end(),
401 extended_payload_length
,
402 extended_payload_length
+ 8);
406 const char* data
= const_cast<char*>(message
.data());
407 if (masking_key
!= 0) {
408 const char* mask_bytes
= reinterpret_cast<char*>(&masking_key
);
409 frame
.insert(frame
.end(), mask_bytes
, mask_bytes
+ 4);
410 for (size_t i
= 0; i
< data_length
; ++i
) // Mask the payload.
411 frame
.push_back(data
[i
] ^ mask_bytes
[i
% kMaskingKeyWidthInBytes
]);
413 frame
.insert(frame
.end(), data
, data
+ data_length
);
415 return std::string(&frame
[0], frame
.size());
418 WebSocket::WebSocket(HttpServer
* server
, HttpConnection
* connection
)
420 connection_(connection
) {