1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/websockets/websocket_inflater.h"
11 #include "base/logging.h"
12 #include "net/base/io_buffer.h"
13 #include "third_party/zlib/zlib.h"
19 class ShrinkableIOBufferWithSize
: public IOBufferWithSize
{
21 explicit ShrinkableIOBufferWithSize(int size
)
22 : IOBufferWithSize(size
) {}
24 void Shrink(int new_size
) {
25 DCHECK_LE(new_size
, size_
);
30 virtual ~ShrinkableIOBufferWithSize() {}
35 WebSocketInflater::WebSocketInflater()
36 : input_queue_(kDefaultInputIOBufferCapacity
),
37 output_buffer_(kDefaultBufferCapacity
) {}
39 WebSocketInflater::WebSocketInflater(size_t input_queue_capacity
,
40 size_t output_buffer_capacity
)
41 : input_queue_(input_queue_capacity
),
42 output_buffer_(output_buffer_capacity
) {
43 DCHECK_GT(input_queue_capacity
, 0u);
44 DCHECK_GT(output_buffer_capacity
, 0u);
47 bool WebSocketInflater::Initialize(int window_bits
) {
48 DCHECK_LE(8, window_bits
);
49 DCHECK_GE(15, window_bits
);
50 stream_
.reset(new z_stream
);
51 memset(stream_
.get(), 0, sizeof(*stream_
));
52 int result
= inflateInit2(stream_
.get(), -window_bits
);
54 inflateEnd(stream_
.get());
61 WebSocketInflater::~WebSocketInflater() {
63 inflateEnd(stream_
.get());
68 bool WebSocketInflater::AddBytes(const char* data
, size_t size
) {
72 if (!input_queue_
.IsEmpty()) {
74 input_queue_
.Push(data
, size
);
78 int result
= InflateWithFlush(data
, size
);
79 if (stream_
->avail_in
> 0)
80 input_queue_
.Push(&data
[size
- stream_
->avail_in
], stream_
->avail_in
);
82 return result
== Z_OK
|| result
== Z_BUF_ERROR
;
85 bool WebSocketInflater::Finish() {
86 return AddBytes("\x00\x00\xff\xff", 4);
89 scoped_refptr
<IOBufferWithSize
> WebSocketInflater::GetOutput(size_t size
) {
90 scoped_refptr
<ShrinkableIOBufferWithSize
> buffer
=
91 new ShrinkableIOBufferWithSize(size
);
92 size_t num_bytes_copied
= 0;
94 while (num_bytes_copied
< size
&& output_buffer_
.Size() > 0) {
95 size_t num_bytes_to_copy
=
96 std::min(output_buffer_
.Size(), size
- num_bytes_copied
);
97 output_buffer_
.Read(&buffer
->data()[num_bytes_copied
], num_bytes_to_copy
);
98 num_bytes_copied
+= num_bytes_to_copy
;
99 int result
= InflateChokedInput();
100 if (result
!= Z_OK
&& result
!= Z_BUF_ERROR
)
103 buffer
->Shrink(num_bytes_copied
);
107 int WebSocketInflater::InflateWithFlush(const char* next_in
, size_t avail_in
) {
108 int result
= Inflate(next_in
, avail_in
, Z_NO_FLUSH
);
109 if (result
!= Z_OK
&& result
!= Z_BUF_ERROR
)
112 if (CurrentOutputSize() > 0)
114 // CurrentOutputSize() == 0 means there is no data to be output,
115 // so we should make sure it by using Z_SYNC_FLUSH.
116 return Inflate(reinterpret_cast<const char*>(stream_
->next_in
),
121 int WebSocketInflater::Inflate(const char* next_in
,
124 stream_
->next_in
= reinterpret_cast<Bytef
*>(const_cast<char*>(next_in
));
125 stream_
->avail_in
= avail_in
;
127 int result
= Z_BUF_ERROR
;
129 std::pair
<char*, size_t> tail
= output_buffer_
.GetTail();
133 stream_
->next_out
= reinterpret_cast<Bytef
*>(tail
.first
);
134 stream_
->avail_out
= tail
.second
;
135 result
= inflate(stream_
.get(), flush
);
136 output_buffer_
.AdvanceTail(tail
.second
- stream_
->avail_out
);
137 if (result
== Z_STREAM_END
) {
138 // Received a block with BFINAL set to 1. Reset the decompression state.
139 result
= inflateReset(stream_
.get());
140 } else if (tail
.second
== stream_
->avail_out
) {
143 } while (result
== Z_OK
|| result
== Z_BUF_ERROR
);
147 int WebSocketInflater::InflateChokedInput() {
148 if (input_queue_
.IsEmpty())
149 return InflateWithFlush(NULL
, 0);
151 int result
= Z_BUF_ERROR
;
152 while (!input_queue_
.IsEmpty()) {
153 std::pair
<char*, size_t> top
= input_queue_
.Top();
155 result
= InflateWithFlush(top
.first
, top
.second
);
156 input_queue_
.Consume(top
.second
- stream_
->avail_in
);
158 if (result
!= Z_OK
&& result
!= Z_BUF_ERROR
)
161 if (stream_
->avail_in
> 0) {
162 // There are some data which are not consumed.
169 WebSocketInflater::OutputBuffer::OutputBuffer(size_t capacity
)
170 : capacity_(capacity
),
171 buffer_(capacity_
+ 1), // 1 for sentinel
175 WebSocketInflater::OutputBuffer::~OutputBuffer() {}
177 size_t WebSocketInflater::OutputBuffer::Size() const {
178 return (tail_
+ buffer_
.size() - head_
) % buffer_
.size();
181 std::pair
<char*, size_t> WebSocketInflater::OutputBuffer::GetTail() {
182 DCHECK_LT(tail_
, buffer_
.size());
183 return std::make_pair(&buffer_
[tail_
],
184 std::min(capacity_
- Size(), buffer_
.size() - tail_
));
187 void WebSocketInflater::OutputBuffer::Read(char* dest
, size_t size
) {
188 DCHECK_LE(size
, Size());
190 size_t num_bytes_copied
= 0;
192 size_t num_bytes_to_copy
= std::min(size
, buffer_
.size() - head_
);
193 DCHECK_LT(head_
, buffer_
.size());
194 memcpy(&dest
[num_bytes_copied
], &buffer_
[head_
], num_bytes_to_copy
);
195 AdvanceHead(num_bytes_to_copy
);
196 num_bytes_copied
+= num_bytes_to_copy
;
199 if (num_bytes_copied
== size
)
201 DCHECK_LE(head_
, tail_
);
202 size_t num_bytes_to_copy
= size
- num_bytes_copied
;
203 DCHECK_LE(num_bytes_to_copy
, tail_
- head_
);
204 DCHECK_LT(head_
, buffer_
.size());
205 memcpy(&dest
[num_bytes_copied
], &buffer_
[head_
], num_bytes_to_copy
);
206 AdvanceHead(num_bytes_to_copy
);
207 num_bytes_copied
+= num_bytes_to_copy
;
208 DCHECK_EQ(size
, num_bytes_copied
);
212 void WebSocketInflater::OutputBuffer::AdvanceHead(size_t advance
) {
213 DCHECK_LE(advance
, Size());
214 head_
= (head_
+ advance
) % buffer_
.size();
217 void WebSocketInflater::OutputBuffer::AdvanceTail(size_t advance
) {
218 DCHECK_LE(advance
+ Size(), capacity_
);
219 tail_
= (tail_
+ advance
) % buffer_
.size();
222 WebSocketInflater::InputQueue::InputQueue(size_t capacity
)
223 : capacity_(capacity
), head_of_first_buffer_(0), tail_of_last_buffer_(0) {}
225 WebSocketInflater::InputQueue::~InputQueue() {}
227 std::pair
<char*, size_t> WebSocketInflater::InputQueue::Top() {
229 if (buffers_
.size() == 1) {
230 return std::make_pair(&buffers_
.front()->data()[head_of_first_buffer_
],
231 tail_of_last_buffer_
- head_of_first_buffer_
);
233 return std::make_pair(&buffers_
.front()->data()[head_of_first_buffer_
],
234 capacity_
- head_of_first_buffer_
);
237 void WebSocketInflater::InputQueue::Push(const char* data
, size_t size
) {
241 size_t num_copied_bytes
= 0;
243 num_copied_bytes
+= PushToLastBuffer(data
, size
);
245 while (num_copied_bytes
< size
) {
246 DCHECK(IsEmpty() || tail_of_last_buffer_
== capacity_
);
248 buffers_
.push_back(new IOBufferWithSize(capacity_
));
249 tail_of_last_buffer_
= 0;
251 PushToLastBuffer(&data
[num_copied_bytes
], size
- num_copied_bytes
);
255 void WebSocketInflater::InputQueue::Consume(size_t size
) {
257 DCHECK_LE(size
+ head_of_first_buffer_
, capacity_
);
259 head_of_first_buffer_
+= size
;
260 if (head_of_first_buffer_
== capacity_
) {
261 buffers_
.pop_front();
262 head_of_first_buffer_
= 0;
264 if (buffers_
.size() == 1 && head_of_first_buffer_
== tail_of_last_buffer_
) {
265 buffers_
.pop_front();
266 head_of_first_buffer_
= 0;
267 tail_of_last_buffer_
= 0;
271 size_t WebSocketInflater::InputQueue::PushToLastBuffer(const char* data
,
274 size_t num_bytes_to_copy
= std::min(size
, capacity_
- tail_of_last_buffer_
);
275 if (!num_bytes_to_copy
)
277 IOBufferWithSize
* buffer
= buffers_
.back().get();
278 memcpy(&buffer
->data()[tail_of_last_buffer_
], data
, num_bytes_to_copy
);
279 tail_of_last_buffer_
+= num_bytes_to_copy
;
280 return num_bytes_to_copy
;