1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_frame.h"
9 #include "base/basictypes.h"
10 #include "base/big_endian.h"
11 #include "base/logging.h"
12 #include "base/rand_util.h"
13 #include "net/base/io_buffer.h"
14 #include "net/base/net_errors.h"
20 const uint8 kFinalBit
= 0x80;
21 const uint8 kReserved1Bit
= 0x40;
22 const uint8 kReserved2Bit
= 0x20;
23 const uint8 kReserved3Bit
= 0x10;
24 const uint8 kOpCodeMask
= 0xF;
25 const uint8 kMaskBit
= 0x80;
26 const uint64 kMaxPayloadLengthWithoutExtendedLengthField
= 125;
27 const uint64 kPayloadLengthWithTwoByteExtendedLengthField
= 126;
28 const uint64 kPayloadLengthWithEightByteExtendedLengthField
= 127;
30 inline void MaskWebSocketFramePayloadByBytes(
31 const WebSocketMaskingKey
& masking_key
,
32 size_t masking_key_offset
,
35 for (char* masked
= begin
; masked
!= end
; ++masked
) {
36 *masked
^= masking_key
.key
[masking_key_offset
++];
37 if (masking_key_offset
== WebSocketFrameHeader::kMaskingKeyLength
)
38 masking_key_offset
= 0;
44 scoped_ptr
<WebSocketFrameHeader
> WebSocketFrameHeader::Clone() const {
45 scoped_ptr
<WebSocketFrameHeader
> ret(new WebSocketFrameHeader(opcode
));
50 void WebSocketFrameHeader::CopyFrom(const WebSocketFrameHeader
& source
) {
52 reserved1
= source
.reserved1
;
53 reserved2
= source
.reserved2
;
54 reserved3
= source
.reserved3
;
55 opcode
= source
.opcode
;
56 masked
= source
.masked
;
57 payload_length
= source
.payload_length
;
60 WebSocketFrame::WebSocketFrame(WebSocketFrameHeader::OpCode opcode
)
63 WebSocketFrame::~WebSocketFrame() {}
65 WebSocketFrameChunk::WebSocketFrameChunk() : final_chunk(false) {}
67 WebSocketFrameChunk::~WebSocketFrameChunk() {}
69 int GetWebSocketFrameHeaderSize(const WebSocketFrameHeader
& header
) {
70 int extended_length_size
= 0;
71 if (header
.payload_length
> kMaxPayloadLengthWithoutExtendedLengthField
&&
72 header
.payload_length
<= kuint16max
) {
73 extended_length_size
= 2;
74 } else if (header
.payload_length
> kuint16max
) {
75 extended_length_size
= 8;
78 return (WebSocketFrameHeader::kBaseHeaderSize
+ extended_length_size
+
79 (header
.masked
? WebSocketFrameHeader::kMaskingKeyLength
: 0));
82 int WriteWebSocketFrameHeader(const WebSocketFrameHeader
& header
,
83 const WebSocketMaskingKey
* masking_key
,
86 DCHECK((header
.opcode
& kOpCodeMask
) == header
.opcode
)
87 << "header.opcode must fit to kOpCodeMask.";
88 DCHECK(header
.payload_length
<= static_cast<uint64
>(kint64max
))
89 << "WebSocket specification doesn't allow a frame longer than "
90 << "kint64max (0x7FFFFFFFFFFFFFFF) bytes.";
91 DCHECK_GE(buffer_size
, 0);
93 // WebSocket frame format is as follows:
94 // - Common header (2 bytes)
95 // - Optional extended payload length
96 // (2 or 8 bytes, present if actual payload length is more than 125 bytes)
97 // - Optional masking key (4 bytes, present if MASK bit is on)
98 // - Actual payload (XOR masked with masking key if MASK bit is on)
100 // This function constructs frame header (the first three in the list
103 int header_size
= GetWebSocketFrameHeaderSize(header
);
104 if (header_size
> buffer_size
)
105 return ERR_INVALID_ARGUMENT
;
107 int buffer_index
= 0;
109 uint8 first_byte
= 0u;
110 first_byte
|= header
.final
? kFinalBit
: 0u;
111 first_byte
|= header
.reserved1
? kReserved1Bit
: 0u;
112 first_byte
|= header
.reserved2
? kReserved2Bit
: 0u;
113 first_byte
|= header
.reserved3
? kReserved3Bit
: 0u;
114 first_byte
|= header
.opcode
& kOpCodeMask
;
115 buffer
[buffer_index
++] = first_byte
;
117 int extended_length_size
= 0;
118 uint8 second_byte
= 0u;
119 second_byte
|= header
.masked
? kMaskBit
: 0u;
120 if (header
.payload_length
<= kMaxPayloadLengthWithoutExtendedLengthField
) {
121 second_byte
|= header
.payload_length
;
122 } else if (header
.payload_length
<= kuint16max
) {
123 second_byte
|= kPayloadLengthWithTwoByteExtendedLengthField
;
124 extended_length_size
= 2;
126 second_byte
|= kPayloadLengthWithEightByteExtendedLengthField
;
127 extended_length_size
= 8;
129 buffer
[buffer_index
++] = second_byte
;
131 // Writes "extended payload length" field.
132 if (extended_length_size
== 2) {
133 uint16 payload_length_16
= static_cast<uint16
>(header
.payload_length
);
134 base::WriteBigEndian(buffer
+ buffer_index
, payload_length_16
);
135 buffer_index
+= sizeof(payload_length_16
);
136 } else if (extended_length_size
== 8) {
137 base::WriteBigEndian(buffer
+ buffer_index
, header
.payload_length
);
138 buffer_index
+= sizeof(header
.payload_length
);
141 // Writes "masking key" field, if needed.
144 std::copy(masking_key
->key
,
145 masking_key
->key
+ WebSocketFrameHeader::kMaskingKeyLength
,
146 buffer
+ buffer_index
);
147 buffer_index
+= WebSocketFrameHeader::kMaskingKeyLength
;
149 DCHECK(!masking_key
);
152 DCHECK_EQ(header_size
, buffer_index
);
156 WebSocketMaskingKey
GenerateWebSocketMaskingKey() {
157 // Masking keys should be generated from a cryptographically secure random
158 // number generator, which means web application authors should not be able
159 // to guess the next value of masking key.
160 WebSocketMaskingKey masking_key
;
161 base::RandBytes(masking_key
.key
, WebSocketFrameHeader::kMaskingKeyLength
);
165 void MaskWebSocketFramePayload(const WebSocketMaskingKey
& masking_key
,
169 static const size_t kMaskingKeyLength
=
170 WebSocketFrameHeader::kMaskingKeyLength
;
172 DCHECK_GE(data_size
, 0);
174 // Most of the masking is done one word at a time, except for the beginning
175 // and the end of the buffer which may be unaligned. We use size_t to get the
176 // word size for this architecture. We require it be a multiple of
177 // kMaskingKeyLength in size.
178 typedef size_t PackedMaskType
;
179 PackedMaskType packed_mask_key
= 0;
180 static const size_t kPackedMaskKeySize
= sizeof(packed_mask_key
);
181 static_assert((kPackedMaskKeySize
>= kMaskingKeyLength
&&
182 kPackedMaskKeySize
% kMaskingKeyLength
== 0),
183 "word size is not a multiple of mask length");
184 char* const end
= data
+ data_size
;
185 // If the buffer is too small for the vectorised version to be useful, revert
186 // to the byte-at-a-time implementation early.
187 if (data_size
<= static_cast<int>(kPackedMaskKeySize
* 2)) {
188 MaskWebSocketFramePayloadByBytes(
189 masking_key
, frame_offset
% kMaskingKeyLength
, data
, end
);
192 const size_t data_modulus
=
193 reinterpret_cast<size_t>(data
) % kPackedMaskKeySize
;
194 char* const aligned_begin
=
195 data_modulus
== 0 ? data
: (data
+ kPackedMaskKeySize
- data_modulus
);
196 // Guaranteed by the above check for small data_size.
197 DCHECK(aligned_begin
< end
);
198 MaskWebSocketFramePayloadByBytes(
199 masking_key
, frame_offset
% kMaskingKeyLength
, data
, aligned_begin
);
200 const size_t end_modulus
= reinterpret_cast<size_t>(end
) % kPackedMaskKeySize
;
201 char* const aligned_end
= end
- end_modulus
;
202 // Guaranteed by the above check for small data_size.
203 DCHECK(aligned_end
> aligned_begin
);
204 // Create a version of the mask which is rotated by the appropriate offset
205 // for our alignment. The "trick" here is that 0 XORed with the mask will
206 // give the value of the mask for the appropriate byte.
207 char realigned_mask
[kMaskingKeyLength
] = {};
208 MaskWebSocketFramePayloadByBytes(
210 (frame_offset
+ aligned_begin
- data
) % kMaskingKeyLength
,
212 realigned_mask
+ kMaskingKeyLength
);
214 for (size_t i
= 0; i
< kPackedMaskKeySize
; i
+= kMaskingKeyLength
) {
215 // memcpy() is allegedly blessed by the C++ standard for type-punning.
216 memcpy(reinterpret_cast<char*>(&packed_mask_key
) + i
,
222 for (char* merged
= aligned_begin
; merged
!= aligned_end
;
223 merged
+= kPackedMaskKeySize
) {
224 // This is not quite standard-compliant C++. However, the standard-compliant
225 // equivalent (using memcpy()) compiles to slower code using g++. In
226 // practice, this will work for the compilers and architectures currently
227 // supported by Chromium, and the tests are extremely unlikely to pass if a
228 // future compiler/architecture breaks it.
229 *reinterpret_cast<PackedMaskType
*>(merged
) ^= packed_mask_key
;
232 MaskWebSocketFramePayloadByBytes(
234 (frame_offset
+ (aligned_end
- data
)) % kMaskingKeyLength
,