1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_deflate_stream.h"
10 #include "base/bind.h"
11 #include "base/logging.h"
12 #include "base/memory/ref_counted.h"
13 #include "base/memory/scoped_ptr.h"
14 #include "base/memory/scoped_vector.h"
15 #include "net/base/completion_callback.h"
16 #include "net/base/io_buffer.h"
17 #include "net/base/net_errors.h"
18 #include "net/websockets/websocket_deflate_predictor.h"
19 #include "net/websockets/websocket_deflater.h"
20 #include "net/websockets/websocket_errors.h"
21 #include "net/websockets/websocket_frame.h"
22 #include "net/websockets/websocket_inflater.h"
23 #include "net/websockets/websocket_stream.h"
31 const int kWindowBits
= 15;
32 const size_t kChunkSize
= 4 * 1024;
36 WebSocketDeflateStream::WebSocketDeflateStream(
37 scoped_ptr
<WebSocketStream
> stream
,
38 WebSocketDeflater::ContextTakeOverMode mode
,
39 int client_window_bits
,
40 scoped_ptr
<WebSocketDeflatePredictor
> predictor
)
41 : stream_(stream
.Pass()),
43 inflater_(kChunkSize
, kChunkSize
),
44 reading_state_(NOT_READING
),
45 writing_state_(NOT_WRITING
),
46 current_reading_opcode_(WebSocketFrameHeader::kOpCodeText
),
47 current_writing_opcode_(WebSocketFrameHeader::kOpCodeText
),
48 predictor_(predictor
.Pass()) {
50 DCHECK_GE(client_window_bits
, 8);
51 DCHECK_LE(client_window_bits
, 15);
52 deflater_
.Initialize(client_window_bits
);
53 inflater_
.Initialize(kWindowBits
);
56 WebSocketDeflateStream::~WebSocketDeflateStream() {}
58 int WebSocketDeflateStream::ReadFrames(ScopedVector
<WebSocketFrame
>* frames
,
59 const CompletionCallback
& callback
) {
60 int result
= stream_
->ReadFrames(
62 base::Bind(&WebSocketDeflateStream::OnReadComplete
,
63 base::Unretained(this),
64 base::Unretained(frames
),
68 DCHECK_EQ(OK
, result
);
69 DCHECK(!frames
->empty());
71 return InflateAndReadIfNecessary(frames
, callback
);
74 int WebSocketDeflateStream::WriteFrames(ScopedVector
<WebSocketFrame
>* frames
,
75 const CompletionCallback
& callback
) {
76 int result
= Deflate(frames
);
81 return stream_
->WriteFrames(frames
, callback
);
84 void WebSocketDeflateStream::Close() { stream_
->Close(); }
86 std::string
WebSocketDeflateStream::GetSubProtocol() const {
87 return stream_
->GetSubProtocol();
90 std::string
WebSocketDeflateStream::GetExtensions() const {
91 return stream_
->GetExtensions();
94 void WebSocketDeflateStream::OnReadComplete(
95 ScopedVector
<WebSocketFrame
>* frames
,
96 const CompletionCallback
& callback
,
100 callback
.Run(result
);
104 int r
= InflateAndReadIfNecessary(frames
, callback
);
105 if (r
!= ERR_IO_PENDING
)
109 int WebSocketDeflateStream::Deflate(ScopedVector
<WebSocketFrame
>* frames
) {
110 ScopedVector
<WebSocketFrame
> frames_to_write
;
111 // Store frames of the currently processed message if writing_state_ equals to
112 // WRITING_POSSIBLY_COMPRESSED_MESSAGE.
113 ScopedVector
<WebSocketFrame
> frames_of_message
;
114 for (size_t i
= 0; i
< frames
->size(); ++i
) {
115 DCHECK(!(*frames
)[i
]->header
.reserved1
);
116 if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames
)[i
]->header
.opcode
)) {
117 frames_to_write
.push_back((*frames
)[i
]);
121 if (writing_state_
== NOT_WRITING
)
122 OnMessageStart(*frames
, i
);
124 scoped_ptr
<WebSocketFrame
> frame((*frames
)[i
]);
126 predictor_
->RecordInputDataFrame(frame
.get());
128 if (writing_state_
== WRITING_UNCOMPRESSED_MESSAGE
) {
129 if (frame
->header
.final
)
130 writing_state_
= NOT_WRITING
;
131 predictor_
->RecordWrittenDataFrame(frame
.get());
132 frames_to_write
.push_back(frame
.release());
133 current_writing_opcode_
= WebSocketFrameHeader::kOpCodeContinuation
;
135 if (frame
->data
&& !deflater_
.AddBytes(frame
->data
->data(),
136 frame
->header
.payload_length
)) {
137 DVLOG(1) << "WebSocket protocol error. "
138 << "deflater_.AddBytes() returns an error.";
139 return ERR_WS_PROTOCOL_ERROR
;
141 if (frame
->header
.final
&& !deflater_
.Finish()) {
142 DVLOG(1) << "WebSocket protocol error. "
143 << "deflater_.Finish() returns an error.";
144 return ERR_WS_PROTOCOL_ERROR
;
147 if (writing_state_
== WRITING_COMPRESSED_MESSAGE
) {
148 if (deflater_
.CurrentOutputSize() >= kChunkSize
||
149 frame
->header
.final
) {
150 int result
= AppendCompressedFrame(frame
->header
, &frames_to_write
);
154 if (frame
->header
.final
)
155 writing_state_
= NOT_WRITING
;
157 DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE
, writing_state_
);
158 bool final
= frame
->header
.final
;
159 frames_of_message
.push_back(frame
.release());
161 int result
= AppendPossiblyCompressedMessage(&frames_of_message
,
165 frames_of_message
.clear();
166 writing_state_
= NOT_WRITING
;
171 DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE
, writing_state_
);
172 frames
->swap(frames_to_write
);
176 void WebSocketDeflateStream::OnMessageStart(
177 const ScopedVector
<WebSocketFrame
>& frames
, size_t index
) {
178 WebSocketFrame
* frame
= frames
[index
];
179 current_writing_opcode_
= frame
->header
.opcode
;
180 DCHECK(current_writing_opcode_
== WebSocketFrameHeader::kOpCodeText
||
181 current_writing_opcode_
== WebSocketFrameHeader::kOpCodeBinary
);
182 WebSocketDeflatePredictor::Result prediction
=
183 predictor_
->Predict(frames
, index
);
185 switch (prediction
) {
186 case WebSocketDeflatePredictor::DEFLATE
:
187 writing_state_
= WRITING_COMPRESSED_MESSAGE
;
189 case WebSocketDeflatePredictor::DO_NOT_DEFLATE
:
190 writing_state_
= WRITING_UNCOMPRESSED_MESSAGE
;
192 case WebSocketDeflatePredictor::TRY_DEFLATE
:
193 writing_state_
= WRITING_POSSIBLY_COMPRESSED_MESSAGE
;
199 int WebSocketDeflateStream::AppendCompressedFrame(
200 const WebSocketFrameHeader
& header
,
201 ScopedVector
<WebSocketFrame
>* frames_to_write
) {
202 const WebSocketFrameHeader::OpCode opcode
= current_writing_opcode_
;
203 scoped_refptr
<IOBufferWithSize
> compressed_payload
=
204 deflater_
.GetOutput(deflater_
.CurrentOutputSize());
205 if (!compressed_payload
) {
206 DVLOG(1) << "WebSocket protocol error. "
207 << "deflater_.GetOutput() returns an error.";
208 return ERR_WS_PROTOCOL_ERROR
;
210 scoped_ptr
<WebSocketFrame
> compressed(new WebSocketFrame(opcode
));
211 compressed
->header
.CopyFrom(header
);
212 compressed
->header
.opcode
= opcode
;
213 compressed
->header
.final
= header
.final
;
214 compressed
->header
.reserved1
=
215 (opcode
!= WebSocketFrameHeader::kOpCodeContinuation
);
216 compressed
->data
= compressed_payload
;
217 compressed
->header
.payload_length
= compressed_payload
->size();
219 current_writing_opcode_
= WebSocketFrameHeader::kOpCodeContinuation
;
220 predictor_
->RecordWrittenDataFrame(compressed
.get());
221 frames_to_write
->push_back(compressed
.release());
225 int WebSocketDeflateStream::AppendPossiblyCompressedMessage(
226 ScopedVector
<WebSocketFrame
>* frames
,
227 ScopedVector
<WebSocketFrame
>* frames_to_write
) {
228 DCHECK(!frames
->empty());
230 const WebSocketFrameHeader::OpCode opcode
= current_writing_opcode_
;
231 scoped_refptr
<IOBufferWithSize
> compressed_payload
=
232 deflater_
.GetOutput(deflater_
.CurrentOutputSize());
233 if (!compressed_payload
) {
234 DVLOG(1) << "WebSocket protocol error. "
235 << "deflater_.GetOutput() returns an error.";
236 return ERR_WS_PROTOCOL_ERROR
;
239 uint64 original_payload_length
= 0;
240 for (size_t i
= 0; i
< frames
->size(); ++i
) {
241 WebSocketFrame
* frame
= (*frames
)[i
];
242 // Asserts checking that frames represent one whole data message.
243 DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame
->header
.opcode
));
245 WebSocketFrameHeader::kOpCodeContinuation
!=
246 frame
->header
.opcode
);
247 DCHECK_EQ(i
== frames
->size() - 1, frame
->header
.final
);
248 original_payload_length
+= frame
->header
.payload_length
;
250 if (original_payload_length
<=
251 static_cast<uint64
>(compressed_payload
->size())) {
252 // Compression is not effective. Use the original frames.
253 for (size_t i
= 0; i
< frames
->size(); ++i
) {
254 WebSocketFrame
* frame
= (*frames
)[i
];
255 frames_to_write
->push_back(frame
);
256 predictor_
->RecordWrittenDataFrame(frame
);
259 frames
->weak_clear();
262 scoped_ptr
<WebSocketFrame
> compressed(new WebSocketFrame(opcode
));
263 compressed
->header
.CopyFrom((*frames
)[0]->header
);
264 compressed
->header
.opcode
= opcode
;
265 compressed
->header
.final
= true;
266 compressed
->header
.reserved1
= true;
267 compressed
->data
= compressed_payload
;
268 compressed
->header
.payload_length
= compressed_payload
->size();
270 predictor_
->RecordWrittenDataFrame(compressed
.get());
271 frames_to_write
->push_back(compressed
.release());
275 int WebSocketDeflateStream::Inflate(ScopedVector
<WebSocketFrame
>* frames
) {
276 ScopedVector
<WebSocketFrame
> frames_to_output
;
277 ScopedVector
<WebSocketFrame
> frames_passed
;
278 frames
->swap(frames_passed
);
279 for (size_t i
= 0; i
< frames_passed
.size(); ++i
) {
280 scoped_ptr
<WebSocketFrame
> frame(frames_passed
[i
]);
281 frames_passed
[i
] = NULL
;
282 DVLOG(3) << "Input frame: opcode=" << frame
->header
.opcode
283 << " final=" << frame
->header
.final
284 << " reserved1=" << frame
->header
.reserved1
285 << " payload_length=" << frame
->header
.payload_length
;
287 if (!WebSocketFrameHeader::IsKnownDataOpCode(frame
->header
.opcode
)) {
288 frames_to_output
.push_back(frame
.release());
292 if (reading_state_
== NOT_READING
) {
293 if (frame
->header
.reserved1
)
294 reading_state_
= READING_COMPRESSED_MESSAGE
;
296 reading_state_
= READING_UNCOMPRESSED_MESSAGE
;
297 current_reading_opcode_
= frame
->header
.opcode
;
299 if (frame
->header
.reserved1
) {
300 DVLOG(1) << "WebSocket protocol error. "
301 << "Receiving a non-first frame with RSV1 flag set.";
302 return ERR_WS_PROTOCOL_ERROR
;
306 if (reading_state_
== READING_UNCOMPRESSED_MESSAGE
) {
307 if (frame
->header
.final
)
308 reading_state_
= NOT_READING
;
309 current_reading_opcode_
= WebSocketFrameHeader::kOpCodeContinuation
;
310 frames_to_output
.push_back(frame
.release());
312 DCHECK_EQ(reading_state_
, READING_COMPRESSED_MESSAGE
);
313 if (frame
->data
&& !inflater_
.AddBytes(frame
->data
->data(),
314 frame
->header
.payload_length
)) {
315 DVLOG(1) << "WebSocket protocol error. "
316 << "inflater_.AddBytes() returns an error.";
317 return ERR_WS_PROTOCOL_ERROR
;
319 if (frame
->header
.final
) {
320 if (!inflater_
.Finish()) {
321 DVLOG(1) << "WebSocket protocol error. "
322 << "inflater_.Finish() returns an error.";
323 return ERR_WS_PROTOCOL_ERROR
;
326 // TODO(yhirano): Many frames can be generated by the inflater and
327 // memory consumption can grow.
328 // We could avoid it, but avoiding it makes this class much more
330 while (inflater_
.CurrentOutputSize() >= kChunkSize
||
331 frame
->header
.final
) {
332 size_t size
= std::min(kChunkSize
, inflater_
.CurrentOutputSize());
333 scoped_ptr
<WebSocketFrame
> inflated(
334 new WebSocketFrame(WebSocketFrameHeader::kOpCodeText
));
335 scoped_refptr
<IOBufferWithSize
> data
= inflater_
.GetOutput(size
);
336 bool is_final
= !inflater_
.CurrentOutputSize();
337 // |is_final| can't be true if |frame->header.final| is false.
338 DCHECK(!(is_final
&& !frame
->header
.final
));
340 DVLOG(1) << "WebSocket protocol error. "
341 << "inflater_.GetOutput() returns an error.";
342 return ERR_WS_PROTOCOL_ERROR
;
344 inflated
->header
.CopyFrom(frame
->header
);
345 inflated
->header
.opcode
= current_reading_opcode_
;
346 inflated
->header
.final
= is_final
;
347 inflated
->header
.reserved1
= false;
348 inflated
->data
= data
;
349 inflated
->header
.payload_length
= data
->size();
350 DVLOG(3) << "Inflated frame: opcode=" << inflated
->header
.opcode
351 << " final=" << inflated
->header
.final
352 << " reserved1=" << inflated
->header
.reserved1
353 << " payload_length=" << inflated
->header
.payload_length
;
354 frames_to_output
.push_back(inflated
.release());
355 current_reading_opcode_
= WebSocketFrameHeader::kOpCodeContinuation
;
359 if (frame
->header
.final
)
360 reading_state_
= NOT_READING
;
363 frames
->swap(frames_to_output
);
364 return frames
->empty() ? ERR_IO_PENDING
: OK
;
367 int WebSocketDeflateStream::InflateAndReadIfNecessary(
368 ScopedVector
<WebSocketFrame
>* frames
,
369 const CompletionCallback
& callback
) {
370 int result
= Inflate(frames
);
371 while (result
== ERR_IO_PENDING
) {
372 DCHECK(frames
->empty());
374 result
= stream_
->ReadFrames(
376 base::Bind(&WebSocketDeflateStream::OnReadComplete
,
377 base::Unretained(this),
378 base::Unretained(frames
),
382 DCHECK_EQ(OK
, result
);
383 DCHECK(!frames
->empty());
385 result
= Inflate(frames
);