vfs: check userland buffers before reading them.
[haiku.git] / src / add-ons / kernel / network / stack / routes.cpp
bloba43cbef18d7beb73fea8588d19483241cc1c4beb
1 /*
2 * Copyright 2006-2010, Haiku, Inc. All Rights Reserved.
3 * Distributed under the terms of the MIT License.
5 * Authors:
6 * Axel Dörfler, axeld@pinc-software.de
7 */
10 #include "domains.h"
11 #include "interfaces.h"
12 #include "routes.h"
13 #include "stack_private.h"
14 #include "utility.h"
16 #include <net_device.h>
17 #include <NetUtilities.h>
19 #include <lock.h>
20 #include <util/AutoLock.h>
22 #include <KernelExport.h>
24 #include <net/if_dl.h>
25 #include <net/route.h>
26 #include <new>
27 #include <stdlib.h>
28 #include <string.h>
29 #include <sys/sockio.h>
32 //#define TRACE_ROUTES
33 #ifdef TRACE_ROUTES
34 # define TRACE(x...) dprintf(STACK_DEBUG_PREFIX x)
35 #else
36 # define TRACE(x...) ;
37 #endif
40 net_route_private::net_route_private()
42 destination = mask = gateway = NULL;
46 net_route_private::~net_route_private()
48 free(destination);
49 free(mask);
50 free(gateway);
54 // #pragma mark - private functions
57 static status_t
58 user_copy_address(const sockaddr* from, sockaddr** to)
60 if (from == NULL) {
61 *to = NULL;
62 return B_OK;
65 sockaddr address;
66 if (user_memcpy(&address, from, sizeof(struct sockaddr)) < B_OK)
67 return B_BAD_ADDRESS;
69 *to = (sockaddr*)malloc(address.sa_len);
70 if (*to == NULL)
71 return B_NO_MEMORY;
73 if (address.sa_len > sizeof(struct sockaddr)) {
74 if (user_memcpy(*to, from, address.sa_len) < B_OK)
75 return B_BAD_ADDRESS;
76 } else
77 memcpy(*to, &address, address.sa_len);
79 return B_OK;
83 static status_t
84 user_copy_address(const sockaddr* from, sockaddr_storage* to)
86 if (from == NULL)
87 return B_BAD_ADDRESS;
89 if (user_memcpy(to, from, sizeof(sockaddr)) < B_OK)
90 return B_BAD_ADDRESS;
92 if (to->ss_len > sizeof(sockaddr)) {
93 if (to->ss_len > sizeof(sockaddr_storage))
94 return B_BAD_VALUE;
95 if (user_memcpy(to, from, to->ss_len) < B_OK)
96 return B_BAD_ADDRESS;
99 return B_OK;
103 static net_route_private*
104 find_route(struct net_domain* _domain, const net_route* description)
106 struct net_domain_private* domain = (net_domain_private*)_domain;
107 RouteList::Iterator iterator = domain->routes.GetIterator();
109 while (iterator.HasNext()) {
110 net_route_private* route = iterator.Next();
112 if ((route->flags & RTF_DEFAULT) != 0
113 && (description->flags & RTF_DEFAULT) != 0) {
114 // there can only be one default route per interface address family
115 // TODO: check this better
116 if (route->interface_address == description->interface_address)
117 return route;
119 continue;
122 if ((route->flags & (RTF_GATEWAY | RTF_HOST | RTF_LOCAL | RTF_DEFAULT))
123 == (description->flags
124 & (RTF_GATEWAY | RTF_HOST | RTF_LOCAL | RTF_DEFAULT))
125 && domain->address_module->equal_masked_addresses(
126 route->destination, description->destination, description->mask)
127 && domain->address_module->equal_addresses(route->mask,
128 description->mask)
129 && domain->address_module->equal_addresses(route->gateway,
130 description->gateway)
131 && (description->interface_address == NULL
132 || description->interface_address == route->interface_address))
133 return route;
136 return NULL;
140 static net_route_private*
141 find_route(net_domain* _domain, const sockaddr* address)
143 net_domain_private* domain = (net_domain_private*)_domain;
145 // find last matching route
147 RouteList::Iterator iterator = domain->routes.GetIterator();
148 net_route_private* candidate = NULL;
150 TRACE("test address %s for routes...\n",
151 AddressString(domain, address).Data());
153 // TODO: alternate equal default routes
155 while (iterator.HasNext()) {
156 net_route_private* route = iterator.Next();
158 if (route->mask) {
159 sockaddr maskedAddress;
160 domain->address_module->mask_address(address, route->mask,
161 &maskedAddress);
162 if (!domain->address_module->equal_addresses(&maskedAddress,
163 route->destination))
164 continue;
165 } else if (!domain->address_module->equal_addresses(address,
166 route->destination))
167 continue;
169 // neglect routes that point to devices that have no link
170 if ((route->interface_address->interface->device->flags & IFF_LINK)
171 == 0) {
172 if (candidate == NULL) {
173 TRACE(" found candidate: %s, flags %lx\n", AddressString(
174 domain, route->destination).Data(), route->flags);
175 candidate = route;
177 continue;
180 TRACE(" found route: %s, flags %lx\n",
181 AddressString(domain, route->destination).Data(), route->flags);
183 return route;
186 return candidate;
190 static void
191 put_route_internal(struct net_domain_private* domain, net_route* _route)
193 ASSERT_LOCKED_RECURSIVE(&domain->lock);
195 net_route_private* route = (net_route_private*)_route;
196 if (route == NULL || atomic_add(&route->ref_count, -1) != 1)
197 return;
199 // delete route - it must already have been removed at this point
200 if (route->interface_address != NULL)
201 ((InterfaceAddress*)route->interface_address)->ReleaseReference();
203 delete route;
207 static struct net_route*
208 get_route_internal(struct net_domain_private* domain,
209 const struct sockaddr* address)
211 ASSERT_LOCKED_RECURSIVE(&domain->lock);
212 net_route_private* route = NULL;
214 if (address->sa_family == AF_LINK) {
215 // special address to find an interface directly
216 RouteList::Iterator iterator = domain->routes.GetIterator();
217 const sockaddr_dl* link = (const sockaddr_dl*)address;
219 while (iterator.HasNext()) {
220 route = iterator.Next();
222 net_device* device = route->interface_address->interface->device;
224 if ((link->sdl_nlen > 0
225 && !strncmp(device->name, (const char*)link->sdl_data,
226 IF_NAMESIZE))
227 || (link->sdl_nlen == 0 && link->sdl_alen > 0
228 && !memcmp(LLADDR(link), device->address.data,
229 device->address.length)))
230 break;
232 } else
233 route = find_route(domain, address);
235 if (route != NULL && atomic_add(&route->ref_count, 1) == 0) {
236 // route has been deleted already
237 route = NULL;
240 return route;
244 static void
245 update_route_infos(struct net_domain_private* domain)
247 ASSERT_LOCKED_RECURSIVE(&domain->lock);
248 RouteInfoList::Iterator iterator = domain->route_infos.GetIterator();
250 while (iterator.HasNext()) {
251 net_route_info* info = iterator.Next();
253 put_route_internal(domain, info->route);
254 info->route = get_route_internal(domain, &info->address);
259 static sockaddr*
260 copy_address(UserBuffer& buffer, sockaddr* address)
262 if (address == NULL)
263 return NULL;
265 return (sockaddr*)buffer.Push(address, address->sa_len);
269 static status_t
270 fill_route_entry(route_entry* target, void* _buffer, size_t bufferSize,
271 net_route* route)
273 UserBuffer buffer(((uint8*)_buffer) + sizeof(route_entry),
274 bufferSize - sizeof(route_entry));
276 target->destination = copy_address(buffer, route->destination);
277 target->mask = copy_address(buffer, route->mask);
278 target->gateway = copy_address(buffer, route->gateway);
279 target->source = copy_address(buffer, route->interface_address->local);
280 target->flags = route->flags;
281 target->mtu = route->mtu;
283 return buffer.Status();
287 // #pragma mark - exported functions
290 /*! Determines the size of a buffer large enough to contain the whole
291 routing table.
293 uint32
294 route_table_size(net_domain_private* domain)
296 RecursiveLocker locker(domain->lock);
297 uint32 size = 0;
299 RouteList::Iterator iterator = domain->routes.GetIterator();
300 while (iterator.HasNext()) {
301 net_route_private* route = iterator.Next();
302 size += IF_NAMESIZE + sizeof(route_entry);
304 if (route->destination)
305 size += route->destination->sa_len;
306 if (route->mask)
307 size += route->mask->sa_len;
308 if (route->gateway)
309 size += route->gateway->sa_len;
312 return size;
316 /*! Dumps a list of all routes into the supplied userland buffer.
317 If the routes don't fit into the buffer, an error (\c ENOBUFS) is
318 returned.
320 status_t
321 list_routes(net_domain_private* domain, void* buffer, size_t size)
323 RecursiveLocker _(domain->lock);
325 RouteList::Iterator iterator = domain->routes.GetIterator();
326 const size_t kBaseSize = IF_NAMESIZE + sizeof(route_entry);
327 size_t spaceLeft = size;
329 sockaddr zeros;
330 memset(&zeros, 0, sizeof(sockaddr));
331 zeros.sa_family = domain->family;
332 zeros.sa_len = sizeof(sockaddr);
334 while (iterator.HasNext()) {
335 net_route* route = iterator.Next();
337 size = kBaseSize;
339 sockaddr* destination = NULL;
340 sockaddr* mask = NULL;
341 sockaddr* gateway = NULL;
342 uint8* next = (uint8*)buffer + size;
344 if (route->destination != NULL) {
345 destination = (sockaddr*)next;
346 next += route->destination->sa_len;
347 size += route->destination->sa_len;
349 if (route->mask != NULL) {
350 mask = (sockaddr*)next;
351 next += route->mask->sa_len;
352 size += route->mask->sa_len;
354 if (route->gateway != NULL) {
355 gateway = (sockaddr*)next;
356 next += route->gateway->sa_len;
357 size += route->gateway->sa_len;
360 if (spaceLeft < size)
361 return ENOBUFS;
363 ifreq request;
364 memset(&request, 0, sizeof(request));
366 strlcpy(request.ifr_name, route->interface_address->interface->name,
367 IF_NAMESIZE);
368 request.ifr_route.destination = destination;
369 request.ifr_route.mask = mask;
370 request.ifr_route.gateway = gateway;
371 request.ifr_route.mtu = route->mtu;
372 request.ifr_route.flags = route->flags;
374 // copy data into userland buffer
375 if (user_memcpy(buffer, &request, kBaseSize) < B_OK
376 || (route->destination != NULL
377 && user_memcpy(request.ifr_route.destination,
378 route->destination, route->destination->sa_len) < B_OK)
379 || (route->mask != NULL && user_memcpy(request.ifr_route.mask,
380 route->mask, route->mask->sa_len) < B_OK)
381 || (route->gateway != NULL && user_memcpy(request.ifr_route.gateway,
382 route->gateway, route->gateway->sa_len) < B_OK))
383 return B_BAD_ADDRESS;
385 buffer = (void*)next;
386 spaceLeft -= size;
389 return B_OK;
393 status_t
394 control_routes(struct net_interface* _interface, net_domain* domain,
395 int32 option, void* argument, size_t length)
397 TRACE("control_routes(interface %p, domain %p, option %" B_PRId32 ")\n",
398 _interface, domain, option);
399 Interface* interface = (Interface*)_interface;
401 switch (option) {
402 case SIOCADDRT:
403 case SIOCDELRT:
405 // add or remove a route
406 if (length != sizeof(struct ifreq))
407 return B_BAD_VALUE;
409 route_entry entry;
410 if (user_memcpy(&entry, &((ifreq*)argument)->ifr_route,
411 sizeof(route_entry)) != B_OK)
412 return B_BAD_ADDRESS;
414 net_route_private route;
415 status_t status;
416 if ((status = user_copy_address(entry.destination,
417 &route.destination)) != B_OK
418 || (status = user_copy_address(entry.mask, &route.mask)) != B_OK
419 || (status = user_copy_address(entry.gateway, &route.gateway))
420 != B_OK)
421 return status;
423 InterfaceAddress* address
424 = interface->FirstForFamily(domain->family);
426 route.mtu = entry.mtu;
427 route.flags = entry.flags;
428 route.interface_address = address;
430 if (option == SIOCADDRT)
431 status = add_route(domain, &route);
432 else
433 status = remove_route(domain, &route);
435 if (address != NULL)
436 address->ReleaseReference();
437 return status;
440 return B_BAD_VALUE;
444 status_t
445 add_route(struct net_domain* _domain, const struct net_route* newRoute)
447 struct net_domain_private* domain = (net_domain_private*)_domain;
449 TRACE("add route to domain %s: dest %s, mask %s, gw %s, flags %lx\n",
450 domain->name,
451 AddressString(domain, newRoute->destination
452 ? newRoute->destination : NULL).Data(),
453 AddressString(domain, newRoute->mask ? newRoute->mask : NULL).Data(),
454 AddressString(domain, newRoute->gateway
455 ? newRoute->gateway : NULL).Data(),
456 newRoute->flags);
458 if (domain == NULL || newRoute == NULL
459 || newRoute->interface_address == NULL
460 || ((newRoute->flags & RTF_HOST) != 0 && newRoute->mask != NULL)
461 || ((newRoute->flags & RTF_DEFAULT) == 0
462 && newRoute->destination == NULL)
463 || ((newRoute->flags & RTF_GATEWAY) != 0 && newRoute->gateway == NULL)
464 || !domain->address_module->check_mask(newRoute->mask))
465 return B_BAD_VALUE;
467 RecursiveLocker _(domain->lock);
469 net_route_private* route = find_route(domain, newRoute);
470 if (route != NULL)
471 return B_FILE_EXISTS;
473 route = new (std::nothrow) net_route_private;
474 if (route == NULL)
475 return B_NO_MEMORY;
477 if (domain->address_module->copy_address(newRoute->destination,
478 &route->destination, (newRoute->flags & RTF_DEFAULT) != 0,
479 newRoute->mask) != B_OK
480 || domain->address_module->copy_address(newRoute->mask, &route->mask,
481 (newRoute->flags & RTF_DEFAULT) != 0, NULL) != B_OK
482 || domain->address_module->copy_address(newRoute->gateway,
483 &route->gateway, false, NULL) != B_OK) {
484 delete route;
485 return B_NO_MEMORY;
488 route->flags = newRoute->flags;
489 route->interface_address = newRoute->interface_address;
490 ((InterfaceAddress*)route->interface_address)->AcquireReference();
491 route->mtu = 0;
492 route->ref_count = 1;
494 // Insert the route sorted by completeness of its mask
496 RouteList::Iterator iterator = domain->routes.GetIterator();
497 net_route_private* before = NULL;
499 while ((before = iterator.Next()) != NULL) {
500 // if the before mask is less specific than the one of the route,
501 // we can insert it before that route.
502 if (domain->address_module->first_mask_bit(before->mask)
503 > domain->address_module->first_mask_bit(route->mask))
504 break;
506 if ((route->flags & RTF_DEFAULT) != 0
507 && (before->flags & RTF_DEFAULT) != 0) {
508 // both routes are equal - let the link speed decide the
509 // order
510 if (before->interface_address->interface->device->link_speed
511 < route->interface_address->interface->device->link_speed)
512 break;
516 domain->routes.Insert(before, route);
517 update_route_infos(domain);
519 return B_OK;
523 status_t
524 remove_route(struct net_domain* _domain, const struct net_route* removeRoute)
526 struct net_domain_private* domain = (net_domain_private*)_domain;
528 TRACE("remove route from domain %s: dest %s, mask %s, gw %s, flags %lx\n",
529 domain->name,
530 AddressString(domain, removeRoute->destination
531 ? removeRoute->destination : NULL).Data(),
532 AddressString(domain, removeRoute->mask
533 ? removeRoute->mask : NULL).Data(),
534 AddressString(domain, removeRoute->gateway
535 ? removeRoute->gateway : NULL).Data(),
536 removeRoute->flags);
538 RecursiveLocker locker(domain->lock);
540 net_route_private* route = find_route(domain, removeRoute);
541 if (route == NULL)
542 return B_ENTRY_NOT_FOUND;
544 domain->routes.Remove(route);
546 put_route_internal(domain, route);
547 update_route_infos(domain);
549 return B_OK;
553 status_t
554 get_route_information(struct net_domain* _domain, void* value, size_t length)
556 struct net_domain_private* domain = (net_domain_private*)_domain;
558 if (length < sizeof(route_entry))
559 return B_BAD_VALUE;
561 route_entry entry;
562 if (user_memcpy(&entry, value, sizeof(route_entry)) < B_OK)
563 return B_BAD_ADDRESS;
565 sockaddr_storage destination;
566 status_t status = user_copy_address(entry.destination, &destination);
567 if (status != B_OK)
568 return status;
570 RecursiveLocker locker(domain->lock);
572 net_route_private* route = find_route(domain, (sockaddr*)&destination);
573 if (route == NULL)
574 return B_ENTRY_NOT_FOUND;
576 status = fill_route_entry(&entry, value, length, route);
577 if (status != B_OK)
578 return status;
580 return user_memcpy(value, &entry, sizeof(route_entry));
584 void
585 invalidate_routes(net_domain* _domain, net_interface* interface)
587 net_domain_private* domain = (net_domain_private*)_domain;
588 RecursiveLocker locker(domain->lock);
590 TRACE("invalidate_routes(%i, %s)\n", domain->family, interface->name);
592 RouteList::Iterator iterator = domain->routes.GetIterator();
593 while (iterator.HasNext()) {
594 net_route* route = iterator.Next();
596 if (route->interface_address->interface == interface)
597 remove_route(domain, route);
602 void
603 invalidate_routes(InterfaceAddress* address)
605 net_domain_private* domain = (net_domain_private*)address->domain;
607 TRACE("invalidate_routes(%s)\n",
608 AddressString(domain, address->local).Data());
610 RecursiveLocker locker(domain->lock);
612 RouteList::Iterator iterator = domain->routes.GetIterator();
613 while (iterator.HasNext()) {
614 net_route* route = iterator.Next();
616 if (route->interface_address == address)
617 remove_route(domain, route);
622 struct net_route*
623 get_route(struct net_domain* _domain, const struct sockaddr* address)
625 struct net_domain_private* domain = (net_domain_private*)_domain;
626 RecursiveLocker locker(domain->lock);
628 return get_route_internal(domain, address);
632 status_t
633 get_device_route(struct net_domain* domain, uint32 index, net_route** _route)
635 Interface* interface = get_interface_for_device(domain, index);
636 if (interface == NULL)
637 return ENETUNREACH;
639 net_route_private* route
640 = &interface->DomainDatalink(domain->family)->direct_route;
642 atomic_add(&route->ref_count, 1);
643 *_route = route;
645 interface->ReleaseReference();
646 return B_OK;
650 status_t
651 get_buffer_route(net_domain* _domain, net_buffer* buffer, net_route** _route)
653 net_domain_private* domain = (net_domain_private*)_domain;
655 RecursiveLocker _(domain->lock);
657 net_route* route = get_route_internal(domain, buffer->destination);
658 if (route == NULL)
659 return ENETUNREACH;
661 status_t status = B_OK;
662 sockaddr* source = buffer->source;
664 // TODO: we are quite relaxed in the address checking here
665 // as we might proceed with source = INADDR_ANY.
667 if (route->interface_address != NULL
668 && route->interface_address->local != NULL) {
669 status = domain->address_module->update_to(source,
670 route->interface_address->local);
673 if (status != B_OK)
674 put_route_internal(domain, route);
675 else
676 *_route = route;
678 return status;
682 void
683 put_route(struct net_domain* _domain, net_route* route)
685 struct net_domain_private* domain = (net_domain_private*)_domain;
686 if (domain == NULL || route == NULL)
687 return;
689 RecursiveLocker locker(domain->lock);
691 put_route_internal(domain, (net_route*)route);
695 status_t
696 register_route_info(struct net_domain* _domain, struct net_route_info* info)
698 struct net_domain_private* domain = (net_domain_private*)_domain;
699 RecursiveLocker locker(domain->lock);
701 domain->route_infos.Add(info);
702 info->route = get_route_internal(domain, &info->address);
704 return B_OK;
708 status_t
709 unregister_route_info(struct net_domain* _domain, struct net_route_info* info)
711 struct net_domain_private* domain = (net_domain_private*)_domain;
712 RecursiveLocker locker(domain->lock);
714 domain->route_infos.Remove(info);
715 if (info->route != NULL)
716 put_route_internal(domain, info->route);
718 return B_OK;
722 status_t
723 update_route_info(struct net_domain* _domain, struct net_route_info* info)
725 struct net_domain_private* domain = (net_domain_private*)_domain;
726 RecursiveLocker locker(domain->lock);
728 put_route_internal(domain, info->route);
729 info->route = get_route_internal(domain, &info->address);
730 return B_OK;