Probably broke Win7 Tests (dbg)(6). http://build.chromium.org/p/chromium.win/builders...
[chromium-blink-merge.git] / net / websockets / websocket_deflate_stream.cc
blob38de5fa2ecab770882d3594a2ba95f7de5810e0c
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_deflate_stream.h"
7 #include <algorithm>
8 #include <string>
10 #include "base/bind.h"
11 #include "base/logging.h"
12 #include "base/memory/ref_counted.h"
13 #include "base/memory/scoped_ptr.h"
14 #include "base/memory/scoped_vector.h"
15 #include "net/base/completion_callback.h"
16 #include "net/base/io_buffer.h"
17 #include "net/base/net_errors.h"
18 #include "net/websockets/websocket_deflate_predictor.h"
19 #include "net/websockets/websocket_deflater.h"
20 #include "net/websockets/websocket_errors.h"
21 #include "net/websockets/websocket_frame.h"
22 #include "net/websockets/websocket_inflater.h"
23 #include "net/websockets/websocket_stream.h"
25 class GURL;
27 namespace net {
29 namespace {
31 const int kWindowBits = 15;
32 const size_t kChunkSize = 4 * 1024;
34 } // namespace
36 WebSocketDeflateStream::WebSocketDeflateStream(
37 scoped_ptr<WebSocketStream> stream,
38 WebSocketDeflater::ContextTakeOverMode mode,
39 int client_window_bits,
40 scoped_ptr<WebSocketDeflatePredictor> predictor)
41 : stream_(stream.Pass()),
42 deflater_(mode),
43 inflater_(kChunkSize, kChunkSize),
44 reading_state_(NOT_READING),
45 writing_state_(NOT_WRITING),
46 current_reading_opcode_(WebSocketFrameHeader::kOpCodeText),
47 current_writing_opcode_(WebSocketFrameHeader::kOpCodeText),
48 predictor_(predictor.Pass()) {
49 DCHECK(stream_);
50 DCHECK_GE(client_window_bits, 8);
51 DCHECK_LE(client_window_bits, 15);
52 deflater_.Initialize(client_window_bits);
53 inflater_.Initialize(kWindowBits);
56 WebSocketDeflateStream::~WebSocketDeflateStream() {}
58 int WebSocketDeflateStream::ReadFrames(ScopedVector<WebSocketFrame>* frames,
59 const CompletionCallback& callback) {
60 int result = stream_->ReadFrames(
61 frames,
62 base::Bind(&WebSocketDeflateStream::OnReadComplete,
63 base::Unretained(this),
64 base::Unretained(frames),
65 callback));
66 if (result < 0)
67 return result;
68 DCHECK_EQ(OK, result);
69 DCHECK(!frames->empty());
71 return InflateAndReadIfNecessary(frames, callback);
74 int WebSocketDeflateStream::WriteFrames(ScopedVector<WebSocketFrame>* frames,
75 const CompletionCallback& callback) {
76 int result = Deflate(frames);
77 if (result != OK)
78 return result;
79 if (frames->empty())
80 return OK;
81 return stream_->WriteFrames(frames, callback);
84 void WebSocketDeflateStream::Close() { stream_->Close(); }
86 std::string WebSocketDeflateStream::GetSubProtocol() const {
87 return stream_->GetSubProtocol();
90 std::string WebSocketDeflateStream::GetExtensions() const {
91 return stream_->GetExtensions();
94 void WebSocketDeflateStream::OnReadComplete(
95 ScopedVector<WebSocketFrame>* frames,
96 const CompletionCallback& callback,
97 int result) {
98 if (result != OK) {
99 frames->clear();
100 callback.Run(result);
101 return;
104 int r = InflateAndReadIfNecessary(frames, callback);
105 if (r != ERR_IO_PENDING)
106 callback.Run(r);
109 int WebSocketDeflateStream::Deflate(ScopedVector<WebSocketFrame>* frames) {
110 ScopedVector<WebSocketFrame> frames_to_write;
111 // Store frames of the currently processed message if writing_state_ equals to
112 // WRITING_POSSIBLY_COMPRESSED_MESSAGE.
113 ScopedVector<WebSocketFrame> frames_of_message;
114 for (size_t i = 0; i < frames->size(); ++i) {
115 DCHECK(!(*frames)[i]->header.reserved1);
116 if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames)[i]->header.opcode)) {
117 frames_to_write.push_back((*frames)[i]);
118 (*frames)[i] = NULL;
119 continue;
121 if (writing_state_ == NOT_WRITING)
122 OnMessageStart(*frames, i);
124 scoped_ptr<WebSocketFrame> frame((*frames)[i]);
125 (*frames)[i] = NULL;
126 predictor_->RecordInputDataFrame(frame.get());
128 if (writing_state_ == WRITING_UNCOMPRESSED_MESSAGE) {
129 if (frame->header.final)
130 writing_state_ = NOT_WRITING;
131 predictor_->RecordWrittenDataFrame(frame.get());
132 frames_to_write.push_back(frame.release());
133 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
134 } else {
135 if (frame->data && !deflater_.AddBytes(frame->data->data(),
136 frame->header.payload_length)) {
137 DVLOG(1) << "WebSocket protocol error. "
138 << "deflater_.AddBytes() returns an error.";
139 return ERR_WS_PROTOCOL_ERROR;
141 if (frame->header.final && !deflater_.Finish()) {
142 DVLOG(1) << "WebSocket protocol error. "
143 << "deflater_.Finish() returns an error.";
144 return ERR_WS_PROTOCOL_ERROR;
147 if (writing_state_ == WRITING_COMPRESSED_MESSAGE) {
148 if (deflater_.CurrentOutputSize() >= kChunkSize ||
149 frame->header.final) {
150 int result = AppendCompressedFrame(frame->header, &frames_to_write);
151 if (result != OK)
152 return result;
154 if (frame->header.final)
155 writing_state_ = NOT_WRITING;
156 } else {
157 DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
158 bool final = frame->header.final;
159 frames_of_message.push_back(frame.release());
160 if (final) {
161 int result = AppendPossiblyCompressedMessage(&frames_of_message,
162 &frames_to_write);
163 if (result != OK)
164 return result;
165 frames_of_message.clear();
166 writing_state_ = NOT_WRITING;
171 DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
172 frames->swap(frames_to_write);
173 return OK;
176 void WebSocketDeflateStream::OnMessageStart(
177 const ScopedVector<WebSocketFrame>& frames, size_t index) {
178 WebSocketFrame* frame = frames[index];
179 current_writing_opcode_ = frame->header.opcode;
180 DCHECK(current_writing_opcode_ == WebSocketFrameHeader::kOpCodeText ||
181 current_writing_opcode_ == WebSocketFrameHeader::kOpCodeBinary);
182 WebSocketDeflatePredictor::Result prediction =
183 predictor_->Predict(frames, index);
185 switch (prediction) {
186 case WebSocketDeflatePredictor::DEFLATE:
187 writing_state_ = WRITING_COMPRESSED_MESSAGE;
188 return;
189 case WebSocketDeflatePredictor::DO_NOT_DEFLATE:
190 writing_state_ = WRITING_UNCOMPRESSED_MESSAGE;
191 return;
192 case WebSocketDeflatePredictor::TRY_DEFLATE:
193 writing_state_ = WRITING_POSSIBLY_COMPRESSED_MESSAGE;
194 return;
196 NOTREACHED();
199 int WebSocketDeflateStream::AppendCompressedFrame(
200 const WebSocketFrameHeader& header,
201 ScopedVector<WebSocketFrame>* frames_to_write) {
202 const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
203 scoped_refptr<IOBufferWithSize> compressed_payload =
204 deflater_.GetOutput(deflater_.CurrentOutputSize());
205 if (!compressed_payload) {
206 DVLOG(1) << "WebSocket protocol error. "
207 << "deflater_.GetOutput() returns an error.";
208 return ERR_WS_PROTOCOL_ERROR;
210 scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode));
211 compressed->header.CopyFrom(header);
212 compressed->header.opcode = opcode;
213 compressed->header.final = header.final;
214 compressed->header.reserved1 =
215 (opcode != WebSocketFrameHeader::kOpCodeContinuation);
216 compressed->data = compressed_payload;
217 compressed->header.payload_length = compressed_payload->size();
219 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
220 predictor_->RecordWrittenDataFrame(compressed.get());
221 frames_to_write->push_back(compressed.release());
222 return OK;
225 int WebSocketDeflateStream::AppendPossiblyCompressedMessage(
226 ScopedVector<WebSocketFrame>* frames,
227 ScopedVector<WebSocketFrame>* frames_to_write) {
228 DCHECK(!frames->empty());
230 const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
231 scoped_refptr<IOBufferWithSize> compressed_payload =
232 deflater_.GetOutput(deflater_.CurrentOutputSize());
233 if (!compressed_payload) {
234 DVLOG(1) << "WebSocket protocol error. "
235 << "deflater_.GetOutput() returns an error.";
236 return ERR_WS_PROTOCOL_ERROR;
239 uint64 original_payload_length = 0;
240 for (size_t i = 0; i < frames->size(); ++i) {
241 WebSocketFrame* frame = (*frames)[i];
242 // Asserts checking that frames represent one whole data message.
243 DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode));
244 DCHECK_EQ(i == 0,
245 WebSocketFrameHeader::kOpCodeContinuation !=
246 frame->header.opcode);
247 DCHECK_EQ(i == frames->size() - 1, frame->header.final);
248 original_payload_length += frame->header.payload_length;
250 if (original_payload_length <=
251 static_cast<uint64>(compressed_payload->size())) {
252 // Compression is not effective. Use the original frames.
253 for (size_t i = 0; i < frames->size(); ++i) {
254 WebSocketFrame* frame = (*frames)[i];
255 frames_to_write->push_back(frame);
256 predictor_->RecordWrittenDataFrame(frame);
257 (*frames)[i] = NULL;
259 frames->weak_clear();
260 return OK;
262 scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode));
263 compressed->header.CopyFrom((*frames)[0]->header);
264 compressed->header.opcode = opcode;
265 compressed->header.final = true;
266 compressed->header.reserved1 = true;
267 compressed->data = compressed_payload;
268 compressed->header.payload_length = compressed_payload->size();
270 predictor_->RecordWrittenDataFrame(compressed.get());
271 frames_to_write->push_back(compressed.release());
272 return OK;
275 int WebSocketDeflateStream::Inflate(ScopedVector<WebSocketFrame>* frames) {
276 ScopedVector<WebSocketFrame> frames_to_output;
277 ScopedVector<WebSocketFrame> frames_passed;
278 frames->swap(frames_passed);
279 for (size_t i = 0; i < frames_passed.size(); ++i) {
280 scoped_ptr<WebSocketFrame> frame(frames_passed[i]);
281 frames_passed[i] = NULL;
282 DVLOG(3) << "Input frame: opcode=" << frame->header.opcode
283 << " final=" << frame->header.final
284 << " reserved1=" << frame->header.reserved1
285 << " payload_length=" << frame->header.payload_length;
287 if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) {
288 frames_to_output.push_back(frame.release());
289 continue;
292 if (reading_state_ == NOT_READING) {
293 if (frame->header.reserved1)
294 reading_state_ = READING_COMPRESSED_MESSAGE;
295 else
296 reading_state_ = READING_UNCOMPRESSED_MESSAGE;
297 current_reading_opcode_ = frame->header.opcode;
298 } else {
299 if (frame->header.reserved1) {
300 DVLOG(1) << "WebSocket protocol error. "
301 << "Receiving a non-first frame with RSV1 flag set.";
302 return ERR_WS_PROTOCOL_ERROR;
306 if (reading_state_ == READING_UNCOMPRESSED_MESSAGE) {
307 if (frame->header.final)
308 reading_state_ = NOT_READING;
309 current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
310 frames_to_output.push_back(frame.release());
311 } else {
312 DCHECK_EQ(reading_state_, READING_COMPRESSED_MESSAGE);
313 if (frame->data && !inflater_.AddBytes(frame->data->data(),
314 frame->header.payload_length)) {
315 DVLOG(1) << "WebSocket protocol error. "
316 << "inflater_.AddBytes() returns an error.";
317 return ERR_WS_PROTOCOL_ERROR;
319 if (frame->header.final) {
320 if (!inflater_.Finish()) {
321 DVLOG(1) << "WebSocket protocol error. "
322 << "inflater_.Finish() returns an error.";
323 return ERR_WS_PROTOCOL_ERROR;
326 // TODO(yhirano): Many frames can be generated by the inflater and
327 // memory consumption can grow.
328 // We could avoid it, but avoiding it makes this class much more
329 // complicated.
330 while (inflater_.CurrentOutputSize() >= kChunkSize ||
331 frame->header.final) {
332 size_t size = std::min(kChunkSize, inflater_.CurrentOutputSize());
333 scoped_ptr<WebSocketFrame> inflated(
334 new WebSocketFrame(WebSocketFrameHeader::kOpCodeText));
335 scoped_refptr<IOBufferWithSize> data = inflater_.GetOutput(size);
336 bool is_final = !inflater_.CurrentOutputSize() && frame->header.final;
337 if (!data) {
338 DVLOG(1) << "WebSocket protocol error. "
339 << "inflater_.GetOutput() returns an error.";
340 return ERR_WS_PROTOCOL_ERROR;
342 inflated->header.CopyFrom(frame->header);
343 inflated->header.opcode = current_reading_opcode_;
344 inflated->header.final = is_final;
345 inflated->header.reserved1 = false;
346 inflated->data = data;
347 inflated->header.payload_length = data->size();
348 DVLOG(3) << "Inflated frame: opcode=" << inflated->header.opcode
349 << " final=" << inflated->header.final
350 << " reserved1=" << inflated->header.reserved1
351 << " payload_length=" << inflated->header.payload_length;
352 frames_to_output.push_back(inflated.release());
353 current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
354 if (is_final)
355 break;
357 if (frame->header.final)
358 reading_state_ = NOT_READING;
361 frames->swap(frames_to_output);
362 return frames->empty() ? ERR_IO_PENDING : OK;
365 int WebSocketDeflateStream::InflateAndReadIfNecessary(
366 ScopedVector<WebSocketFrame>* frames,
367 const CompletionCallback& callback) {
368 int result = Inflate(frames);
369 while (result == ERR_IO_PENDING) {
370 DCHECK(frames->empty());
372 result = stream_->ReadFrames(
373 frames,
374 base::Bind(&WebSocketDeflateStream::OnReadComplete,
375 base::Unretained(this),
376 base::Unretained(frames),
377 callback));
378 if (result < 0)
379 break;
380 DCHECK_EQ(OK, result);
381 DCHECK(!frames->empty());
383 result = Inflate(frames);
385 if (result < 0)
386 frames->clear();
387 return result;
390 } // namespace net