2 * Copyright 2006-2009, Haiku, Inc. All Rights Reserved.
3 * Distributed under the terms of the MIT License.
6 * Axel Dörfler, axeld@pinc-software.de
7 * Andrew Galante, haiku.galante@gmail.com
8 * Hugo Santos, hugosantos@gmail.com
12 #include "EndpointManager.h"
13 #include "TCPEndpoint.h"
15 #include <net_protocol.h>
18 #include <KernelExport.h>
19 #include <util/list.h>
21 #include <netinet/in.h>
22 #include <netinet/ip.h>
28 #include <util/AutoLock.h>
30 #include <NetBufferUtilities.h>
31 #include <NetUtilities.h>
36 # define TRACE(x) dprintf x
37 # define TRACE_BLOCK(x) dump_block x
40 # define TRACE_BLOCK(x)
44 typedef NetBufferField
<uint16
, offsetof(tcp_header
, checksum
)> TCPChecksumField
;
47 net_buffer_module_info
*gBufferModule
;
48 net_datalink_module_info
*gDatalinkModule
;
49 net_socket_module_info
*gSocketModule
;
50 net_stack_module_info
*gStackModule
;
53 static EndpointManager
* sEndpointManagers
[AF_MAX
];
54 static rw_lock sEndpointManagersLock
;
57 // The TCP header length is at most 64 bytes.
58 static const int kMaxOptionSize
= 64 - sizeof(tcp_header
);
61 /*! Returns an endpoint manager for the specified domain, if any.
62 You need to hold the sEndpointManagersLock when calling this function.
64 static inline EndpointManager
*
65 endpoint_manager_for_locked(int family
)
67 if (family
>= AF_MAX
|| family
< 0)
70 return sEndpointManagers
[family
];
74 /*! Returns an endpoint manager for the specified domain, if any */
75 static inline EndpointManager
*
76 endpoint_manager_for(net_domain
* domain
)
78 ReadLocker
_(sEndpointManagersLock
);
80 return endpoint_manager_for_locked(domain
->family
);
85 bump_option(tcp_option
*&option
, size_t &length
)
87 if (option
->kind
<= TCP_OPTION_NOP
) {
89 option
= (tcp_option
*)((uint8
*)option
+ 1);
91 length
+= option
->length
;
92 option
= (tcp_option
*)((uint8
*)option
+ option
->length
);
98 add_options(tcp_segment_header
&segment
, uint8
*buffer
, size_t bufferSize
)
100 tcp_option
*option
= (tcp_option
*)buffer
;
103 if (segment
.max_segment_size
> 0 && length
+ 8 <= bufferSize
) {
104 option
->kind
= TCP_OPTION_MAX_SEGMENT_SIZE
;
106 option
->max_segment_size
= htons(segment
.max_segment_size
);
107 bump_option(option
, length
);
110 if ((segment
.options
& TCP_HAS_TIMESTAMPS
) != 0
111 && length
+ 12 <= bufferSize
) {
112 // two NOPs so the timestamps get aligned to a 4 byte boundary
113 option
->kind
= TCP_OPTION_NOP
;
114 bump_option(option
, length
);
115 option
->kind
= TCP_OPTION_NOP
;
116 bump_option(option
, length
);
117 option
->kind
= TCP_OPTION_TIMESTAMP
;
119 option
->timestamp
.value
= htonl(segment
.timestamp_value
);
120 option
->timestamp
.reply
= htonl(segment
.timestamp_reply
);
121 bump_option(option
, length
);
124 if ((segment
.options
& TCP_HAS_WINDOW_SCALE
) != 0
125 && length
+ 4 <= bufferSize
) {
126 // insert one NOP so that the subsequent data is aligned on a 4 byte boundary
127 option
->kind
= TCP_OPTION_NOP
;
128 bump_option(option
, length
);
130 option
->kind
= TCP_OPTION_WINDOW_SHIFT
;
132 option
->window_shift
= segment
.window_shift
;
133 bump_option(option
, length
);
136 if ((segment
.options
& TCP_SACK_PERMITTED
) != 0
137 && length
+ 2 <= bufferSize
) {
138 option
->kind
= TCP_OPTION_SACK_PERMITTED
;
140 bump_option(option
, length
);
143 if (segment
.sack_count
> 0) {
144 int sackCount
= ((int)(bufferSize
- length
) - 4) / sizeof(tcp_sack
);
145 if (sackCount
> segment
.sack_count
)
146 sackCount
= segment
.sack_count
;
149 option
->kind
= TCP_OPTION_NOP
;
150 bump_option(option
, length
);
151 option
->kind
= TCP_OPTION_NOP
;
152 bump_option(option
, length
);
153 option
->kind
= TCP_OPTION_SACK
;
154 option
->length
= 2 + sackCount
* sizeof(tcp_sack
);
155 memcpy(option
->sack
, segment
.sacks
, sackCount
* sizeof(tcp_sack
));
156 bump_option(option
, length
);
160 if ((length
& 3) == 0) {
161 // options completely fill out the option space
165 option
->kind
= TCP_OPTION_END
;
166 return (length
+ 3) & ~3;
167 // bump to a multiple of 4 length
172 process_options(tcp_segment_header
&segment
, net_buffer
*buffer
, size_t size
)
179 uint8 optionsBuffer
[kMaxOptionSize
];
180 if (gBufferModule
->direct_access(buffer
, sizeof(tcp_header
), size
,
181 (void **)&option
) != B_OK
) {
182 if ((size_t)size
> sizeof(optionsBuffer
)) {
183 dprintf("Ignoring TCP options larger than expected.\n");
187 gBufferModule
->read(buffer
, sizeof(tcp_header
), optionsBuffer
, size
);
188 option
= (tcp_option
*)optionsBuffer
;
194 switch (option
->kind
) {
199 case TCP_OPTION_MAX_SEGMENT_SIZE
:
200 if (option
->length
== 4 && size
>= 4)
201 segment
.max_segment_size
= ntohs(option
->max_segment_size
);
203 case TCP_OPTION_WINDOW_SHIFT
:
204 if (option
->length
== 3 && size
>= 3) {
205 segment
.options
|= TCP_HAS_WINDOW_SCALE
;
206 segment
.window_shift
= option
->window_shift
;
209 case TCP_OPTION_TIMESTAMP
:
210 if (option
->length
== 10 && size
>= 10) {
211 segment
.options
|= TCP_HAS_TIMESTAMPS
;
212 segment
.timestamp_value
= ntohl(option
->timestamp
.value
);
213 segment
.timestamp_reply
=
214 ntohl(option
->timestamp
.reply
);
217 case TCP_OPTION_SACK_PERMITTED
:
218 if (option
->length
== 2 && size
>= 2)
219 segment
.options
|= TCP_SACK_PERMITTED
;
224 length
= option
->length
;
225 if (length
== 0 || length
> (ssize_t
)size
)
229 option
= (tcp_option
*)((uint8
*)option
+ length
);
237 dump_tcp_header(tcp_header
&header
)
239 dprintf(" source port: %u\n", ntohs(header
.source_port
));
240 dprintf(" dest port: %u\n", ntohs(header
.destination_port
));
241 dprintf(" sequence: %lu\n", header
.Sequence());
242 dprintf(" ack: %lu\n", header
.Acknowledge());
243 dprintf(" flags: %s%s%s%s%s%s\n", (header
.flags
& TCP_FLAG_FINISH
) ? "FIN " : "",
244 (header
.flags
& TCP_FLAG_SYNCHRONIZE
) ? "SYN " : "",
245 (header
.flags
& TCP_FLAG_RESET
) ? "RST " : "",
246 (header
.flags
& TCP_FLAG_PUSH
) ? "PUSH " : "",
247 (header
.flags
& TCP_FLAG_ACKNOWLEDGE
) ? "ACK " : "",
248 (header
.flags
& TCP_FLAG_URGENT
) ? "URG " : "");
249 dprintf(" window: %u\n", header
.AdvertisedWindow());
250 dprintf(" urgent offset: %u\n", header
.UrgentOffset());
256 dump_endpoints(int argc
, char** argv
)
258 for (int i
= 0; i
< AF_MAX
; i
++) {
259 EndpointManager
* manager
= sEndpointManagers
[i
];
269 dump_endpoint(int argc
, char** argv
)
272 kprintf("usage: tcp_endpoint [address]\n");
276 TCPEndpoint
* endpoint
= (TCPEndpoint
*)parse_expression(argv
[1]);
283 // #pragma mark - internal API
286 /*! Creates a new endpoint manager for the specified domain, or returns
287 an existing one for this domain.
290 get_endpoint_manager(net_domain
* domain
)
292 // See if there is one already
293 EndpointManager
* endpointManager
= endpoint_manager_for(domain
);
294 if (endpointManager
!= NULL
)
295 return endpointManager
;
297 WriteLocker
_(sEndpointManagersLock
);
299 endpointManager
= endpoint_manager_for_locked(domain
->family
);
300 if (endpointManager
!= NULL
)
301 return endpointManager
;
303 // There is no endpoint manager for this domain yet, so we need
306 endpointManager
= new(std::nothrow
) EndpointManager(domain
);
307 if (endpointManager
== NULL
)
310 if (endpointManager
->Init() != B_OK
) {
311 delete endpointManager
;
315 sEndpointManagers
[domain
->family
] = endpointManager
;
316 return endpointManager
;
321 put_endpoint_manager(EndpointManager
* endpointManager
)
323 // TODO: we may want to use reference counting instead of only discarding
324 // them on unload. But since there is likely only IPv4/v6 there is not much
330 name_for_state(tcp_state state
)
337 case SYNCHRONIZE_SENT
:
339 case SYNCHRONIZE_RECEIVED
:
340 return "syn-received";
342 return "established";
344 // peer closes the connection
345 case FINISH_RECEIVED
:
347 case WAIT_FOR_FINISH_ACKNOWLEDGE
:
350 // we close the connection
353 case FINISH_ACKNOWLEDGED
:
366 /*! Constructs a TCP header on \a buffer with the specified values
367 for \a flags, \a seq \a ack and \a advertisedWindow.
370 add_tcp_header(net_address_module_info
* addressModule
,
371 tcp_segment_header
& segment
, net_buffer
* buffer
)
373 buffer
->protocol
= IPPROTO_TCP
;
375 uint8 optionsBuffer
[kMaxOptionSize
];
376 uint32 optionsLength
= add_options(segment
, optionsBuffer
,
377 sizeof(optionsBuffer
));
379 NetBufferPrepend
<tcp_header
> bufferHeader(buffer
,
380 sizeof(tcp_header
) + optionsLength
);
381 if (bufferHeader
.Status() != B_OK
)
382 return bufferHeader
.Status();
384 tcp_header
& header
= bufferHeader
.Data();
386 header
.source_port
= addressModule
->get_port(buffer
->source
);
387 header
.destination_port
= addressModule
->get_port(buffer
->destination
);
388 header
.sequence
= htonl(segment
.sequence
);
389 header
.acknowledge
= (segment
.flags
& TCP_FLAG_ACKNOWLEDGE
)
390 ? htonl(segment
.acknowledge
) : 0;
392 header
.header_length
= (sizeof(tcp_header
) + optionsLength
) >> 2;
393 header
.flags
= segment
.flags
;
394 header
.advertised_window
= htons(segment
.advertised_window
);
396 header
.urgent_offset
= htons(segment
.urgent_offset
);
398 // we must detach before calculating the checksum as we may
399 // not have a contiguous buffer.
402 if (optionsLength
> 0) {
403 gBufferModule
->write(buffer
, sizeof(tcp_header
), optionsBuffer
,
407 TRACE(("add_tcp_header(): buffer %p, flags 0x%x, seq %lu, ack %lu, up %u, "
408 "win %u\n", buffer
, segment
.flags
, segment
.sequence
,
409 segment
.acknowledge
, segment
.urgent_offset
, segment
.advertised_window
));
411 *TCPChecksumField(buffer
) = Checksum::PseudoHeader(addressModule
,
412 gBufferModule
, buffer
, IPPROTO_TCP
);
419 tcp_options_length(tcp_segment_header
& segment
)
423 if (segment
.max_segment_size
> 0)
426 if (segment
.options
& TCP_HAS_TIMESTAMPS
)
429 if (segment
.options
& TCP_HAS_WINDOW_SCALE
)
432 if (segment
.options
& TCP_SACK_PERMITTED
)
435 if (segment
.sack_count
> 0) {
436 int sackCount
= min_c((int)((kMaxOptionSize
- length
- 4)
437 / sizeof(tcp_sack
)), segment
.sack_count
);
439 length
+= 4 + sackCount
* sizeof(tcp_sack
);
442 if ((length
& 3) == 0)
445 return (length
+ 3) & ~3;
449 // #pragma mark - protocol API
453 tcp_init_protocol(net_socket
* socket
)
455 socket
->send
.buffer_size
= 32768;
456 // override net_socket default
458 TCPEndpoint
* protocol
= new (std::nothrow
) TCPEndpoint(socket
);
459 if (protocol
== NULL
)
462 if (protocol
->InitCheck() != B_OK
) {
467 TRACE(("Creating new TCPEndpoint: %p\n", protocol
));
468 socket
->protocol
= IPPROTO_TCP
;
474 tcp_uninit_protocol(net_protocol
* protocol
)
476 TRACE(("Deleting TCPEndpoint: %p\n", protocol
));
477 delete (TCPEndpoint
*)protocol
;
483 tcp_open(net_protocol
* protocol
)
485 return ((TCPEndpoint
*)protocol
)->Open();
490 tcp_close(net_protocol
* protocol
)
492 return ((TCPEndpoint
*)protocol
)->Close();
497 tcp_free(net_protocol
* protocol
)
499 ((TCPEndpoint
*)protocol
)->Free();
505 tcp_connect(net_protocol
* protocol
, const struct sockaddr
* address
)
507 return ((TCPEndpoint
*)protocol
)->Connect(address
);
512 tcp_accept(net_protocol
* protocol
, struct net_socket
** _acceptedSocket
)
514 return ((TCPEndpoint
*)protocol
)->Accept(_acceptedSocket
);
519 tcp_control(net_protocol
* _protocol
, int level
, int option
, void* value
,
522 TCPEndpoint
* protocol
= (TCPEndpoint
*)_protocol
;
524 if ((level
& LEVEL_MASK
) == IPPROTO_TCP
) {
525 if (option
== NET_STAT_SOCKET
)
526 return protocol
->FillStat((net_stat
*)value
);
529 return protocol
->next
->module
->control(protocol
->next
, level
, option
,
535 tcp_getsockopt(net_protocol
* _protocol
, int level
, int option
, void* value
,
538 TCPEndpoint
* protocol
= (TCPEndpoint
*)_protocol
;
540 if (level
== IPPROTO_TCP
)
541 return protocol
->GetOption(option
, value
, _length
);
543 return protocol
->next
->module
->getsockopt(protocol
->next
, level
, option
,
549 tcp_setsockopt(net_protocol
* _protocol
, int level
, int option
,
550 const void* _value
, int length
)
552 TCPEndpoint
* protocol
= (TCPEndpoint
*)_protocol
;
554 if (level
== SOL_SOCKET
) {
555 if (option
== SO_SNDBUF
|| option
== SO_RCVBUF
) {
556 if (length
!= sizeof(int))
560 const int* value
= (const int*)_value
;
562 if (option
== SO_SNDBUF
)
563 status
= protocol
->SetSendBufferSize(*value
);
565 status
= protocol
->SetReceiveBufferSize(*value
);
570 } else if (level
== IPPROTO_TCP
)
571 return protocol
->SetOption(option
, _value
, length
);
573 return protocol
->next
->module
->setsockopt(protocol
->next
, level
, option
,
579 tcp_bind(net_protocol
* protocol
, const struct sockaddr
* address
)
581 return ((TCPEndpoint
*)protocol
)->Bind(address
);
586 tcp_unbind(net_protocol
* protocol
, struct sockaddr
* address
)
588 return ((TCPEndpoint
*)protocol
)->Unbind(address
);
593 tcp_listen(net_protocol
* protocol
, int count
)
595 return ((TCPEndpoint
*)protocol
)->Listen(count
);
600 tcp_shutdown(net_protocol
* protocol
, int direction
)
602 return ((TCPEndpoint
*)protocol
)->Shutdown(direction
);
607 tcp_send_data(net_protocol
* protocol
, net_buffer
* buffer
)
609 return ((TCPEndpoint
*)protocol
)->SendData(buffer
);
614 tcp_send_routed_data(net_protocol
* protocol
, struct net_route
* route
,
617 // TCP never sends routed data
623 tcp_send_avail(net_protocol
* protocol
)
625 return ((TCPEndpoint
*)protocol
)->SendAvailable();
630 tcp_read_data(net_protocol
* protocol
, size_t numBytes
, uint32 flags
,
631 net_buffer
** _buffer
)
633 return ((TCPEndpoint
*)protocol
)->ReadData(numBytes
, flags
, _buffer
);
638 tcp_read_avail(net_protocol
* protocol
)
640 return ((TCPEndpoint
*)protocol
)->ReadAvailable();
645 tcp_get_domain(net_protocol
* protocol
)
647 return protocol
->next
->module
->get_domain(protocol
->next
);
652 tcp_get_mtu(net_protocol
* protocol
, const struct sockaddr
* address
)
654 return protocol
->next
->module
->get_mtu(protocol
->next
, address
);
659 tcp_receive_data(net_buffer
* buffer
)
661 TRACE(("TCP: Received buffer %p\n", buffer
));
663 if (buffer
->interface_address
== NULL
664 || buffer
->interface_address
->domain
== NULL
)
667 net_domain
* domain
= buffer
->interface_address
->domain
;
668 net_address_module_info
* addressModule
= domain
->address_module
;
670 NetBufferHeaderReader
<tcp_header
> bufferHeader(buffer
);
671 if (bufferHeader
.Status() < B_OK
)
672 return bufferHeader
.Status();
674 tcp_header
& header
= bufferHeader
.Data();
676 uint16 headerLength
= header
.HeaderLength();
677 if (headerLength
< sizeof(tcp_header
))
680 if (Checksum::PseudoHeader(addressModule
, gBufferModule
, buffer
,
684 addressModule
->set_port(buffer
->source
, header
.source_port
);
685 addressModule
->set_port(buffer
->destination
, header
.destination_port
);
687 TRACE((" Looking for: peer %s, local %s\n",
688 AddressString(domain
, buffer
->source
, true).Data(),
689 AddressString(domain
, buffer
->destination
, true).Data()));
690 //dump_tcp_header(header);
691 //gBufferModule->dump(buffer);
693 tcp_segment_header
segment(header
.flags
);
694 segment
.sequence
= header
.Sequence();
695 segment
.acknowledge
= header
.Acknowledge();
696 segment
.advertised_window
= header
.AdvertisedWindow();
697 segment
.urgent_offset
= header
.UrgentOffset();
698 process_options(segment
, buffer
, headerLength
- sizeof(tcp_header
));
700 bufferHeader
.Remove(headerLength
);
701 // we no longer need to keep the header around
703 EndpointManager
* endpointManager
= endpoint_manager_for(domain
);
704 if (endpointManager
== NULL
) {
705 TRACE((" No endpoint manager!\n"));
709 int32 segmentAction
= DROP
;
711 TCPEndpoint
* endpoint
= endpointManager
->FindConnection(
712 buffer
->destination
, buffer
->source
);
713 if (endpoint
!= NULL
) {
714 segmentAction
= endpoint
->SegmentReceived(segment
, buffer
);
715 gSocketModule
->release_socket(endpoint
->socket
);
716 } else if ((segment
.flags
& TCP_FLAG_RESET
) == 0)
717 segmentAction
= DROP
| RESET
;
719 if ((segmentAction
& RESET
) != 0) {
721 endpointManager
->ReplyWithReset(segment
, buffer
);
723 if ((segmentAction
& DROP
) != 0)
724 gBufferModule
->free(buffer
);
731 tcp_error_received(net_error error
, net_buffer
* data
)
738 tcp_error_reply(net_protocol
* protocol
, net_buffer
* cause
, net_error error
,
739 net_error_data
* errorData
)
751 rw_lock_init(&sEndpointManagersLock
, "endpoint managers");
753 status_t status
= gStackModule
->register_domain_protocols(AF_INET
,
755 "network/protocols/tcp/v1",
756 "network/protocols/ipv4/v1",
760 status
= gStackModule
->register_domain_protocols(AF_INET6
,
762 "network/protocols/tcp/v1",
763 "network/protocols/ipv6/v1",
768 status
= gStackModule
->register_domain_protocols(AF_INET
, SOCK_STREAM
,
770 "network/protocols/tcp/v1",
771 "network/protocols/ipv4/v1",
775 status
= gStackModule
->register_domain_protocols(AF_INET6
, SOCK_STREAM
,
777 "network/protocols/tcp/v1",
778 "network/protocols/ipv6/v1",
783 status
= gStackModule
->register_domain_receiving_protocol(AF_INET
,
784 IPPROTO_TCP
, "network/protocols/tcp/v1");
787 status
= gStackModule
->register_domain_receiving_protocol(AF_INET6
,
788 IPPROTO_TCP
, "network/protocols/tcp/v1");
792 add_debugger_command("tcp_endpoints", dump_endpoints
,
793 "lists all open TCP endpoints");
794 add_debugger_command("tcp_endpoint", dump_endpoint
,
795 "dumps a TCP endpoint internal state");
804 remove_debugger_command("tcp_endpoint", dump_endpoint
);
805 remove_debugger_command("tcp_endpoints", dump_endpoints
);
807 rw_lock_destroy(&sEndpointManagersLock
);
809 for (int i
= 0; i
< AF_MAX
; i
++) {
810 delete sEndpointManagers
[i
];
818 tcp_std_ops(int32 op
, ...)
824 case B_MODULE_UNINIT
:
833 net_protocol_module_info sTCPModule
= {
835 "network/protocols/tcp/v1",
856 tcp_send_routed_data
,
863 NULL
, // deliver_data()
866 NULL
, // add_ancillary_data()
867 NULL
, // process_ancillary_data()
868 NULL
, // process_ancillary_data_no_container()
869 NULL
, // send_data_no_buffer()
870 NULL
// read_data_no_buffer()
873 module_dependency module_dependencies
[] = {
874 {NET_STACK_MODULE_NAME
, (module_info
**)&gStackModule
},
875 {NET_BUFFER_MODULE_NAME
, (module_info
**)&gBufferModule
},
876 {NET_DATALINK_MODULE_NAME
, (module_info
**)&gDatalinkModule
},
877 {NET_SOCKET_MODULE_NAME
, (module_info
**)&gSocketModule
},
881 module_info
*modules
[] = {
882 (module_info
*)&sTCPModule
,