-Add UDP ephemeral port allocation
[newos.git] / kernel / net / udp.c
blobb8cd523fc4b0619c613ad865d0f081566096173d
1 /*
2 ** Copyright 2001, Travis Geiselbrecht. All rights reserved.
3 ** Distributed under the terms of the NewOS License.
4 */
5 #include <kernel/kernel.h>
6 #include <kernel/cbuf.h>
7 #include <kernel/lock.h>
8 #include <kernel/debug.h>
9 #include <kernel/heap.h>
10 #include <kernel/khash.h>
11 #include <kernel/sem.h>
12 #include <kernel/arch/cpu.h>
13 #include <kernel/net/udp.h>
14 #include <kernel/net/ipv4.h>
15 #include <kernel/net/misc.h>
16 #include <stdlib.h>
18 typedef struct udp_header {
19 uint16 source_port;
20 uint16 dest_port;
21 uint16 length;
22 uint16 checksum;
23 } _PACKED udp_header;
25 typedef struct udp_pseudo_header {
26 ipv4_addr source_addr;
27 ipv4_addr dest_addr;
28 uint8 zero;
29 uint8 protocol;
30 uint16 udp_length;
31 } _PACKED udp_pseudo_header;
33 typedef struct udp_queue_elem {
34 struct udp_queue_elem *next;
35 struct udp_queue_elem *prev;
36 ipv4_addr src_address;
37 ipv4_addr target_address;
38 uint16 src_port;
39 uint16 target_port;
40 int len;
41 cbuf *buf;
42 } udp_queue_elem;
44 typedef struct udp_queue {
45 udp_queue_elem *next;
46 udp_queue_elem *prev;
47 int count;
48 } udp_queue;
50 typedef struct udp_endpoint {
51 struct udp_endpoint *next;
52 mutex lock;
53 sem_id blocking_sem;
54 uint16 port;
55 udp_queue q;
56 int ref_count;
57 } udp_endpoint;
59 static udp_endpoint *endpoints;
60 static mutex endpoints_lock;
61 static int next_ephemeral_port;
63 static int udp_endpoint_compare_func(void *_e, const void *_key)
65 udp_endpoint *e = _e;
66 const uint16 *port = _key;
68 if(e->port == *port)
69 return 0;
70 else
71 return 1;
74 static unsigned int udp_endpoint_hash_func(void *_e, const void *_key, unsigned int range)
76 udp_endpoint *e = _e;
77 const uint16 *port = _key;
79 if(e)
80 return e->port % range;
81 else
82 return *port % range;
85 static void udp_init_queue(udp_queue *q)
87 q->count = 0;
88 q->next = q->prev = (udp_queue_elem *)q;
91 static udp_queue_elem *udp_queue_pop(udp_queue *q)
93 if(q->next != (udp_queue_elem *)q) {
94 udp_queue_elem *e = q->next;
96 q->next = e->next;
97 e->next->prev = (udp_queue_elem *)q;
98 q->count--;
99 e->next = e->prev = NULL;
100 return e;
101 } else {
102 return NULL;
106 static void udp_queue_push(udp_queue *q, udp_queue_elem *e)
108 e->prev = q->prev;
109 e->next = (udp_queue_elem *)q;
110 q->prev->next = e;
111 q->prev = e;
112 q->count++;
115 static void udp_endpoint_acquire_ref(udp_endpoint *e)
117 atomic_add(&e->ref_count, 1);
120 static void udp_endpoint_release_ref(udp_endpoint *e)
122 if(atomic_add(&e->ref_count, -1) == 1) {
123 udp_queue_elem *qe;
125 mutex_destroy(&e->lock);
126 sem_delete(e->blocking_sem);
128 // clear out the queue of packets
129 for(qe = udp_queue_pop(&e->q); qe; qe = udp_queue_pop(&e->q)) {
130 if(qe->buf)
131 cbuf_free_chain(qe->buf);
132 kfree(qe);
137 static int udp_allocate_ephemeral_port(void)
139 return atomic_add(&next_ephemeral_port, 1) % 0x10000;
142 int udp_input(cbuf *buf, ifnet *i, ipv4_addr source_address, ipv4_addr target_address)
144 udp_header *header;
145 udp_endpoint *e;
146 udp_queue_elem *qe;
147 uint16 port;
148 int err;
150 header = cbuf_get_ptr(buf, 0);
152 #if NET_CHATTY
153 dprintf("udp_input: src port %d, dest port %d, len %d, buf len %d, checksum 0x%x\n",
154 ntohs(header->source_port), ntohs(header->dest_port), ntohs(header->length), (int)cbuf_get_len(buf), ntohs(header->checksum));
155 #endif
156 if(ntohs(header->length) > (uint16)cbuf_get_len(buf)) {
157 err = ERR_NET_BAD_PACKET;
158 goto ditch_packet;
161 // deal with the checksum check
162 if(header->checksum) {
163 udp_pseudo_header pheader;
164 uint16 checksum;
166 // set up the pseudo header for checksum purposes
167 pheader.source_addr = htonl(source_address);
168 pheader.dest_addr = htonl(target_address);
169 pheader.zero = 0;
170 pheader.protocol = IP_PROT_UDP;
171 pheader.udp_length = header->length;
173 checksum = cbuf_ones_cksum16_2(buf, 0, ntohs(header->length), &pheader, sizeof(pheader));
174 if(checksum != 0) {
175 #if NET_CHATTY
176 dprintf("udp_receive: packet failed checksum\n");
177 #endif
178 err = ERR_NET_BAD_PACKET;
179 goto ditch_packet;
183 // see if we have an endpoint
184 port = ntohs(header->dest_port);
185 mutex_lock(&endpoints_lock);
186 e = hash_lookup(endpoints, &port);
187 if(e)
188 udp_endpoint_acquire_ref(e);
189 mutex_unlock(&endpoints_lock);
191 if(!e) {
192 err = NO_ERROR;
193 goto ditch_packet;
196 // okay, we have an endpoint, lets queue our stuff up and move on
197 qe = kmalloc(sizeof(udp_queue_elem));
198 if(!qe) {
199 udp_endpoint_release_ref(e);
200 err = ERR_NO_MEMORY;
201 goto ditch_packet;
203 qe->src_port = ntohs(header->source_port);
204 qe->target_port = port;
205 qe->src_address = source_address;
206 qe->target_address = target_address;
207 qe->len = ntohs(header->length) - sizeof(udp_header);
209 // trim off the udp header
210 buf = cbuf_truncate_head(buf, sizeof(udp_header), true);
211 qe->buf = buf;
213 mutex_lock(&e->lock);
214 udp_queue_push(&e->q, qe);
215 mutex_unlock(&e->lock);
217 sem_release(e->blocking_sem, 1);
219 udp_endpoint_release_ref(e);
221 err = NO_ERROR;
222 return err;
224 ditch_packet:
225 cbuf_free_chain(buf);
227 return err;
230 int udp_open(void **prot_data)
232 udp_endpoint *e;
234 e = kmalloc(sizeof(udp_endpoint));
235 if(!e)
236 return ERR_NO_MEMORY;
238 mutex_init(&e->lock, "udp endpoint lock");
239 e->blocking_sem = sem_create(0, "udp endpoint sem");
240 e->port = 0;
241 e->ref_count = 1;
242 udp_init_queue(&e->q);
244 mutex_lock(&endpoints_lock);
245 hash_insert(endpoints, e);
246 mutex_unlock(&endpoints_lock);
248 *prot_data = e;
250 return 0;
253 static int _udp_bind(udp_endpoint *e, int port)
255 int err;
257 mutex_lock(&e->lock);
259 if(e->port == 0) {
261 // make up a port number if one isn't passed in
262 if (port == 0)
263 port = udp_allocate_ephemeral_port();
265 dprintf("_udp_bind: setting endprint %p to port %d\n", e, port);
267 if(port != e->port) {
269 // XXX search to make sure this port isn't used already
271 // remove it from the hashtable, stick it back with the new port
272 mutex_lock(&endpoints_lock);
273 hash_remove(endpoints, e);
274 e->port = port;
275 hash_insert(endpoints, e);
276 mutex_unlock(&endpoints_lock);
278 err = NO_ERROR;
279 } else {
280 err = ERR_NET_SOCKET_ALREADY_BOUND;
283 mutex_unlock(&e->lock);
285 return err;
288 int udp_bind(void *prot_data, sockaddr *addr)
290 udp_endpoint *e = prot_data;
292 // XXX does not support binding src ip address
293 return _udp_bind(e, addr->port);
296 int udp_connect(void *prot_data, sockaddr *addr)
298 return ERR_NOT_ALLOWED;
301 int udp_listen(void *prot_data)
303 return ERR_NOT_ALLOWED;
306 int udp_accept(void *prot_data, sockaddr *addr, void **new_socket)
308 return ERR_NOT_ALLOWED;
311 int udp_close(void *prot_data)
313 udp_endpoint *e = prot_data;
315 mutex_lock(&endpoints_lock);
316 hash_remove(endpoints, e);
317 mutex_unlock(&endpoints_lock);
319 udp_endpoint_release_ref(e);
321 return 0;
324 ssize_t udp_recvfrom(void *prot_data, void *buf, ssize_t len, sockaddr *saddr, int flags, bigtime_t timeout)
326 udp_endpoint *e = prot_data;
327 udp_queue_elem *qe;
328 int err;
329 ssize_t ret;
331 retry:
332 if(flags & SOCK_FLAG_TIMEOUT)
333 err = sem_acquire_etc(e->blocking_sem, 1, SEM_FLAG_TIMEOUT, timeout, NULL);
334 else
335 err = sem_acquire(e->blocking_sem, 1);
336 if(err < 0)
337 return err;
339 // pop an item off the list, if there are any
340 mutex_lock(&e->lock);
341 qe = udp_queue_pop(&e->q);
342 mutex_unlock(&e->lock);
344 if(!qe)
345 goto retry;
347 // we have the data, copy it out
348 err = cbuf_user_memcpy_from_chain(buf, qe->buf, 0, min(qe->len, len));
349 if(err < 0) {
350 ret = err;
351 goto out;
353 ret = qe->len;
355 // copy the address out
356 if(saddr) {
357 saddr->addr.len = 4;
358 saddr->addr.type = ADDR_TYPE_IP;
359 NETADDR_TO_IPV4(saddr->addr) = qe->src_address;
360 saddr->port = qe->src_port;
363 out:
364 // free this queue entry
365 cbuf_free_chain(qe->buf);
366 kfree(qe);
368 return ret;
371 ssize_t udp_sendto(void *prot_data, const void *inbuf, ssize_t len, sockaddr *toaddr)
373 udp_endpoint *e = prot_data;
374 udp_header *header;
375 int total_len;
376 cbuf *buf;
377 udp_pseudo_header pheader;
378 ipv4_addr srcaddr;
379 int err;
381 // make sure the args make sense
382 if(len < 0 || len + sizeof(udp_header) > 0xffff)
383 return ERR_INVALID_ARGS;
384 if(toaddr->port < 0 || toaddr->port > 0xffff)
385 return ERR_INVALID_ARGS;
387 // find us a local port if no one has already
388 if (e->port == 0)
389 _udp_bind(e, 0);
391 // allocate a buffer to hold the data + header
392 total_len = len + sizeof(udp_header);
393 buf = cbuf_get_chain(total_len);
394 if(!buf)
395 return ERR_NO_MEMORY;
397 // copy the data to this new buffer
398 err = cbuf_user_memcpy_to_chain(buf, sizeof(udp_header), inbuf, len);
399 if(err < 0) {
400 cbuf_free_chain(buf);
401 return ERR_VM_BAD_USER_MEMORY;
404 // set up the udp pseudo header
405 if(ipv4_lookup_srcaddr_for_dest(NETADDR_TO_IPV4(toaddr->addr), &srcaddr) < 0) {
406 cbuf_free_chain(buf);
407 return ERR_NET_NO_ROUTE;
409 pheader.source_addr = htonl(srcaddr);
410 pheader.dest_addr = htonl(NETADDR_TO_IPV4(toaddr->addr));
411 pheader.zero = 0;
412 pheader.protocol = IP_PROT_UDP;
413 pheader.udp_length = htons(total_len);
415 // start setting up the header
416 header = cbuf_get_ptr(buf, 0);
417 header->source_port = htons(e->port);
418 header->dest_port = htons(toaddr->port);
419 header->length = htons(total_len);
420 header->checksum = 0;
421 header->checksum = cbuf_ones_cksum16_2(buf, 0, total_len, &pheader, sizeof(pheader));
422 if(header->checksum == 0)
423 header->checksum = 0xffff;
425 // send it away
426 err = ipv4_output(buf, NETADDR_TO_IPV4(toaddr->addr), IP_PROT_UDP);
428 // if it returns ARP_QUEUED, then it's actually okay
429 if(err == ERR_NET_ARP_QUEUED) {
430 err = 0;
433 return err;
436 int udp_init(void)
438 mutex_init(&endpoints_lock, "udp_endpoints lock");
440 next_ephemeral_port = rand() % 32000 + 1024;
442 endpoints = hash_init(256, offsetof(udp_endpoint, next),
443 &udp_endpoint_compare_func, &udp_endpoint_hash_func);
444 if(!endpoints)
445 return ERR_NO_MEMORY;
447 return 0;