1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_deflate_stream.h"
10 #include "base/bind.h"
11 #include "base/logging.h"
12 #include "base/memory/ref_counted.h"
13 #include "base/memory/scoped_ptr.h"
14 #include "base/memory/scoped_vector.h"
15 #include "net/base/completion_callback.h"
16 #include "net/base/io_buffer.h"
17 #include "net/base/net_errors.h"
18 #include "net/websockets/websocket_deflate_predictor.h"
19 #include "net/websockets/websocket_deflater.h"
20 #include "net/websockets/websocket_errors.h"
21 #include "net/websockets/websocket_frame.h"
22 #include "net/websockets/websocket_inflater.h"
23 #include "net/websockets/websocket_stream.h"
31 const int kWindowBits
= 15;
32 const size_t kChunkSize
= 4 * 1024;
36 WebSocketDeflateStream::WebSocketDeflateStream(
37 scoped_ptr
<WebSocketStream
> stream
,
38 WebSocketDeflater::ContextTakeOverMode mode
,
39 int client_window_bits
,
40 scoped_ptr
<WebSocketDeflatePredictor
> predictor
)
41 : stream_(stream
.Pass()),
43 inflater_(kChunkSize
, kChunkSize
),
44 reading_state_(NOT_READING
),
45 writing_state_(NOT_WRITING
),
46 current_reading_opcode_(WebSocketFrameHeader::kOpCodeText
),
47 current_writing_opcode_(WebSocketFrameHeader::kOpCodeText
),
48 predictor_(predictor
.Pass()) {
50 DCHECK_GE(client_window_bits
, 8);
51 DCHECK_LE(client_window_bits
, 15);
52 deflater_
.Initialize(client_window_bits
);
53 inflater_
.Initialize(kWindowBits
);
56 WebSocketDeflateStream::~WebSocketDeflateStream() {}
58 int WebSocketDeflateStream::ReadFrames(ScopedVector
<WebSocketFrame
>* frames
,
59 const CompletionCallback
& callback
) {
60 int result
= stream_
->ReadFrames(
62 base::Bind(&WebSocketDeflateStream::OnReadComplete
,
63 base::Unretained(this),
64 base::Unretained(frames
),
68 DCHECK_EQ(OK
, result
);
69 DCHECK(!frames
->empty());
71 return InflateAndReadIfNecessary(frames
, callback
);
74 int WebSocketDeflateStream::WriteFrames(ScopedVector
<WebSocketFrame
>* frames
,
75 const CompletionCallback
& callback
) {
76 int result
= Deflate(frames
);
81 return stream_
->WriteFrames(frames
, callback
);
84 void WebSocketDeflateStream::Close() { stream_
->Close(); }
86 std::string
WebSocketDeflateStream::GetSubProtocol() const {
87 return stream_
->GetSubProtocol();
90 std::string
WebSocketDeflateStream::GetExtensions() const {
91 return stream_
->GetExtensions();
94 void WebSocketDeflateStream::OnReadComplete(
95 ScopedVector
<WebSocketFrame
>* frames
,
96 const CompletionCallback
& callback
,
100 callback
.Run(result
);
104 int r
= InflateAndReadIfNecessary(frames
, callback
);
105 if (r
!= ERR_IO_PENDING
)
109 int WebSocketDeflateStream::Deflate(ScopedVector
<WebSocketFrame
>* frames
) {
110 ScopedVector
<WebSocketFrame
> frames_to_write
;
111 // Store frames of the currently processed message if writing_state_ equals to
112 // WRITING_POSSIBLY_COMPRESSED_MESSAGE.
113 ScopedVector
<WebSocketFrame
> frames_of_message
;
114 for (size_t i
= 0; i
< frames
->size(); ++i
) {
115 DCHECK(!(*frames
)[i
]->header
.reserved1
);
116 if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames
)[i
]->header
.opcode
)) {
117 frames_to_write
.push_back((*frames
)[i
]);
121 if (writing_state_
== NOT_WRITING
)
122 OnMessageStart(*frames
, i
);
124 scoped_ptr
<WebSocketFrame
> frame((*frames
)[i
]);
126 predictor_
->RecordInputDataFrame(frame
.get());
128 if (writing_state_
== WRITING_UNCOMPRESSED_MESSAGE
) {
129 if (frame
->header
.final
)
130 writing_state_
= NOT_WRITING
;
131 predictor_
->RecordWrittenDataFrame(frame
.get());
132 frames_to_write
.push_back(frame
.Pass());
133 current_writing_opcode_
= WebSocketFrameHeader::kOpCodeContinuation
;
135 if (frame
->data
.get() &&
138 static_cast<size_t>(frame
->header
.payload_length
))) {
139 DVLOG(1) << "WebSocket protocol error. "
140 << "deflater_.AddBytes() returns an error.";
141 return ERR_WS_PROTOCOL_ERROR
;
143 if (frame
->header
.final
&& !deflater_
.Finish()) {
144 DVLOG(1) << "WebSocket protocol error. "
145 << "deflater_.Finish() returns an error.";
146 return ERR_WS_PROTOCOL_ERROR
;
149 if (writing_state_
== WRITING_COMPRESSED_MESSAGE
) {
150 if (deflater_
.CurrentOutputSize() >= kChunkSize
||
151 frame
->header
.final
) {
152 int result
= AppendCompressedFrame(frame
->header
, &frames_to_write
);
156 if (frame
->header
.final
)
157 writing_state_
= NOT_WRITING
;
159 DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE
, writing_state_
);
160 bool final
= frame
->header
.final
;
161 frames_of_message
.push_back(frame
.Pass());
163 int result
= AppendPossiblyCompressedMessage(&frames_of_message
,
167 frames_of_message
.clear();
168 writing_state_
= NOT_WRITING
;
173 DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE
, writing_state_
);
174 frames
->swap(frames_to_write
);
178 void WebSocketDeflateStream::OnMessageStart(
179 const ScopedVector
<WebSocketFrame
>& frames
, size_t index
) {
180 WebSocketFrame
* frame
= frames
[index
];
181 current_writing_opcode_
= frame
->header
.opcode
;
182 DCHECK(current_writing_opcode_
== WebSocketFrameHeader::kOpCodeText
||
183 current_writing_opcode_
== WebSocketFrameHeader::kOpCodeBinary
);
184 WebSocketDeflatePredictor::Result prediction
=
185 predictor_
->Predict(frames
, index
);
187 switch (prediction
) {
188 case WebSocketDeflatePredictor::DEFLATE
:
189 writing_state_
= WRITING_COMPRESSED_MESSAGE
;
191 case WebSocketDeflatePredictor::DO_NOT_DEFLATE
:
192 writing_state_
= WRITING_UNCOMPRESSED_MESSAGE
;
194 case WebSocketDeflatePredictor::TRY_DEFLATE
:
195 writing_state_
= WRITING_POSSIBLY_COMPRESSED_MESSAGE
;
201 int WebSocketDeflateStream::AppendCompressedFrame(
202 const WebSocketFrameHeader
& header
,
203 ScopedVector
<WebSocketFrame
>* frames_to_write
) {
204 const WebSocketFrameHeader::OpCode opcode
= current_writing_opcode_
;
205 scoped_refptr
<IOBufferWithSize
> compressed_payload
=
206 deflater_
.GetOutput(deflater_
.CurrentOutputSize());
207 if (!compressed_payload
.get()) {
208 DVLOG(1) << "WebSocket protocol error. "
209 << "deflater_.GetOutput() returns an error.";
210 return ERR_WS_PROTOCOL_ERROR
;
212 scoped_ptr
<WebSocketFrame
> compressed(new WebSocketFrame(opcode
));
213 compressed
->header
.CopyFrom(header
);
214 compressed
->header
.opcode
= opcode
;
215 compressed
->header
.final
= header
.final
;
216 compressed
->header
.reserved1
=
217 (opcode
!= WebSocketFrameHeader::kOpCodeContinuation
);
218 compressed
->data
= compressed_payload
;
219 compressed
->header
.payload_length
= compressed_payload
->size();
221 current_writing_opcode_
= WebSocketFrameHeader::kOpCodeContinuation
;
222 predictor_
->RecordWrittenDataFrame(compressed
.get());
223 frames_to_write
->push_back(compressed
.Pass());
227 int WebSocketDeflateStream::AppendPossiblyCompressedMessage(
228 ScopedVector
<WebSocketFrame
>* frames
,
229 ScopedVector
<WebSocketFrame
>* frames_to_write
) {
230 DCHECK(!frames
->empty());
232 const WebSocketFrameHeader::OpCode opcode
= current_writing_opcode_
;
233 scoped_refptr
<IOBufferWithSize
> compressed_payload
=
234 deflater_
.GetOutput(deflater_
.CurrentOutputSize());
235 if (!compressed_payload
.get()) {
236 DVLOG(1) << "WebSocket protocol error. "
237 << "deflater_.GetOutput() returns an error.";
238 return ERR_WS_PROTOCOL_ERROR
;
241 uint64 original_payload_length
= 0;
242 for (size_t i
= 0; i
< frames
->size(); ++i
) {
243 WebSocketFrame
* frame
= (*frames
)[i
];
244 // Asserts checking that frames represent one whole data message.
245 DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame
->header
.opcode
));
247 WebSocketFrameHeader::kOpCodeContinuation
!=
248 frame
->header
.opcode
);
249 DCHECK_EQ(i
== frames
->size() - 1, frame
->header
.final
);
250 original_payload_length
+= frame
->header
.payload_length
;
252 if (original_payload_length
<=
253 static_cast<uint64
>(compressed_payload
->size())) {
254 // Compression is not effective. Use the original frames.
255 for (size_t i
= 0; i
< frames
->size(); ++i
) {
256 WebSocketFrame
* frame
= (*frames
)[i
];
257 frames_to_write
->push_back(frame
);
258 predictor_
->RecordWrittenDataFrame(frame
);
261 frames
->weak_clear();
264 scoped_ptr
<WebSocketFrame
> compressed(new WebSocketFrame(opcode
));
265 compressed
->header
.CopyFrom((*frames
)[0]->header
);
266 compressed
->header
.opcode
= opcode
;
267 compressed
->header
.final
= true;
268 compressed
->header
.reserved1
= true;
269 compressed
->data
= compressed_payload
;
270 compressed
->header
.payload_length
= compressed_payload
->size();
272 predictor_
->RecordWrittenDataFrame(compressed
.get());
273 frames_to_write
->push_back(compressed
.Pass());
277 int WebSocketDeflateStream::Inflate(ScopedVector
<WebSocketFrame
>* frames
) {
278 ScopedVector
<WebSocketFrame
> frames_to_output
;
279 ScopedVector
<WebSocketFrame
> frames_passed
;
280 frames
->swap(frames_passed
);
281 for (size_t i
= 0; i
< frames_passed
.size(); ++i
) {
282 scoped_ptr
<WebSocketFrame
> frame(frames_passed
[i
]);
283 frames_passed
[i
] = NULL
;
284 DVLOG(3) << "Input frame: opcode=" << frame
->header
.opcode
285 << " final=" << frame
->header
.final
286 << " reserved1=" << frame
->header
.reserved1
287 << " payload_length=" << frame
->header
.payload_length
;
289 if (!WebSocketFrameHeader::IsKnownDataOpCode(frame
->header
.opcode
)) {
290 frames_to_output
.push_back(frame
.Pass());
294 if (reading_state_
== NOT_READING
) {
295 if (frame
->header
.reserved1
)
296 reading_state_
= READING_COMPRESSED_MESSAGE
;
298 reading_state_
= READING_UNCOMPRESSED_MESSAGE
;
299 current_reading_opcode_
= frame
->header
.opcode
;
301 if (frame
->header
.reserved1
) {
302 DVLOG(1) << "WebSocket protocol error. "
303 << "Receiving a non-first frame with RSV1 flag set.";
304 return ERR_WS_PROTOCOL_ERROR
;
308 if (reading_state_
== READING_UNCOMPRESSED_MESSAGE
) {
309 if (frame
->header
.final
)
310 reading_state_
= NOT_READING
;
311 current_reading_opcode_
= WebSocketFrameHeader::kOpCodeContinuation
;
312 frames_to_output
.push_back(frame
.Pass());
314 DCHECK_EQ(reading_state_
, READING_COMPRESSED_MESSAGE
);
315 if (frame
->data
.get() &&
318 static_cast<size_t>(frame
->header
.payload_length
))) {
319 DVLOG(1) << "WebSocket protocol error. "
320 << "inflater_.AddBytes() returns an error.";
321 return ERR_WS_PROTOCOL_ERROR
;
323 if (frame
->header
.final
) {
324 if (!inflater_
.Finish()) {
325 DVLOG(1) << "WebSocket protocol error. "
326 << "inflater_.Finish() returns an error.";
327 return ERR_WS_PROTOCOL_ERROR
;
330 // TODO(yhirano): Many frames can be generated by the inflater and
331 // memory consumption can grow.
332 // We could avoid it, but avoiding it makes this class much more
334 while (inflater_
.CurrentOutputSize() >= kChunkSize
||
335 frame
->header
.final
) {
336 size_t size
= std::min(kChunkSize
, inflater_
.CurrentOutputSize());
337 scoped_ptr
<WebSocketFrame
> inflated(
338 new WebSocketFrame(WebSocketFrameHeader::kOpCodeText
));
339 scoped_refptr
<IOBufferWithSize
> data
= inflater_
.GetOutput(size
);
340 bool is_final
= !inflater_
.CurrentOutputSize() && frame
->header
.final
;
342 DVLOG(1) << "WebSocket protocol error. "
343 << "inflater_.GetOutput() returns an error.";
344 return ERR_WS_PROTOCOL_ERROR
;
346 inflated
->header
.CopyFrom(frame
->header
);
347 inflated
->header
.opcode
= current_reading_opcode_
;
348 inflated
->header
.final
= is_final
;
349 inflated
->header
.reserved1
= false;
350 inflated
->data
= data
;
351 inflated
->header
.payload_length
= data
->size();
352 DVLOG(3) << "Inflated frame: opcode=" << inflated
->header
.opcode
353 << " final=" << inflated
->header
.final
354 << " reserved1=" << inflated
->header
.reserved1
355 << " payload_length=" << inflated
->header
.payload_length
;
356 frames_to_output
.push_back(inflated
.Pass());
357 current_reading_opcode_
= WebSocketFrameHeader::kOpCodeContinuation
;
361 if (frame
->header
.final
)
362 reading_state_
= NOT_READING
;
365 frames
->swap(frames_to_output
);
366 return frames
->empty() ? ERR_IO_PENDING
: OK
;
369 int WebSocketDeflateStream::InflateAndReadIfNecessary(
370 ScopedVector
<WebSocketFrame
>* frames
,
371 const CompletionCallback
& callback
) {
372 int result
= Inflate(frames
);
373 while (result
== ERR_IO_PENDING
) {
374 DCHECK(frames
->empty());
376 result
= stream_
->ReadFrames(
378 base::Bind(&WebSocketDeflateStream::OnReadComplete
,
379 base::Unretained(this),
380 base::Unretained(frames
),
384 DCHECK_EQ(OK
, result
);
385 DCHECK(!frames
->empty());
387 result
= Inflate(frames
);