vfs: check userland buffers before reading them.
[haiku.git] / src / system / boot / loader / net / UDP.cpp
blob5c4adbdc355920dabb150ce7b0943a9e9a40d40e
1 /*
2 * Copyright 2005, Ingo Weinhold <bonefish@cs.tu-berlin.de>.
3 * All rights reserved. Distributed under the terms of the MIT License.
4 */
7 #include <boot/net/UDP.h>
9 #include <stdio.h>
11 #include <KernelExport.h>
13 #include <boot/net/ChainBuffer.h>
14 #include <boot/net/NetStack.h>
17 //#define TRACE_UDP
18 #ifdef TRACE_UDP
19 # define TRACE(x) dprintf x
20 #else
21 # define TRACE(x) ;
22 #endif
25 using std::nothrow;
28 // #pragma mark - UDPPacket
31 UDPPacket::UDPPacket()
33 fNext(NULL),
34 fData(NULL),
35 fSize(0)
40 UDPPacket::~UDPPacket()
42 free(fData);
46 status_t
47 UDPPacket::SetTo(const void *data, size_t size, ip_addr_t sourceAddress,
48 uint16 sourcePort, ip_addr_t destinationAddress, uint16 destinationPort)
50 if (data == NULL)
51 return B_BAD_VALUE;
53 // clone the data
54 fData = malloc(size);
55 if (fData == NULL)
56 return B_NO_MEMORY;
57 memcpy(fData, data, size);
59 fSize = size;
60 fSourceAddress = sourceAddress;
61 fDestinationAddress = destinationAddress;
62 fSourcePort = sourcePort;
63 fDestinationPort = destinationPort;
65 return B_OK;
69 UDPPacket *
70 UDPPacket::Next() const
72 return fNext;
76 void
77 UDPPacket::SetNext(UDPPacket *next)
79 fNext = next;
83 const void *
84 UDPPacket::Data() const
86 return fData;
90 size_t
91 UDPPacket::DataSize() const
93 return fSize;
97 ip_addr_t
98 UDPPacket::SourceAddress() const
100 return fSourceAddress;
104 uint16
105 UDPPacket::SourcePort() const
107 return fSourcePort;
111 ip_addr_t
112 UDPPacket::DestinationAddress() const
114 return fDestinationAddress;
118 uint16
119 UDPPacket::DestinationPort() const
121 return fDestinationPort;
125 // #pragma mark - UDPSocket
128 UDPSocket::UDPSocket()
130 fUDPService(NetStack::Default()->GetUDPService()),
131 fFirstPacket(NULL),
132 fLastPacket(NULL),
133 fAddress(INADDR_ANY),
134 fPort(0)
139 UDPSocket::~UDPSocket()
141 if (fPort != 0 && fUDPService != NULL)
142 fUDPService->UnbindSocket(this);
146 status_t
147 UDPSocket::Bind(ip_addr_t address, uint16 port)
149 if (fUDPService == NULL) {
150 printf("UDPSocket::Bind(): no UDP service\n");
151 return B_NO_INIT;
154 if (address == INADDR_BROADCAST || port == 0) {
155 printf("UDPSocket::Bind(): broadcast IP or port 0\n");
156 return B_BAD_VALUE;
159 if (fPort != 0) {
160 printf("UDPSocket::Bind(): already bound\n");
161 return EALREADY;
162 // correct code?
165 status_t error = fUDPService->BindSocket(this, address, port);
166 if (error != B_OK) {
167 printf("UDPSocket::Bind(): service BindSocket() failed\n");
168 return error;
171 fAddress = address;
172 fPort = port;
174 return B_OK;
178 void
179 UDPSocket::Detach()
181 fUDPService = NULL;
182 // This will lead to subsequent methods returning B_NO_INIT
187 status_t
188 UDPSocket::Send(ip_addr_t destinationAddress, uint16 destinationPort,
189 ChainBuffer *buffer)
191 if (fUDPService == NULL)
192 return B_NO_INIT;
194 return fUDPService->Send(fPort, destinationAddress, destinationPort,
195 buffer);
199 status_t
200 UDPSocket::Send(ip_addr_t destinationAddress, uint16 destinationPort,
201 const void *data, size_t size)
203 if (data == NULL)
204 return B_BAD_VALUE;
206 ChainBuffer buffer((void*)data, size);
207 return Send(destinationAddress, destinationPort, &buffer);
211 status_t
212 UDPSocket::Receive(UDPPacket **_packet, bigtime_t timeout)
214 if (fUDPService == NULL)
215 return B_NO_INIT;
217 if (_packet == NULL)
218 return B_BAD_VALUE;
220 bigtime_t startTime = system_time();
221 for (;;) {
222 fUDPService->ProcessIncomingPackets();
223 *_packet = PopPacket();
224 if (*_packet != NULL)
225 return B_OK;
227 if (system_time() - startTime > timeout)
228 return (timeout == 0 ? B_WOULD_BLOCK : B_TIMED_OUT);
233 void
234 UDPSocket::PushPacket(UDPPacket *packet)
236 if (fLastPacket != NULL)
237 fLastPacket->SetNext(packet);
238 else
239 fFirstPacket = packet;
241 fLastPacket = packet;
242 packet->SetNext(NULL);
246 UDPPacket *
247 UDPSocket::PopPacket()
249 if (fFirstPacket == NULL)
250 return NULL;
252 UDPPacket *packet = fFirstPacket;
253 fFirstPacket = packet->Next();
255 if (fFirstPacket == NULL)
256 fLastPacket = NULL;
258 packet->SetNext(NULL);
259 return packet;
263 // #pragma mark - UDPService
266 UDPService::UDPService(IPService *ipService)
268 IPSubService(kUDPServiceName),
269 fIPService(ipService)
274 UDPService::~UDPService()
276 int count = fSockets.Count();
277 for (int i = 0; i < count; i++) {
278 UDPSocket *socket = fSockets.ElementAt(i);
279 socket->Detach();
282 if (fIPService != NULL)
283 fIPService->UnregisterIPSubService(this);
287 status_t
288 UDPService::Init()
290 if (fIPService == NULL)
291 return B_BAD_VALUE;
292 if (!fIPService->RegisterIPSubService(this))
293 return B_NO_MEMORY;
294 return B_OK;
298 uint8
299 UDPService::IPProtocol() const
301 return IPPROTO_UDP;
305 void
306 UDPService::HandleIPPacket(IPService *ipService, ip_addr_t sourceIP,
307 ip_addr_t destinationIP, const void *data, size_t size)
309 TRACE(("UDPService::HandleIPPacket(): source: %08lx, destination: %08lx, "
310 "%lu - %lu bytes\n", sourceIP, destinationIP, size,
311 sizeof(udp_header)));
313 if (data == NULL || size < sizeof(udp_header))
314 return;
316 const udp_header *header = (const udp_header*)data;
317 uint16 source = ntohs(header->source);
318 uint16 destination = ntohs(header->destination);
319 uint16 length = ntohs(header->length);
321 // check the header
322 if (length < sizeof(udp_header) || length > size
323 || (header->checksum != 0 // 0 => checksum disabled
324 && _ChecksumData(data, length, sourceIP, destinationIP) != 0)) {
325 TRACE(("UDPService::HandleIPPacket(): dropping packet -- invalid size "
326 "or checksum\n"));
327 return;
330 // find the target socket
331 UDPSocket *socket = _FindSocket(destinationIP, destination);
332 if (socket == NULL)
333 return;
335 // create a UDPPacket and queue it in the socket
336 UDPPacket *packet = new(nothrow) UDPPacket;
337 if (packet == NULL)
338 return;
339 status_t error = packet->SetTo((uint8*)data + sizeof(udp_header),
340 length - sizeof(udp_header), sourceIP, source, destinationIP,
341 destination);
342 if (error == B_OK)
343 socket->PushPacket(packet);
344 else
345 delete packet;
349 status_t
350 UDPService::Send(uint16 sourcePort, ip_addr_t destinationAddress,
351 uint16 destinationPort, ChainBuffer *buffer)
353 TRACE(("UDPService::Send(source port: %hu, to: %08lx:%hu, %lu bytes)\n",
354 sourcePort, destinationAddress, destinationPort,
355 (buffer != NULL ? buffer->TotalSize() : 0)));
357 if (fIPService == NULL)
358 return B_NO_INIT;
360 if (buffer == NULL)
361 return B_BAD_VALUE;
363 // prepend the UDP header
364 udp_header header;
365 ChainBuffer headerBuffer(&header, sizeof(header), buffer);
366 header.source = htons(sourcePort);
367 header.destination = htons(destinationPort);
368 header.length = htons(headerBuffer.TotalSize());
370 // compute the checksum
371 header.checksum = 0;
372 header.checksum = htons(_ChecksumBuffer(&headerBuffer,
373 fIPService->IPAddress(), destinationAddress,
374 headerBuffer.TotalSize()));
375 // 0 means checksum disabled; 0xffff is equivalent in this case
376 if (header.checksum == 0)
377 header.checksum = 0xffff;
379 return fIPService->Send(destinationAddress, IPPROTO_UDP, &headerBuffer);
383 void
384 UDPService::ProcessIncomingPackets()
386 if (fIPService != NULL)
387 fIPService->ProcessIncomingPackets();
391 status_t
392 UDPService::BindSocket(UDPSocket *socket, ip_addr_t address, uint16 port)
394 if (socket == NULL)
395 return B_BAD_VALUE;
397 if (_FindSocket(address, port) != NULL) {
398 printf("UDPService::BindSocket(): address in use\n");
399 return EADDRINUSE;
402 return fSockets.Add(socket);
406 void
407 UDPService::UnbindSocket(UDPSocket *socket)
409 fSockets.Remove(socket);
413 uint16
414 UDPService::_ChecksumBuffer(ChainBuffer *buffer, ip_addr_t source,
415 ip_addr_t destination, uint16 length)
417 // The checksum is calculated over a pseudo-header plus the UDP packet.
418 // So we temporarily prepend the pseudo-header.
419 struct pseudo_header {
420 ip_addr_t source;
421 ip_addr_t destination;
422 uint8 pad;
423 uint8 protocol;
424 uint16 length;
425 } __attribute__ ((__packed__));
426 pseudo_header header = {
427 htonl(source),
428 htonl(destination),
430 IPPROTO_UDP,
431 htons(length)
434 ChainBuffer headerBuffer(&header, sizeof(header), buffer);
435 uint16 checksum = ip_checksum(&headerBuffer);
436 headerBuffer.DetachNext();
437 return checksum;
441 uint16
442 UDPService::_ChecksumData(const void *data, uint16 length, ip_addr_t source,
443 ip_addr_t destination)
445 ChainBuffer buffer((void*)data, length);
446 return _ChecksumBuffer(&buffer, source, destination, length);
450 UDPSocket *
451 UDPService::_FindSocket(ip_addr_t address, uint16 port)
453 int count = fSockets.Count();
454 for (int i = 0; i < count; i++) {
455 UDPSocket *socket = fSockets.ElementAt(i);
456 if ((address == INADDR_ANY || socket->Address() == INADDR_ANY
457 || socket->Address() == address)
458 && port == socket->Port()) {
459 return socket;
463 return NULL;