vfs: check userland buffers before reading them.
[haiku.git] / src / add-ons / kernel / network / protocols / tcp / BufferQueue.cpp
blobde83ba3b312256e7d6a76854e6ac5e88216426b5
1 /*
2 * Copyright 2006-2010, Haiku, Inc. All Rights Reserved.
3 * Distributed under the terms of the MIT License.
5 * Authors:
6 * Axel Dörfler, axeld@pinc-software.de
7 */
10 #include "BufferQueue.h"
12 #include <KernelExport.h>
15 //#define TRACE_BUFFER_QUEUE
16 #ifdef TRACE_BUFFER_QUEUE
17 # define TRACE(x) dprintf x
18 #else
19 # define TRACE(x)
20 #endif
22 #if DEBUG_BUFFER_QUEUE
23 # define VERIFY() Verify();
24 #else
25 # define VERIFY() ;
26 #endif
29 BufferQueue::BufferQueue(size_t maxBytes)
31 fMaxBytes(maxBytes),
32 fNumBytes(0),
33 fContiguousBytes(0),
34 fFirstSequence(0),
35 fLastSequence(0),
36 fPushPointer(0)
41 BufferQueue::~BufferQueue()
43 // free up any buffers left in the queue
45 net_buffer *buffer;
46 while ((buffer = fList.RemoveHead()) != NULL) {
47 gBufferModule->free(buffer);
52 void
53 BufferQueue::SetMaxBytes(size_t maxBytes)
55 fMaxBytes = maxBytes;
59 void
60 BufferQueue::SetInitialSequence(tcp_sequence sequence)
62 TRACE(("BufferQueue@%p::SetInitialSequence(%lu)\n", this,
63 sequence.Number()));
65 fFirstSequence = fLastSequence = sequence;
70 void
71 BufferQueue::Add(net_buffer *buffer)
73 Add(buffer, fLastSequence);
77 void
78 BufferQueue::Add(net_buffer *buffer, tcp_sequence sequence)
80 TRACE(("BufferQueue@%p::Add(buffer %p, size %lu, sequence %lu)\n",
81 this, buffer, buffer->size, sequence.Number()));
82 TRACE((" in: first: %lu, last: %lu, num: %lu, cont: %lu\n",
83 fFirstSequence.Number(), fLastSequence.Number(), fNumBytes,
84 fContiguousBytes));
85 VERIFY();
87 if (tcp_sequence(sequence + buffer->size) <= fFirstSequence
88 || buffer->size == 0) {
89 // This buffer does not contain any data of interest
90 gBufferModule->free(buffer);
91 return;
93 if (sequence < fFirstSequence) {
94 // this buffer contains data that is already long gone - trim it
95 gBufferModule->remove_header(buffer,
96 (fFirstSequence - sequence).Number());
97 sequence = fFirstSequence;
100 if (fList.IsEmpty() || sequence >= fLastSequence) {
101 // we usually just add the buffer to the end of the queue
102 fList.Add(buffer);
103 buffer->sequence = sequence.Number();
105 if (sequence == fLastSequence
106 && fLastSequence - fFirstSequence == fNumBytes) {
107 // there is no hole in the buffer, we can make the whole buffer
108 // available
109 fContiguousBytes += buffer->size;
112 fLastSequence = sequence + buffer->size;
113 fNumBytes += buffer->size;
115 TRACE((" out0: first: %lu, last: %lu, num: %lu, cont: %lu\n",
116 fFirstSequence.Number(), fLastSequence.Number(), fNumBytes,
117 fContiguousBytes));
118 VERIFY();
119 return;
122 if (fLastSequence < sequence + buffer->size)
123 fLastSequence = sequence + buffer->size;
125 // find the place where to insert the buffer into the queue
127 SegmentList::ReverseIterator iterator = fList.GetReverseIterator();
128 net_buffer *previous = NULL;
129 net_buffer *next = NULL;
130 while ((previous = iterator.Next()) != NULL) {
131 if (sequence >= previous->sequence) {
132 // The new fragment can be inserted after this one
133 break;
136 next = previous;
139 // check if we have duplicate data, and remove it if that is the case
140 if (previous != NULL) {
141 if (sequence == previous->sequence) {
142 // we already have at least part of this data - ignore new data
143 // whenever it makes sense (because some TCP implementations send
144 // bogus data when probing the window)
145 if (previous->size >= buffer->size) {
146 gBufferModule->free(buffer);
147 buffer = NULL;
148 } else {
149 fList.Remove(previous);
150 fNumBytes -= previous->size;
151 gBufferModule->free(previous);
153 } else if (tcp_sequence(previous->sequence + previous->size)
154 >= sequence + buffer->size) {
155 // We already know this data
156 gBufferModule->free(buffer);
157 buffer = NULL;
158 } else if (tcp_sequence(previous->sequence + previous->size)
159 > sequence) {
160 // We already have the first part of this buffer
161 gBufferModule->remove_header(buffer,
162 (previous->sequence + previous->size - sequence).Number());
163 sequence = previous->sequence + previous->size;
167 // "next" always starts at or after the buffer sequence
168 ASSERT(next == NULL || buffer == NULL || next->sequence >= sequence);
170 while (buffer != NULL && next != NULL
171 && tcp_sequence(sequence + buffer->size) > next->sequence) {
172 // we already have at least part of this data
173 if (tcp_sequence(next->sequence + next->size)
174 <= sequence + buffer->size) {
175 net_buffer *remove = next;
176 next = (net_buffer *)next->link.next;
178 fList.Remove(remove);
179 fNumBytes -= remove->size;
180 gBufferModule->free(remove);
181 } else if (tcp_sequence(next->sequence) > sequence) {
182 // We have the end of this buffer already
183 gBufferModule->remove_trailer(buffer,
184 (sequence + buffer->size - next->sequence).Number());
185 } else {
186 // We already have this data
187 gBufferModule->free(buffer);
188 buffer = NULL;
192 if (buffer == NULL) {
193 TRACE((" out1: first: %lu, last: %lu, num: %lu, cont: %lu\n",
194 fFirstSequence.Number(), fLastSequence.Number(), fNumBytes,
195 fContiguousBytes));
196 VERIFY();
197 return;
200 fList.Insert(next, buffer);
201 buffer->sequence = sequence.Number();
202 fNumBytes += buffer->size;
204 // we might need to update the number of bytes available
206 if (fLastSequence - fFirstSequence == fNumBytes)
207 fContiguousBytes = fNumBytes;
208 else if (fFirstSequence + fContiguousBytes == sequence) {
209 // the complicated case: the new segment may have connected almost all
210 // buffers in the queue (but not all, or the above would be true)
212 do {
213 fContiguousBytes += buffer->size;
215 buffer = (struct net_buffer *)buffer->link.next;
216 } while (buffer != NULL
217 && fFirstSequence + fContiguousBytes == buffer->sequence);
220 TRACE((" out2: first: %lu, last: %lu, num: %lu, cont: %lu\n",
221 fFirstSequence.Number(), fLastSequence.Number(), fNumBytes,
222 fContiguousBytes));
223 VERIFY();
227 /*! Removes all data in the queue up to the \a sequence number as specified.
229 NOTE: If there are missing segments in the buffers to be removed,
230 fContiguousBytes is not maintained correctly!
232 status_t
233 BufferQueue::RemoveUntil(tcp_sequence sequence)
235 TRACE(("BufferQueue@%p::RemoveUntil(sequence %lu)\n", this,
236 sequence.Number()));
237 VERIFY();
239 if (sequence < fFirstSequence)
240 return B_OK;
242 SegmentList::Iterator iterator = fList.GetIterator();
243 tcp_sequence lastRemoved = fFirstSequence;
244 net_buffer *buffer = NULL;
245 while ((buffer = iterator.Next()) != NULL && buffer->sequence < sequence) {
246 ASSERT(lastRemoved == buffer->sequence);
247 // This assures that the queue has no holes, and fContiguousBytes
248 // is maintained correctly.
250 if (sequence >= buffer->sequence + buffer->size) {
251 // remove this buffer completely
252 iterator.Remove();
253 fNumBytes -= buffer->size;
255 fContiguousBytes -= buffer->size;
256 lastRemoved = buffer->sequence + buffer->size;
257 gBufferModule->free(buffer);
258 } else {
259 // remove the header as far as needed
260 size_t size = (sequence - buffer->sequence).Number();
261 gBufferModule->remove_header(buffer, size);
263 buffer->sequence += size;
264 fNumBytes -= size;
265 fContiguousBytes -= size;
266 break;
270 if (fList.IsEmpty())
271 fFirstSequence = fLastSequence;
272 else
273 fFirstSequence = fList.Head()->sequence;
275 VERIFY();
276 return B_OK;
280 /*! Clones the requested data in the buffer queue into the provided \a buffer.
282 status_t
283 BufferQueue::Get(net_buffer *buffer, tcp_sequence sequence, size_t bytes)
285 TRACE(("BufferQueue@%p::Get(sequence %lu, bytes %lu)\n", this,
286 sequence.Number(), bytes));
287 VERIFY();
289 if (bytes == 0)
290 return B_OK;
292 if (sequence >= fLastSequence || sequence < fFirstSequence) {
293 // we don't have the requested data
294 return B_BAD_VALUE;
296 if (tcp_sequence(sequence + bytes) > fLastSequence)
297 bytes = (fLastSequence - sequence).Number();
299 size_t bytesLeft = bytes;
301 // find first buffer matching the sequence
303 SegmentList::Iterator iterator = fList.GetIterator();
304 net_buffer *source = NULL;
305 while ((source = iterator.Next()) != NULL) {
306 if (sequence < source->sequence + source->size)
307 break;
310 if (source == NULL)
311 panic("we should have had that data...");
312 if (tcp_sequence(source->sequence) > sequence) {
313 panic("source %p, sequence = %" B_PRIu32 " (%" B_PRIu32 ")\n", source,
314 source->sequence, sequence.Number());
317 // clone the data
319 uint32 offset = (sequence - source->sequence).Number();
321 while (source != NULL && bytesLeft > 0) {
322 size_t size = min_c(source->size - offset, bytesLeft);
323 status_t status = gBufferModule->append_cloned(buffer, source, offset,
324 size);
325 if (status < B_OK)
326 return status;
328 bytesLeft -= size;
329 offset = 0;
330 source = iterator.Next();
333 VERIFY();
334 return B_OK;
338 /*! Creates a new buffer containing \a bytes bytes from the start of the
339 buffer queue. If \a remove is \c true, the data is removed from the
340 queue, if not, the data is cloned from the queue.
342 status_t
343 BufferQueue::Get(size_t bytes, bool remove, net_buffer **_buffer)
345 if (bytes > Available())
346 bytes = Available();
348 if (bytes == 0) {
349 // we don't need to create a buffer when there is no data
350 *_buffer = NULL;
351 return B_OK;
354 net_buffer *buffer = fList.First();
355 size_t bytesLeft = bytes;
356 ASSERT(buffer != NULL);
358 if (!remove || buffer->size > bytes) {
359 // we need a new buffer
360 buffer = gBufferModule->create(256);
361 if (buffer == NULL)
362 return B_NO_MEMORY;
363 } else {
364 // we can reuse this buffer
365 bytesLeft -= buffer->size;
366 fFirstSequence += buffer->size;
368 fList.Remove(buffer);
371 // clone/copy the remaining data
373 SegmentList::Iterator iterator = fList.GetIterator();
374 net_buffer *source = NULL;
375 status_t status = B_OK;
376 while (bytesLeft > 0 && (source = iterator.Next()) != NULL) {
377 size_t size = min_c(source->size, bytesLeft);
378 status = gBufferModule->append_cloned(buffer, source, 0, size);
379 if (status < B_OK)
380 break;
382 bytesLeft -= size;
384 if (!remove)
385 continue;
387 // remove either the whole buffer or only the part we cloned
389 fFirstSequence += size;
391 if (size == source->size) {
392 iterator.Remove();
393 gBufferModule->free(source);
394 } else {
395 gBufferModule->remove_header(source, size);
396 source->sequence += size;
400 if (remove && buffer->size) {
401 fNumBytes -= buffer->size;
402 fContiguousBytes -= buffer->size;
405 // We always return what we got, or else we would lose data
406 if (status < B_OK && buffer->size == 0) {
407 // We could not remove any bytes from the buffer, so
408 // let this call fail.
409 gBufferModule->free(buffer);
410 VERIFY();
411 return status;
414 *_buffer = buffer;
415 VERIFY();
416 return B_OK;
420 size_t
421 BufferQueue::Available(tcp_sequence sequence) const
423 if (sequence > (fFirstSequence + fContiguousBytes).Number())
424 return 0;
426 return (fContiguousBytes + fFirstSequence - sequence).Number();
430 void
431 BufferQueue::SetPushPointer()
433 if (fList.IsEmpty())
434 fPushPointer = 0;
435 else
436 fPushPointer = fList.Tail()->sequence + fList.Tail()->size;
439 #if DEBUG_BUFFER_QUEUE
441 /*! Perform a sanity check of the whole queue.
443 void
444 BufferQueue::Verify() const
446 ASSERT(Available() == 0 || fList.First() != NULL);
448 if (fList.First() == NULL) {
449 ASSERT(fNumBytes == 0);
450 return;
453 SegmentList::ConstIterator iterator = fList.GetIterator();
454 size_t numBytes = 0;
455 size_t contiguousBytes = 0;
456 bool contiguous = true;
457 tcp_sequence last = fFirstSequence;
459 while (net_buffer* buffer = iterator.Next()) {
460 if (contiguous && buffer->sequence == last)
461 contiguousBytes += buffer->size;
462 else
463 contiguous = false;
465 ASSERT(last <= buffer->sequence);
466 ASSERT(buffer->size > 0);
468 numBytes += buffer->size;
469 last = buffer->sequence + buffer->size;
472 ASSERT(last == fLastSequence);
473 ASSERT(contiguousBytes == fContiguousBytes);
474 ASSERT(numBytes == fNumBytes);
478 void
479 BufferQueue::Dump() const
481 SegmentList::ConstIterator iterator = fList.GetIterator();
482 int32 number = 0;
483 while (net_buffer* buffer = iterator.Next()) {
484 kprintf(" %" B_PRId32 ". buffer %p, sequence %" B_PRIu32 ", size %"
485 B_PRIu32 "\n", ++number, buffer, buffer->sequence, buffer->size);
489 #endif // DEBUG_BUFFER_QUEUE