1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_deflate_stream.h"
10 #include "base/bind.h"
11 #include "base/logging.h"
12 #include "base/memory/ref_counted.h"
13 #include "base/memory/scoped_ptr.h"
14 #include "base/memory/scoped_vector.h"
15 #include "net/base/completion_callback.h"
16 #include "net/base/io_buffer.h"
17 #include "net/base/net_errors.h"
18 #include "net/websockets/websocket_deflate_parameters.h"
19 #include "net/websockets/websocket_deflate_predictor.h"
20 #include "net/websockets/websocket_deflater.h"
21 #include "net/websockets/websocket_errors.h"
22 #include "net/websockets/websocket_frame.h"
23 #include "net/websockets/websocket_inflater.h"
24 #include "net/websockets/websocket_stream.h"
32 const int kWindowBits
= 15;
33 const size_t kChunkSize
= 4 * 1024;
37 WebSocketDeflateStream::WebSocketDeflateStream(
38 scoped_ptr
<WebSocketStream
> stream
,
39 const WebSocketDeflateParameters
& params
,
40 scoped_ptr
<WebSocketDeflatePredictor
> predictor
)
41 : stream_(stream
.Pass()),
42 deflater_(params
.client_context_take_over_mode()),
43 inflater_(kChunkSize
, kChunkSize
),
44 reading_state_(NOT_READING
),
45 writing_state_(NOT_WRITING
),
46 current_reading_opcode_(WebSocketFrameHeader::kOpCodeText
),
47 current_writing_opcode_(WebSocketFrameHeader::kOpCodeText
),
48 predictor_(predictor
.Pass()) {
50 DCHECK(params
.IsValidAsResponse());
51 int client_max_window_bits
= 15;
52 if (params
.is_client_max_window_bits_specified()) {
53 DCHECK(params
.has_client_max_window_bits_value());
54 client_max_window_bits
= params
.client_max_window_bits();
56 deflater_
.Initialize(client_max_window_bits
);
57 inflater_
.Initialize(kWindowBits
);
60 WebSocketDeflateStream::~WebSocketDeflateStream() {}
62 int WebSocketDeflateStream::ReadFrames(ScopedVector
<WebSocketFrame
>* frames
,
63 const CompletionCallback
& callback
) {
64 int result
= stream_
->ReadFrames(
66 base::Bind(&WebSocketDeflateStream::OnReadComplete
,
67 base::Unretained(this),
68 base::Unretained(frames
),
72 DCHECK_EQ(OK
, result
);
73 DCHECK(!frames
->empty());
75 return InflateAndReadIfNecessary(frames
, callback
);
78 int WebSocketDeflateStream::WriteFrames(ScopedVector
<WebSocketFrame
>* frames
,
79 const CompletionCallback
& callback
) {
80 int result
= Deflate(frames
);
85 return stream_
->WriteFrames(frames
, callback
);
88 void WebSocketDeflateStream::Close() { stream_
->Close(); }
90 std::string
WebSocketDeflateStream::GetSubProtocol() const {
91 return stream_
->GetSubProtocol();
94 std::string
WebSocketDeflateStream::GetExtensions() const {
95 return stream_
->GetExtensions();
98 void WebSocketDeflateStream::OnReadComplete(
99 ScopedVector
<WebSocketFrame
>* frames
,
100 const CompletionCallback
& callback
,
104 callback
.Run(result
);
108 int r
= InflateAndReadIfNecessary(frames
, callback
);
109 if (r
!= ERR_IO_PENDING
)
113 int WebSocketDeflateStream::Deflate(ScopedVector
<WebSocketFrame
>* frames
) {
114 ScopedVector
<WebSocketFrame
> frames_to_write
;
115 // Store frames of the currently processed message if writing_state_ equals to
116 // WRITING_POSSIBLY_COMPRESSED_MESSAGE.
117 ScopedVector
<WebSocketFrame
> frames_of_message
;
118 for (size_t i
= 0; i
< frames
->size(); ++i
) {
119 DCHECK(!(*frames
)[i
]->header
.reserved1
);
120 if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames
)[i
]->header
.opcode
)) {
121 frames_to_write
.push_back((*frames
)[i
]);
125 if (writing_state_
== NOT_WRITING
)
126 OnMessageStart(*frames
, i
);
128 scoped_ptr
<WebSocketFrame
> frame((*frames
)[i
]);
130 predictor_
->RecordInputDataFrame(frame
.get());
132 if (writing_state_
== WRITING_UNCOMPRESSED_MESSAGE
) {
133 if (frame
->header
.final
)
134 writing_state_
= NOT_WRITING
;
135 predictor_
->RecordWrittenDataFrame(frame
.get());
136 frames_to_write
.push_back(frame
.Pass());
137 current_writing_opcode_
= WebSocketFrameHeader::kOpCodeContinuation
;
139 if (frame
->data
.get() &&
142 static_cast<size_t>(frame
->header
.payload_length
))) {
143 DVLOG(1) << "WebSocket protocol error. "
144 << "deflater_.AddBytes() returns an error.";
145 return ERR_WS_PROTOCOL_ERROR
;
147 if (frame
->header
.final
&& !deflater_
.Finish()) {
148 DVLOG(1) << "WebSocket protocol error. "
149 << "deflater_.Finish() returns an error.";
150 return ERR_WS_PROTOCOL_ERROR
;
153 if (writing_state_
== WRITING_COMPRESSED_MESSAGE
) {
154 if (deflater_
.CurrentOutputSize() >= kChunkSize
||
155 frame
->header
.final
) {
156 int result
= AppendCompressedFrame(frame
->header
, &frames_to_write
);
160 if (frame
->header
.final
)
161 writing_state_
= NOT_WRITING
;
163 DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE
, writing_state_
);
164 bool final
= frame
->header
.final
;
165 frames_of_message
.push_back(frame
.Pass());
167 int result
= AppendPossiblyCompressedMessage(&frames_of_message
,
171 frames_of_message
.clear();
172 writing_state_
= NOT_WRITING
;
177 DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE
, writing_state_
);
178 frames
->swap(frames_to_write
);
182 void WebSocketDeflateStream::OnMessageStart(
183 const ScopedVector
<WebSocketFrame
>& frames
, size_t index
) {
184 WebSocketFrame
* frame
= frames
[index
];
185 current_writing_opcode_
= frame
->header
.opcode
;
186 DCHECK(current_writing_opcode_
== WebSocketFrameHeader::kOpCodeText
||
187 current_writing_opcode_
== WebSocketFrameHeader::kOpCodeBinary
);
188 WebSocketDeflatePredictor::Result prediction
=
189 predictor_
->Predict(frames
, index
);
191 switch (prediction
) {
192 case WebSocketDeflatePredictor::DEFLATE
:
193 writing_state_
= WRITING_COMPRESSED_MESSAGE
;
195 case WebSocketDeflatePredictor::DO_NOT_DEFLATE
:
196 writing_state_
= WRITING_UNCOMPRESSED_MESSAGE
;
198 case WebSocketDeflatePredictor::TRY_DEFLATE
:
199 writing_state_
= WRITING_POSSIBLY_COMPRESSED_MESSAGE
;
205 int WebSocketDeflateStream::AppendCompressedFrame(
206 const WebSocketFrameHeader
& header
,
207 ScopedVector
<WebSocketFrame
>* frames_to_write
) {
208 const WebSocketFrameHeader::OpCode opcode
= current_writing_opcode_
;
209 scoped_refptr
<IOBufferWithSize
> compressed_payload
=
210 deflater_
.GetOutput(deflater_
.CurrentOutputSize());
211 if (!compressed_payload
.get()) {
212 DVLOG(1) << "WebSocket protocol error. "
213 << "deflater_.GetOutput() returns an error.";
214 return ERR_WS_PROTOCOL_ERROR
;
216 scoped_ptr
<WebSocketFrame
> compressed(new WebSocketFrame(opcode
));
217 compressed
->header
.CopyFrom(header
);
218 compressed
->header
.opcode
= opcode
;
219 compressed
->header
.final
= header
.final
;
220 compressed
->header
.reserved1
=
221 (opcode
!= WebSocketFrameHeader::kOpCodeContinuation
);
222 compressed
->data
= compressed_payload
;
223 compressed
->header
.payload_length
= compressed_payload
->size();
225 current_writing_opcode_
= WebSocketFrameHeader::kOpCodeContinuation
;
226 predictor_
->RecordWrittenDataFrame(compressed
.get());
227 frames_to_write
->push_back(compressed
.Pass());
231 int WebSocketDeflateStream::AppendPossiblyCompressedMessage(
232 ScopedVector
<WebSocketFrame
>* frames
,
233 ScopedVector
<WebSocketFrame
>* frames_to_write
) {
234 DCHECK(!frames
->empty());
236 const WebSocketFrameHeader::OpCode opcode
= current_writing_opcode_
;
237 scoped_refptr
<IOBufferWithSize
> compressed_payload
=
238 deflater_
.GetOutput(deflater_
.CurrentOutputSize());
239 if (!compressed_payload
.get()) {
240 DVLOG(1) << "WebSocket protocol error. "
241 << "deflater_.GetOutput() returns an error.";
242 return ERR_WS_PROTOCOL_ERROR
;
245 uint64 original_payload_length
= 0;
246 for (size_t i
= 0; i
< frames
->size(); ++i
) {
247 WebSocketFrame
* frame
= (*frames
)[i
];
248 // Asserts checking that frames represent one whole data message.
249 DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame
->header
.opcode
));
251 WebSocketFrameHeader::kOpCodeContinuation
!=
252 frame
->header
.opcode
);
253 DCHECK_EQ(i
== frames
->size() - 1, frame
->header
.final
);
254 original_payload_length
+= frame
->header
.payload_length
;
256 if (original_payload_length
<=
257 static_cast<uint64
>(compressed_payload
->size())) {
258 // Compression is not effective. Use the original frames.
259 for (size_t i
= 0; i
< frames
->size(); ++i
) {
260 WebSocketFrame
* frame
= (*frames
)[i
];
261 frames_to_write
->push_back(frame
);
262 predictor_
->RecordWrittenDataFrame(frame
);
265 frames
->weak_clear();
268 scoped_ptr
<WebSocketFrame
> compressed(new WebSocketFrame(opcode
));
269 compressed
->header
.CopyFrom((*frames
)[0]->header
);
270 compressed
->header
.opcode
= opcode
;
271 compressed
->header
.final
= true;
272 compressed
->header
.reserved1
= true;
273 compressed
->data
= compressed_payload
;
274 compressed
->header
.payload_length
= compressed_payload
->size();
276 predictor_
->RecordWrittenDataFrame(compressed
.get());
277 frames_to_write
->push_back(compressed
.Pass());
281 int WebSocketDeflateStream::Inflate(ScopedVector
<WebSocketFrame
>* frames
) {
282 ScopedVector
<WebSocketFrame
> frames_to_output
;
283 ScopedVector
<WebSocketFrame
> frames_passed
;
284 frames
->swap(frames_passed
);
285 for (size_t i
= 0; i
< frames_passed
.size(); ++i
) {
286 scoped_ptr
<WebSocketFrame
> frame(frames_passed
[i
]);
287 frames_passed
[i
] = NULL
;
288 DVLOG(3) << "Input frame: opcode=" << frame
->header
.opcode
289 << " final=" << frame
->header
.final
290 << " reserved1=" << frame
->header
.reserved1
291 << " payload_length=" << frame
->header
.payload_length
;
293 if (!WebSocketFrameHeader::IsKnownDataOpCode(frame
->header
.opcode
)) {
294 frames_to_output
.push_back(frame
.Pass());
298 if (reading_state_
== NOT_READING
) {
299 if (frame
->header
.reserved1
)
300 reading_state_
= READING_COMPRESSED_MESSAGE
;
302 reading_state_
= READING_UNCOMPRESSED_MESSAGE
;
303 current_reading_opcode_
= frame
->header
.opcode
;
305 if (frame
->header
.reserved1
) {
306 DVLOG(1) << "WebSocket protocol error. "
307 << "Receiving a non-first frame with RSV1 flag set.";
308 return ERR_WS_PROTOCOL_ERROR
;
312 if (reading_state_
== READING_UNCOMPRESSED_MESSAGE
) {
313 if (frame
->header
.final
)
314 reading_state_
= NOT_READING
;
315 current_reading_opcode_
= WebSocketFrameHeader::kOpCodeContinuation
;
316 frames_to_output
.push_back(frame
.Pass());
318 DCHECK_EQ(reading_state_
, READING_COMPRESSED_MESSAGE
);
319 if (frame
->data
.get() &&
322 static_cast<size_t>(frame
->header
.payload_length
))) {
323 DVLOG(1) << "WebSocket protocol error. "
324 << "inflater_.AddBytes() returns an error.";
325 return ERR_WS_PROTOCOL_ERROR
;
327 if (frame
->header
.final
) {
328 if (!inflater_
.Finish()) {
329 DVLOG(1) << "WebSocket protocol error. "
330 << "inflater_.Finish() returns an error.";
331 return ERR_WS_PROTOCOL_ERROR
;
334 // TODO(yhirano): Many frames can be generated by the inflater and
335 // memory consumption can grow.
336 // We could avoid it, but avoiding it makes this class much more
338 while (inflater_
.CurrentOutputSize() >= kChunkSize
||
339 frame
->header
.final
) {
340 size_t size
= std::min(kChunkSize
, inflater_
.CurrentOutputSize());
341 scoped_ptr
<WebSocketFrame
> inflated(
342 new WebSocketFrame(WebSocketFrameHeader::kOpCodeText
));
343 scoped_refptr
<IOBufferWithSize
> data
= inflater_
.GetOutput(size
);
344 bool is_final
= !inflater_
.CurrentOutputSize() && frame
->header
.final
;
346 DVLOG(1) << "WebSocket protocol error. "
347 << "inflater_.GetOutput() returns an error.";
348 return ERR_WS_PROTOCOL_ERROR
;
350 inflated
->header
.CopyFrom(frame
->header
);
351 inflated
->header
.opcode
= current_reading_opcode_
;
352 inflated
->header
.final
= is_final
;
353 inflated
->header
.reserved1
= false;
354 inflated
->data
= data
;
355 inflated
->header
.payload_length
= data
->size();
356 DVLOG(3) << "Inflated frame: opcode=" << inflated
->header
.opcode
357 << " final=" << inflated
->header
.final
358 << " reserved1=" << inflated
->header
.reserved1
359 << " payload_length=" << inflated
->header
.payload_length
;
360 frames_to_output
.push_back(inflated
.Pass());
361 current_reading_opcode_
= WebSocketFrameHeader::kOpCodeContinuation
;
365 if (frame
->header
.final
)
366 reading_state_
= NOT_READING
;
369 frames
->swap(frames_to_output
);
370 return frames
->empty() ? ERR_IO_PENDING
: OK
;
373 int WebSocketDeflateStream::InflateAndReadIfNecessary(
374 ScopedVector
<WebSocketFrame
>* frames
,
375 const CompletionCallback
& callback
) {
376 int result
= Inflate(frames
);
377 while (result
== ERR_IO_PENDING
) {
378 DCHECK(frames
->empty());
380 result
= stream_
->ReadFrames(
382 base::Bind(&WebSocketDeflateStream::OnReadComplete
,
383 base::Unretained(this),
384 base::Unretained(frames
),
388 DCHECK_EQ(OK
, result
);
389 DCHECK(!frames
->empty());
391 result
= Inflate(frames
);