1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_deflate_stream.h"
10 #include "base/bind.h"
11 #include "base/logging.h"
12 #include "base/memory/ref_counted.h"
13 #include "base/memory/scoped_ptr.h"
14 #include "base/memory/scoped_vector.h"
15 #include "net/base/completion_callback.h"
16 #include "net/base/io_buffer.h"
17 #include "net/base/net_errors.h"
18 #include "net/websockets/websocket_deflate_predictor.h"
19 #include "net/websockets/websocket_deflater.h"
20 #include "net/websockets/websocket_errors.h"
21 #include "net/websockets/websocket_frame.h"
22 #include "net/websockets/websocket_inflater.h"
23 #include "net/websockets/websocket_stream.h"
31 const int kWindowBits
= 15;
32 const size_t kChunkSize
= 4 * 1024;
36 WebSocketDeflateStream::WebSocketDeflateStream(
37 scoped_ptr
<WebSocketStream
> stream
,
38 WebSocketDeflater::ContextTakeOverMode mode
,
39 int client_window_bits
,
40 scoped_ptr
<WebSocketDeflatePredictor
> predictor
)
41 : stream_(stream
.Pass()),
43 inflater_(kChunkSize
, kChunkSize
),
44 reading_state_(NOT_READING
),
45 writing_state_(NOT_WRITING
),
46 current_reading_opcode_(WebSocketFrameHeader::kOpCodeText
),
47 current_writing_opcode_(WebSocketFrameHeader::kOpCodeText
),
48 predictor_(predictor
.Pass()) {
50 DCHECK_GE(client_window_bits
, 8);
51 DCHECK_LE(client_window_bits
, 15);
52 deflater_
.Initialize(client_window_bits
);
53 inflater_
.Initialize(kWindowBits
);
56 WebSocketDeflateStream::~WebSocketDeflateStream() {}
58 int WebSocketDeflateStream::ReadFrames(ScopedVector
<WebSocketFrame
>* frames
,
59 const CompletionCallback
& callback
) {
60 int result
= stream_
->ReadFrames(
62 base::Bind(&WebSocketDeflateStream::OnReadComplete
,
63 base::Unretained(this),
64 base::Unretained(frames
),
68 DCHECK_EQ(OK
, result
);
69 DCHECK(!frames
->empty());
71 return InflateAndReadIfNecessary(frames
, callback
);
74 int WebSocketDeflateStream::WriteFrames(ScopedVector
<WebSocketFrame
>* frames
,
75 const CompletionCallback
& callback
) {
76 int result
= Deflate(frames
);
81 return stream_
->WriteFrames(frames
, callback
);
84 void WebSocketDeflateStream::Close() { stream_
->Close(); }
86 std::string
WebSocketDeflateStream::GetSubProtocol() const {
87 return stream_
->GetSubProtocol();
90 std::string
WebSocketDeflateStream::GetExtensions() const {
91 return stream_
->GetExtensions();
94 void WebSocketDeflateStream::OnReadComplete(
95 ScopedVector
<WebSocketFrame
>* frames
,
96 const CompletionCallback
& callback
,
100 callback
.Run(result
);
104 int r
= InflateAndReadIfNecessary(frames
, callback
);
105 if (r
!= ERR_IO_PENDING
)
109 int WebSocketDeflateStream::Deflate(ScopedVector
<WebSocketFrame
>* frames
) {
110 ScopedVector
<WebSocketFrame
> frames_to_write
;
111 // Store frames of the currently processed message if writing_state_ equals to
112 // WRITING_POSSIBLY_COMPRESSED_MESSAGE.
113 ScopedVector
<WebSocketFrame
> frames_of_message
;
114 for (size_t i
= 0; i
< frames
->size(); ++i
) {
115 DCHECK(!(*frames
)[i
]->header
.reserved1
);
116 if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames
)[i
]->header
.opcode
)) {
117 frames_to_write
.push_back((*frames
)[i
]);
121 if (writing_state_
== NOT_WRITING
)
122 OnMessageStart(*frames
, i
);
124 scoped_ptr
<WebSocketFrame
> frame((*frames
)[i
]);
126 predictor_
->RecordInputDataFrame(frame
.get());
128 if (writing_state_
== WRITING_UNCOMPRESSED_MESSAGE
) {
129 if (frame
->header
.final
)
130 writing_state_
= NOT_WRITING
;
131 predictor_
->RecordWrittenDataFrame(frame
.get());
132 frames_to_write
.push_back(frame
.release());
133 current_writing_opcode_
= WebSocketFrameHeader::kOpCodeContinuation
;
135 if (frame
->data
.get() &&
136 !deflater_
.AddBytes(frame
->data
->data(),
137 frame
->header
.payload_length
)) {
138 DVLOG(1) << "WebSocket protocol error. "
139 << "deflater_.AddBytes() returns an error.";
140 return ERR_WS_PROTOCOL_ERROR
;
142 if (frame
->header
.final
&& !deflater_
.Finish()) {
143 DVLOG(1) << "WebSocket protocol error. "
144 << "deflater_.Finish() returns an error.";
145 return ERR_WS_PROTOCOL_ERROR
;
148 if (writing_state_
== WRITING_COMPRESSED_MESSAGE
) {
149 if (deflater_
.CurrentOutputSize() >= kChunkSize
||
150 frame
->header
.final
) {
151 int result
= AppendCompressedFrame(frame
->header
, &frames_to_write
);
155 if (frame
->header
.final
)
156 writing_state_
= NOT_WRITING
;
158 DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE
, writing_state_
);
159 bool final
= frame
->header
.final
;
160 frames_of_message
.push_back(frame
.release());
162 int result
= AppendPossiblyCompressedMessage(&frames_of_message
,
166 frames_of_message
.clear();
167 writing_state_
= NOT_WRITING
;
172 DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE
, writing_state_
);
173 frames
->swap(frames_to_write
);
177 void WebSocketDeflateStream::OnMessageStart(
178 const ScopedVector
<WebSocketFrame
>& frames
, size_t index
) {
179 WebSocketFrame
* frame
= frames
[index
];
180 current_writing_opcode_
= frame
->header
.opcode
;
181 DCHECK(current_writing_opcode_
== WebSocketFrameHeader::kOpCodeText
||
182 current_writing_opcode_
== WebSocketFrameHeader::kOpCodeBinary
);
183 WebSocketDeflatePredictor::Result prediction
=
184 predictor_
->Predict(frames
, index
);
186 switch (prediction
) {
187 case WebSocketDeflatePredictor::DEFLATE
:
188 writing_state_
= WRITING_COMPRESSED_MESSAGE
;
190 case WebSocketDeflatePredictor::DO_NOT_DEFLATE
:
191 writing_state_
= WRITING_UNCOMPRESSED_MESSAGE
;
193 case WebSocketDeflatePredictor::TRY_DEFLATE
:
194 writing_state_
= WRITING_POSSIBLY_COMPRESSED_MESSAGE
;
200 int WebSocketDeflateStream::AppendCompressedFrame(
201 const WebSocketFrameHeader
& header
,
202 ScopedVector
<WebSocketFrame
>* frames_to_write
) {
203 const WebSocketFrameHeader::OpCode opcode
= current_writing_opcode_
;
204 scoped_refptr
<IOBufferWithSize
> compressed_payload
=
205 deflater_
.GetOutput(deflater_
.CurrentOutputSize());
206 if (!compressed_payload
.get()) {
207 DVLOG(1) << "WebSocket protocol error. "
208 << "deflater_.GetOutput() returns an error.";
209 return ERR_WS_PROTOCOL_ERROR
;
211 scoped_ptr
<WebSocketFrame
> compressed(new WebSocketFrame(opcode
));
212 compressed
->header
.CopyFrom(header
);
213 compressed
->header
.opcode
= opcode
;
214 compressed
->header
.final
= header
.final
;
215 compressed
->header
.reserved1
=
216 (opcode
!= WebSocketFrameHeader::kOpCodeContinuation
);
217 compressed
->data
= compressed_payload
;
218 compressed
->header
.payload_length
= compressed_payload
->size();
220 current_writing_opcode_
= WebSocketFrameHeader::kOpCodeContinuation
;
221 predictor_
->RecordWrittenDataFrame(compressed
.get());
222 frames_to_write
->push_back(compressed
.release());
226 int WebSocketDeflateStream::AppendPossiblyCompressedMessage(
227 ScopedVector
<WebSocketFrame
>* frames
,
228 ScopedVector
<WebSocketFrame
>* frames_to_write
) {
229 DCHECK(!frames
->empty());
231 const WebSocketFrameHeader::OpCode opcode
= current_writing_opcode_
;
232 scoped_refptr
<IOBufferWithSize
> compressed_payload
=
233 deflater_
.GetOutput(deflater_
.CurrentOutputSize());
234 if (!compressed_payload
.get()) {
235 DVLOG(1) << "WebSocket protocol error. "
236 << "deflater_.GetOutput() returns an error.";
237 return ERR_WS_PROTOCOL_ERROR
;
240 uint64 original_payload_length
= 0;
241 for (size_t i
= 0; i
< frames
->size(); ++i
) {
242 WebSocketFrame
* frame
= (*frames
)[i
];
243 // Asserts checking that frames represent one whole data message.
244 DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame
->header
.opcode
));
246 WebSocketFrameHeader::kOpCodeContinuation
!=
247 frame
->header
.opcode
);
248 DCHECK_EQ(i
== frames
->size() - 1, frame
->header
.final
);
249 original_payload_length
+= frame
->header
.payload_length
;
251 if (original_payload_length
<=
252 static_cast<uint64
>(compressed_payload
->size())) {
253 // Compression is not effective. Use the original frames.
254 for (size_t i
= 0; i
< frames
->size(); ++i
) {
255 WebSocketFrame
* frame
= (*frames
)[i
];
256 frames_to_write
->push_back(frame
);
257 predictor_
->RecordWrittenDataFrame(frame
);
260 frames
->weak_clear();
263 scoped_ptr
<WebSocketFrame
> compressed(new WebSocketFrame(opcode
));
264 compressed
->header
.CopyFrom((*frames
)[0]->header
);
265 compressed
->header
.opcode
= opcode
;
266 compressed
->header
.final
= true;
267 compressed
->header
.reserved1
= true;
268 compressed
->data
= compressed_payload
;
269 compressed
->header
.payload_length
= compressed_payload
->size();
271 predictor_
->RecordWrittenDataFrame(compressed
.get());
272 frames_to_write
->push_back(compressed
.release());
276 int WebSocketDeflateStream::Inflate(ScopedVector
<WebSocketFrame
>* frames
) {
277 ScopedVector
<WebSocketFrame
> frames_to_output
;
278 ScopedVector
<WebSocketFrame
> frames_passed
;
279 frames
->swap(frames_passed
);
280 for (size_t i
= 0; i
< frames_passed
.size(); ++i
) {
281 scoped_ptr
<WebSocketFrame
> frame(frames_passed
[i
]);
282 frames_passed
[i
] = NULL
;
283 DVLOG(3) << "Input frame: opcode=" << frame
->header
.opcode
284 << " final=" << frame
->header
.final
285 << " reserved1=" << frame
->header
.reserved1
286 << " payload_length=" << frame
->header
.payload_length
;
288 if (!WebSocketFrameHeader::IsKnownDataOpCode(frame
->header
.opcode
)) {
289 frames_to_output
.push_back(frame
.release());
293 if (reading_state_
== NOT_READING
) {
294 if (frame
->header
.reserved1
)
295 reading_state_
= READING_COMPRESSED_MESSAGE
;
297 reading_state_
= READING_UNCOMPRESSED_MESSAGE
;
298 current_reading_opcode_
= frame
->header
.opcode
;
300 if (frame
->header
.reserved1
) {
301 DVLOG(1) << "WebSocket protocol error. "
302 << "Receiving a non-first frame with RSV1 flag set.";
303 return ERR_WS_PROTOCOL_ERROR
;
307 if (reading_state_
== READING_UNCOMPRESSED_MESSAGE
) {
308 if (frame
->header
.final
)
309 reading_state_
= NOT_READING
;
310 current_reading_opcode_
= WebSocketFrameHeader::kOpCodeContinuation
;
311 frames_to_output
.push_back(frame
.release());
313 DCHECK_EQ(reading_state_
, READING_COMPRESSED_MESSAGE
);
314 if (frame
->data
.get() &&
315 !inflater_
.AddBytes(frame
->data
->data(),
316 frame
->header
.payload_length
)) {
317 DVLOG(1) << "WebSocket protocol error. "
318 << "inflater_.AddBytes() returns an error.";
319 return ERR_WS_PROTOCOL_ERROR
;
321 if (frame
->header
.final
) {
322 if (!inflater_
.Finish()) {
323 DVLOG(1) << "WebSocket protocol error. "
324 << "inflater_.Finish() returns an error.";
325 return ERR_WS_PROTOCOL_ERROR
;
328 // TODO(yhirano): Many frames can be generated by the inflater and
329 // memory consumption can grow.
330 // We could avoid it, but avoiding it makes this class much more
332 while (inflater_
.CurrentOutputSize() >= kChunkSize
||
333 frame
->header
.final
) {
334 size_t size
= std::min(kChunkSize
, inflater_
.CurrentOutputSize());
335 scoped_ptr
<WebSocketFrame
> inflated(
336 new WebSocketFrame(WebSocketFrameHeader::kOpCodeText
));
337 scoped_refptr
<IOBufferWithSize
> data
= inflater_
.GetOutput(size
);
338 bool is_final
= !inflater_
.CurrentOutputSize() && frame
->header
.final
;
340 DVLOG(1) << "WebSocket protocol error. "
341 << "inflater_.GetOutput() returns an error.";
342 return ERR_WS_PROTOCOL_ERROR
;
344 inflated
->header
.CopyFrom(frame
->header
);
345 inflated
->header
.opcode
= current_reading_opcode_
;
346 inflated
->header
.final
= is_final
;
347 inflated
->header
.reserved1
= false;
348 inflated
->data
= data
;
349 inflated
->header
.payload_length
= data
->size();
350 DVLOG(3) << "Inflated frame: opcode=" << inflated
->header
.opcode
351 << " final=" << inflated
->header
.final
352 << " reserved1=" << inflated
->header
.reserved1
353 << " payload_length=" << inflated
->header
.payload_length
;
354 frames_to_output
.push_back(inflated
.release());
355 current_reading_opcode_
= WebSocketFrameHeader::kOpCodeContinuation
;
359 if (frame
->header
.final
)
360 reading_state_
= NOT_READING
;
363 frames
->swap(frames_to_output
);
364 return frames
->empty() ? ERR_IO_PENDING
: OK
;
367 int WebSocketDeflateStream::InflateAndReadIfNecessary(
368 ScopedVector
<WebSocketFrame
>* frames
,
369 const CompletionCallback
& callback
) {
370 int result
= Inflate(frames
);
371 while (result
== ERR_IO_PENDING
) {
372 DCHECK(frames
->empty());
374 result
= stream_
->ReadFrames(
376 base::Bind(&WebSocketDeflateStream::OnReadComplete
,
377 base::Unretained(this),
378 base::Unretained(frames
),
382 DCHECK_EQ(OK
, result
);
383 DCHECK(!frames
->empty());
385 result
= Inflate(frames
);