Add include.
[chromium-blink-merge.git] / net / websockets / websocket_deflate_stream.cc
blob6666bef0aea169029ca50f028988fc0069f83286
1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_deflate_stream.h"
7 #include <algorithm>
8 #include <string>
10 #include "base/bind.h"
11 #include "base/logging.h"
12 #include "base/memory/ref_counted.h"
13 #include "base/memory/scoped_ptr.h"
14 #include "base/memory/scoped_vector.h"
15 #include "net/base/completion_callback.h"
16 #include "net/base/io_buffer.h"
17 #include "net/base/net_errors.h"
18 #include "net/websockets/websocket_deflate_predictor.h"
19 #include "net/websockets/websocket_deflater.h"
20 #include "net/websockets/websocket_errors.h"
21 #include "net/websockets/websocket_frame.h"
22 #include "net/websockets/websocket_inflater.h"
23 #include "net/websockets/websocket_stream.h"
25 class GURL;
27 namespace net {
29 namespace {
31 const int kWindowBits = 15;
32 const size_t kChunkSize = 4 * 1024;
34 } // namespace
36 WebSocketDeflateStream::WebSocketDeflateStream(
37 scoped_ptr<WebSocketStream> stream,
38 WebSocketDeflater::ContextTakeOverMode mode,
39 int client_window_bits,
40 scoped_ptr<WebSocketDeflatePredictor> predictor)
41 : stream_(stream.Pass()),
42 deflater_(mode),
43 inflater_(kChunkSize, kChunkSize),
44 reading_state_(NOT_READING),
45 writing_state_(NOT_WRITING),
46 current_reading_opcode_(WebSocketFrameHeader::kOpCodeText),
47 current_writing_opcode_(WebSocketFrameHeader::kOpCodeText),
48 predictor_(predictor.Pass()) {
49 DCHECK(stream_);
50 DCHECK_GE(client_window_bits, 8);
51 DCHECK_LE(client_window_bits, 15);
52 deflater_.Initialize(client_window_bits);
53 inflater_.Initialize(kWindowBits);
56 WebSocketDeflateStream::~WebSocketDeflateStream() {}
58 int WebSocketDeflateStream::ReadFrames(ScopedVector<WebSocketFrame>* frames,
59 const CompletionCallback& callback) {
60 int result = stream_->ReadFrames(
61 frames,
62 base::Bind(&WebSocketDeflateStream::OnReadComplete,
63 base::Unretained(this),
64 base::Unretained(frames),
65 callback));
66 if (result < 0)
67 return result;
68 DCHECK_EQ(OK, result);
69 DCHECK(!frames->empty());
71 return InflateAndReadIfNecessary(frames, callback);
74 int WebSocketDeflateStream::WriteFrames(ScopedVector<WebSocketFrame>* frames,
75 const CompletionCallback& callback) {
76 int result = Deflate(frames);
77 if (result != OK)
78 return result;
79 if (frames->empty())
80 return OK;
81 return stream_->WriteFrames(frames, callback);
84 void WebSocketDeflateStream::Close() { stream_->Close(); }
86 std::string WebSocketDeflateStream::GetSubProtocol() const {
87 return stream_->GetSubProtocol();
90 std::string WebSocketDeflateStream::GetExtensions() const {
91 return stream_->GetExtensions();
94 void WebSocketDeflateStream::OnReadComplete(
95 ScopedVector<WebSocketFrame>* frames,
96 const CompletionCallback& callback,
97 int result) {
98 if (result != OK) {
99 frames->clear();
100 callback.Run(result);
101 return;
104 int r = InflateAndReadIfNecessary(frames, callback);
105 if (r != ERR_IO_PENDING)
106 callback.Run(r);
109 int WebSocketDeflateStream::Deflate(ScopedVector<WebSocketFrame>* frames) {
110 ScopedVector<WebSocketFrame> frames_to_write;
111 // Store frames of the currently processed message if writing_state_ equals to
112 // WRITING_POSSIBLY_COMPRESSED_MESSAGE.
113 ScopedVector<WebSocketFrame> frames_of_message;
114 for (size_t i = 0; i < frames->size(); ++i) {
115 DCHECK(!(*frames)[i]->header.reserved1);
116 if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames)[i]->header.opcode)) {
117 frames_to_write.push_back((*frames)[i]);
118 (*frames)[i] = NULL;
119 continue;
121 if (writing_state_ == NOT_WRITING)
122 OnMessageStart(*frames, i);
124 scoped_ptr<WebSocketFrame> frame((*frames)[i]);
125 (*frames)[i] = NULL;
126 predictor_->RecordInputDataFrame(frame.get());
128 if (writing_state_ == WRITING_UNCOMPRESSED_MESSAGE) {
129 if (frame->header.final)
130 writing_state_ = NOT_WRITING;
131 predictor_->RecordWrittenDataFrame(frame.get());
132 frames_to_write.push_back(frame.release());
133 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
134 } else {
135 if (frame->data.get() &&
136 !deflater_.AddBytes(
137 frame->data->data(),
138 static_cast<size_t>(frame->header.payload_length))) {
139 DVLOG(1) << "WebSocket protocol error. "
140 << "deflater_.AddBytes() returns an error.";
141 return ERR_WS_PROTOCOL_ERROR;
143 if (frame->header.final && !deflater_.Finish()) {
144 DVLOG(1) << "WebSocket protocol error. "
145 << "deflater_.Finish() returns an error.";
146 return ERR_WS_PROTOCOL_ERROR;
149 if (writing_state_ == WRITING_COMPRESSED_MESSAGE) {
150 if (deflater_.CurrentOutputSize() >= kChunkSize ||
151 frame->header.final) {
152 int result = AppendCompressedFrame(frame->header, &frames_to_write);
153 if (result != OK)
154 return result;
156 if (frame->header.final)
157 writing_state_ = NOT_WRITING;
158 } else {
159 DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
160 bool final = frame->header.final;
161 frames_of_message.push_back(frame.release());
162 if (final) {
163 int result = AppendPossiblyCompressedMessage(&frames_of_message,
164 &frames_to_write);
165 if (result != OK)
166 return result;
167 frames_of_message.clear();
168 writing_state_ = NOT_WRITING;
173 DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
174 frames->swap(frames_to_write);
175 return OK;
178 void WebSocketDeflateStream::OnMessageStart(
179 const ScopedVector<WebSocketFrame>& frames, size_t index) {
180 WebSocketFrame* frame = frames[index];
181 current_writing_opcode_ = frame->header.opcode;
182 DCHECK(current_writing_opcode_ == WebSocketFrameHeader::kOpCodeText ||
183 current_writing_opcode_ == WebSocketFrameHeader::kOpCodeBinary);
184 WebSocketDeflatePredictor::Result prediction =
185 predictor_->Predict(frames, index);
187 switch (prediction) {
188 case WebSocketDeflatePredictor::DEFLATE:
189 writing_state_ = WRITING_COMPRESSED_MESSAGE;
190 return;
191 case WebSocketDeflatePredictor::DO_NOT_DEFLATE:
192 writing_state_ = WRITING_UNCOMPRESSED_MESSAGE;
193 return;
194 case WebSocketDeflatePredictor::TRY_DEFLATE:
195 writing_state_ = WRITING_POSSIBLY_COMPRESSED_MESSAGE;
196 return;
198 NOTREACHED();
201 int WebSocketDeflateStream::AppendCompressedFrame(
202 const WebSocketFrameHeader& header,
203 ScopedVector<WebSocketFrame>* frames_to_write) {
204 const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
205 scoped_refptr<IOBufferWithSize> compressed_payload =
206 deflater_.GetOutput(deflater_.CurrentOutputSize());
207 if (!compressed_payload.get()) {
208 DVLOG(1) << "WebSocket protocol error. "
209 << "deflater_.GetOutput() returns an error.";
210 return ERR_WS_PROTOCOL_ERROR;
212 scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode));
213 compressed->header.CopyFrom(header);
214 compressed->header.opcode = opcode;
215 compressed->header.final = header.final;
216 compressed->header.reserved1 =
217 (opcode != WebSocketFrameHeader::kOpCodeContinuation);
218 compressed->data = compressed_payload;
219 compressed->header.payload_length = compressed_payload->size();
221 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
222 predictor_->RecordWrittenDataFrame(compressed.get());
223 frames_to_write->push_back(compressed.release());
224 return OK;
227 int WebSocketDeflateStream::AppendPossiblyCompressedMessage(
228 ScopedVector<WebSocketFrame>* frames,
229 ScopedVector<WebSocketFrame>* frames_to_write) {
230 DCHECK(!frames->empty());
232 const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
233 scoped_refptr<IOBufferWithSize> compressed_payload =
234 deflater_.GetOutput(deflater_.CurrentOutputSize());
235 if (!compressed_payload.get()) {
236 DVLOG(1) << "WebSocket protocol error. "
237 << "deflater_.GetOutput() returns an error.";
238 return ERR_WS_PROTOCOL_ERROR;
241 uint64 original_payload_length = 0;
242 for (size_t i = 0; i < frames->size(); ++i) {
243 WebSocketFrame* frame = (*frames)[i];
244 // Asserts checking that frames represent one whole data message.
245 DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode));
246 DCHECK_EQ(i == 0,
247 WebSocketFrameHeader::kOpCodeContinuation !=
248 frame->header.opcode);
249 DCHECK_EQ(i == frames->size() - 1, frame->header.final);
250 original_payload_length += frame->header.payload_length;
252 if (original_payload_length <=
253 static_cast<uint64>(compressed_payload->size())) {
254 // Compression is not effective. Use the original frames.
255 for (size_t i = 0; i < frames->size(); ++i) {
256 WebSocketFrame* frame = (*frames)[i];
257 frames_to_write->push_back(frame);
258 predictor_->RecordWrittenDataFrame(frame);
259 (*frames)[i] = NULL;
261 frames->weak_clear();
262 return OK;
264 scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode));
265 compressed->header.CopyFrom((*frames)[0]->header);
266 compressed->header.opcode = opcode;
267 compressed->header.final = true;
268 compressed->header.reserved1 = true;
269 compressed->data = compressed_payload;
270 compressed->header.payload_length = compressed_payload->size();
272 predictor_->RecordWrittenDataFrame(compressed.get());
273 frames_to_write->push_back(compressed.release());
274 return OK;
277 int WebSocketDeflateStream::Inflate(ScopedVector<WebSocketFrame>* frames) {
278 ScopedVector<WebSocketFrame> frames_to_output;
279 ScopedVector<WebSocketFrame> frames_passed;
280 frames->swap(frames_passed);
281 for (size_t i = 0; i < frames_passed.size(); ++i) {
282 scoped_ptr<WebSocketFrame> frame(frames_passed[i]);
283 frames_passed[i] = NULL;
284 DVLOG(3) << "Input frame: opcode=" << frame->header.opcode
285 << " final=" << frame->header.final
286 << " reserved1=" << frame->header.reserved1
287 << " payload_length=" << frame->header.payload_length;
289 if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) {
290 frames_to_output.push_back(frame.release());
291 continue;
294 if (reading_state_ == NOT_READING) {
295 if (frame->header.reserved1)
296 reading_state_ = READING_COMPRESSED_MESSAGE;
297 else
298 reading_state_ = READING_UNCOMPRESSED_MESSAGE;
299 current_reading_opcode_ = frame->header.opcode;
300 } else {
301 if (frame->header.reserved1) {
302 DVLOG(1) << "WebSocket protocol error. "
303 << "Receiving a non-first frame with RSV1 flag set.";
304 return ERR_WS_PROTOCOL_ERROR;
308 if (reading_state_ == READING_UNCOMPRESSED_MESSAGE) {
309 if (frame->header.final)
310 reading_state_ = NOT_READING;
311 current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
312 frames_to_output.push_back(frame.release());
313 } else {
314 DCHECK_EQ(reading_state_, READING_COMPRESSED_MESSAGE);
315 if (frame->data.get() &&
316 !inflater_.AddBytes(
317 frame->data->data(),
318 static_cast<size_t>(frame->header.payload_length))) {
319 DVLOG(1) << "WebSocket protocol error. "
320 << "inflater_.AddBytes() returns an error.";
321 return ERR_WS_PROTOCOL_ERROR;
323 if (frame->header.final) {
324 if (!inflater_.Finish()) {
325 DVLOG(1) << "WebSocket protocol error. "
326 << "inflater_.Finish() returns an error.";
327 return ERR_WS_PROTOCOL_ERROR;
330 // TODO(yhirano): Many frames can be generated by the inflater and
331 // memory consumption can grow.
332 // We could avoid it, but avoiding it makes this class much more
333 // complicated.
334 while (inflater_.CurrentOutputSize() >= kChunkSize ||
335 frame->header.final) {
336 size_t size = std::min(kChunkSize, inflater_.CurrentOutputSize());
337 scoped_ptr<WebSocketFrame> inflated(
338 new WebSocketFrame(WebSocketFrameHeader::kOpCodeText));
339 scoped_refptr<IOBufferWithSize> data = inflater_.GetOutput(size);
340 bool is_final = !inflater_.CurrentOutputSize() && frame->header.final;
341 if (!data.get()) {
342 DVLOG(1) << "WebSocket protocol error. "
343 << "inflater_.GetOutput() returns an error.";
344 return ERR_WS_PROTOCOL_ERROR;
346 inflated->header.CopyFrom(frame->header);
347 inflated->header.opcode = current_reading_opcode_;
348 inflated->header.final = is_final;
349 inflated->header.reserved1 = false;
350 inflated->data = data;
351 inflated->header.payload_length = data->size();
352 DVLOG(3) << "Inflated frame: opcode=" << inflated->header.opcode
353 << " final=" << inflated->header.final
354 << " reserved1=" << inflated->header.reserved1
355 << " payload_length=" << inflated->header.payload_length;
356 frames_to_output.push_back(inflated.release());
357 current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
358 if (is_final)
359 break;
361 if (frame->header.final)
362 reading_state_ = NOT_READING;
365 frames->swap(frames_to_output);
366 return frames->empty() ? ERR_IO_PENDING : OK;
369 int WebSocketDeflateStream::InflateAndReadIfNecessary(
370 ScopedVector<WebSocketFrame>* frames,
371 const CompletionCallback& callback) {
372 int result = Inflate(frames);
373 while (result == ERR_IO_PENDING) {
374 DCHECK(frames->empty());
376 result = stream_->ReadFrames(
377 frames,
378 base::Bind(&WebSocketDeflateStream::OnReadComplete,
379 base::Unretained(this),
380 base::Unretained(frames),
381 callback));
382 if (result < 0)
383 break;
384 DCHECK_EQ(OK, result);
385 DCHECK(!frames->empty());
387 result = Inflate(frames);
389 if (result < 0)
390 frames->clear();
391 return result;
394 } // namespace net