2 ** Copyright 2001, Travis Geiselbrecht. All rights reserved.
3 ** Distributed under the terms of the NewOS License.
5 #include <kernel/kernel.h>
6 #include <kernel/cbuf.h>
7 #include <kernel/lock.h>
8 #include <kernel/debug.h>
9 #include <kernel/heap.h>
10 #include <kernel/khash.h>
11 #include <kernel/sem.h>
12 #include <kernel/arch/cpu.h>
13 #include <kernel/net/udp.h>
14 #include <kernel/net/ipv4.h>
15 #include <kernel/net/misc.h>
18 typedef struct udp_header
{
25 typedef struct udp_pseudo_header
{
26 ipv4_addr source_addr
;
31 } _PACKED udp_pseudo_header
;
33 typedef struct udp_queue_elem
{
34 struct udp_queue_elem
*next
;
35 struct udp_queue_elem
*prev
;
36 ipv4_addr src_address
;
37 ipv4_addr target_address
;
44 typedef struct udp_queue
{
50 typedef struct udp_endpoint
{
51 struct udp_endpoint
*next
;
59 static udp_endpoint
*endpoints
;
60 static mutex endpoints_lock
;
61 static int next_ephemeral_port
;
63 static int udp_endpoint_compare_func(void *_e
, const void *_key
)
66 const uint16
*port
= _key
;
74 static unsigned int udp_endpoint_hash_func(void *_e
, const void *_key
, unsigned int range
)
77 const uint16
*port
= _key
;
80 return e
->port
% range
;
85 static void udp_init_queue(udp_queue
*q
)
88 q
->next
= q
->prev
= (udp_queue_elem
*)q
;
91 static udp_queue_elem
*udp_queue_pop(udp_queue
*q
)
93 if(q
->next
!= (udp_queue_elem
*)q
) {
94 udp_queue_elem
*e
= q
->next
;
97 e
->next
->prev
= (udp_queue_elem
*)q
;
99 e
->next
= e
->prev
= NULL
;
106 static void udp_queue_push(udp_queue
*q
, udp_queue_elem
*e
)
109 e
->next
= (udp_queue_elem
*)q
;
115 static void udp_endpoint_acquire_ref(udp_endpoint
*e
)
117 atomic_add(&e
->ref_count
, 1);
120 static void udp_endpoint_release_ref(udp_endpoint
*e
)
122 if(atomic_add(&e
->ref_count
, -1) == 1) {
125 mutex_destroy(&e
->lock
);
126 sem_delete(e
->blocking_sem
);
128 // clear out the queue of packets
129 for(qe
= udp_queue_pop(&e
->q
); qe
; qe
= udp_queue_pop(&e
->q
)) {
131 cbuf_free_chain(qe
->buf
);
137 static int udp_allocate_ephemeral_port(void)
139 return atomic_add(&next_ephemeral_port
, 1) % 0x10000;
142 int udp_input(cbuf
*buf
, ifnet
*i
, ipv4_addr source_address
, ipv4_addr target_address
)
150 header
= cbuf_get_ptr(buf
, 0);
153 dprintf("udp_input: src port %d, dest port %d, len %d, buf len %d, checksum 0x%x\n",
154 ntohs(header
->source_port
), ntohs(header
->dest_port
), ntohs(header
->length
), (int)cbuf_get_len(buf
), ntohs(header
->checksum
));
156 if(ntohs(header
->length
) > (uint16
)cbuf_get_len(buf
)) {
157 err
= ERR_NET_BAD_PACKET
;
161 // deal with the checksum check
162 if(header
->checksum
) {
163 udp_pseudo_header pheader
;
166 // set up the pseudo header for checksum purposes
167 pheader
.source_addr
= htonl(source_address
);
168 pheader
.dest_addr
= htonl(target_address
);
170 pheader
.protocol
= IP_PROT_UDP
;
171 pheader
.udp_length
= header
->length
;
173 checksum
= cbuf_ones_cksum16_2(buf
, 0, ntohs(header
->length
), &pheader
, sizeof(pheader
));
176 dprintf("udp_receive: packet failed checksum\n");
178 err
= ERR_NET_BAD_PACKET
;
183 // see if we have an endpoint
184 port
= ntohs(header
->dest_port
);
185 mutex_lock(&endpoints_lock
);
186 e
= hash_lookup(endpoints
, &port
);
188 udp_endpoint_acquire_ref(e
);
189 mutex_unlock(&endpoints_lock
);
196 // okay, we have an endpoint, lets queue our stuff up and move on
197 qe
= kmalloc(sizeof(udp_queue_elem
));
199 udp_endpoint_release_ref(e
);
203 qe
->src_port
= ntohs(header
->source_port
);
204 qe
->target_port
= port
;
205 qe
->src_address
= source_address
;
206 qe
->target_address
= target_address
;
207 qe
->len
= ntohs(header
->length
) - sizeof(udp_header
);
209 // trim off the udp header
210 buf
= cbuf_truncate_head(buf
, sizeof(udp_header
), true);
213 mutex_lock(&e
->lock
);
214 udp_queue_push(&e
->q
, qe
);
215 mutex_unlock(&e
->lock
);
217 sem_release(e
->blocking_sem
, 1);
219 udp_endpoint_release_ref(e
);
225 cbuf_free_chain(buf
);
230 int udp_open(void **prot_data
)
234 e
= kmalloc(sizeof(udp_endpoint
));
236 return ERR_NO_MEMORY
;
238 mutex_init(&e
->lock
, "udp endpoint lock");
239 e
->blocking_sem
= sem_create(0, "udp endpoint sem");
242 udp_init_queue(&e
->q
);
244 mutex_lock(&endpoints_lock
);
245 hash_insert(endpoints
, e
);
246 mutex_unlock(&endpoints_lock
);
253 static int _udp_bind(udp_endpoint
*e
, int port
)
257 mutex_lock(&e
->lock
);
261 // make up a port number if one isn't passed in
263 port
= udp_allocate_ephemeral_port();
265 dprintf("_udp_bind: setting endprint %p to port %d\n", e
, port
);
267 if(port
!= e
->port
) {
269 // XXX search to make sure this port isn't used already
271 // remove it from the hashtable, stick it back with the new port
272 mutex_lock(&endpoints_lock
);
273 hash_remove(endpoints
, e
);
275 hash_insert(endpoints
, e
);
276 mutex_unlock(&endpoints_lock
);
280 err
= ERR_NET_SOCKET_ALREADY_BOUND
;
283 mutex_unlock(&e
->lock
);
288 int udp_bind(void *prot_data
, sockaddr
*addr
)
290 udp_endpoint
*e
= prot_data
;
292 // XXX does not support binding src ip address
293 return _udp_bind(e
, addr
->port
);
296 int udp_connect(void *prot_data
, sockaddr
*addr
)
298 return ERR_NOT_ALLOWED
;
301 int udp_listen(void *prot_data
)
303 return ERR_NOT_ALLOWED
;
306 int udp_accept(void *prot_data
, sockaddr
*addr
, void **new_socket
)
308 return ERR_NOT_ALLOWED
;
311 int udp_close(void *prot_data
)
313 udp_endpoint
*e
= prot_data
;
315 mutex_lock(&endpoints_lock
);
316 hash_remove(endpoints
, e
);
317 mutex_unlock(&endpoints_lock
);
319 udp_endpoint_release_ref(e
);
324 ssize_t
udp_recvfrom(void *prot_data
, void *buf
, ssize_t len
, sockaddr
*saddr
, int flags
, bigtime_t timeout
)
326 udp_endpoint
*e
= prot_data
;
332 if(flags
& SOCK_FLAG_TIMEOUT
)
333 err
= sem_acquire_etc(e
->blocking_sem
, 1, SEM_FLAG_TIMEOUT
, timeout
, NULL
);
335 err
= sem_acquire(e
->blocking_sem
, 1);
339 // pop an item off the list, if there are any
340 mutex_lock(&e
->lock
);
341 qe
= udp_queue_pop(&e
->q
);
342 mutex_unlock(&e
->lock
);
347 // we have the data, copy it out
348 err
= cbuf_user_memcpy_from_chain(buf
, qe
->buf
, 0, min(qe
->len
, len
));
355 // copy the address out
358 saddr
->addr
.type
= ADDR_TYPE_IP
;
359 NETADDR_TO_IPV4(saddr
->addr
) = qe
->src_address
;
360 saddr
->port
= qe
->src_port
;
364 // free this queue entry
365 cbuf_free_chain(qe
->buf
);
371 ssize_t
udp_sendto(void *prot_data
, const void *inbuf
, ssize_t len
, sockaddr
*toaddr
)
373 udp_endpoint
*e
= prot_data
;
377 udp_pseudo_header pheader
;
381 // make sure the args make sense
382 if(len
< 0 || len
+ sizeof(udp_header
) > 0xffff)
383 return ERR_INVALID_ARGS
;
384 if(toaddr
->port
< 0 || toaddr
->port
> 0xffff)
385 return ERR_INVALID_ARGS
;
387 // find us a local port if no one has already
391 // allocate a buffer to hold the data + header
392 total_len
= len
+ sizeof(udp_header
);
393 buf
= cbuf_get_chain(total_len
);
395 return ERR_NO_MEMORY
;
397 // copy the data to this new buffer
398 err
= cbuf_user_memcpy_to_chain(buf
, sizeof(udp_header
), inbuf
, len
);
400 cbuf_free_chain(buf
);
401 return ERR_VM_BAD_USER_MEMORY
;
404 // set up the udp pseudo header
405 if(ipv4_lookup_srcaddr_for_dest(NETADDR_TO_IPV4(toaddr
->addr
), &srcaddr
) < 0) {
406 cbuf_free_chain(buf
);
407 return ERR_NET_NO_ROUTE
;
409 pheader
.source_addr
= htonl(srcaddr
);
410 pheader
.dest_addr
= htonl(NETADDR_TO_IPV4(toaddr
->addr
));
412 pheader
.protocol
= IP_PROT_UDP
;
413 pheader
.udp_length
= htons(total_len
);
415 // start setting up the header
416 header
= cbuf_get_ptr(buf
, 0);
417 header
->source_port
= htons(e
->port
);
418 header
->dest_port
= htons(toaddr
->port
);
419 header
->length
= htons(total_len
);
420 header
->checksum
= 0;
421 header
->checksum
= cbuf_ones_cksum16_2(buf
, 0, total_len
, &pheader
, sizeof(pheader
));
422 if(header
->checksum
== 0)
423 header
->checksum
= 0xffff;
426 err
= ipv4_output(buf
, NETADDR_TO_IPV4(toaddr
->addr
), IP_PROT_UDP
);
428 // if it returns ARP_QUEUED, then it's actually okay
429 if(err
== ERR_NET_ARP_QUEUED
) {
438 mutex_init(&endpoints_lock
, "udp_endpoints lock");
440 next_ephemeral_port
= rand() % 32000 + 1024;
442 endpoints
= hash_init(256, offsetof(udp_endpoint
, next
),
443 &udp_endpoint_compare_func
, &udp_endpoint_hash_func
);
445 return ERR_NO_MEMORY
;