1 // SPDX-License-Identifier: GPL-2.0
14 #include <linux/tls.h>
15 #include <linux/tcp.h>
16 #include <linux/socket.h>
18 #include <sys/types.h>
19 #include <sys/sendfile.h>
20 #include <sys/socket.h>
23 #include "../kselftest_harness.h"
25 #define TLS_PAYLOAD_MAX_LEN 16384
36 struct tls12_crypto_info_aes_gcm_128 tls12
;
37 struct sockaddr_in addr
;
44 memset(&tls12
, 0, sizeof(tls12
));
45 tls12
.info
.version
= TLS_1_2_VERSION
;
46 tls12
.info
.cipher_type
= TLS_CIPHER_AES_GCM_128
;
48 addr
.sin_family
= AF_INET
;
49 addr
.sin_addr
.s_addr
= htonl(INADDR_ANY
);
52 self
->fd
= socket(AF_INET
, SOCK_STREAM
, 0);
53 sfd
= socket(AF_INET
, SOCK_STREAM
, 0);
55 ret
= bind(sfd
, &addr
, sizeof(addr
));
57 ret
= listen(sfd
, 10);
60 ret
= getsockname(sfd
, &addr
, &len
);
63 ret
= connect(self
->fd
, &addr
, sizeof(addr
));
66 ret
= setsockopt(self
->fd
, IPPROTO_TCP
, TCP_ULP
, "tls", sizeof("tls"));
69 printf("Failure setting TCP_ULP, testing without tls\n");
73 ret
= setsockopt(self
->fd
, SOL_TLS
, TLS_TX
, &tls12
,
78 self
->cfd
= accept(sfd
, &addr
, &len
);
79 ASSERT_GE(self
->cfd
, 0);
82 ret
= setsockopt(self
->cfd
, IPPROTO_TCP
, TCP_ULP
, "tls",
86 ret
= setsockopt(self
->cfd
, SOL_TLS
, TLS_RX
, &tls12
,
100 TEST_F(tls
, sendfile
)
102 int filefd
= open("/proc/self/exe", O_RDONLY
);
105 EXPECT_GE(filefd
, 0);
107 EXPECT_GE(sendfile(self
->fd
, filefd
, 0, st
.st_size
), 0);
110 TEST_F(tls
, send_then_sendfile
)
112 int filefd
= open("/proc/self/exe", O_RDONLY
);
113 char const *test_str
= "test_send";
114 int to_send
= strlen(test_str
) + 1;
119 EXPECT_GE(filefd
, 0);
121 buf
= (char *)malloc(st
.st_size
);
123 EXPECT_EQ(send(self
->fd
, test_str
, to_send
, 0), to_send
);
124 EXPECT_EQ(recv(self
->cfd
, recv_buf
, to_send
, 0), to_send
);
125 EXPECT_EQ(memcmp(test_str
, recv_buf
, to_send
), 0);
127 EXPECT_GE(sendfile(self
->fd
, filefd
, 0, st
.st_size
), 0);
128 EXPECT_EQ(recv(self
->cfd
, buf
, st
.st_size
, 0), st
.st_size
);
131 TEST_F(tls
, recv_max
)
133 unsigned int send_len
= TLS_PAYLOAD_MAX_LEN
;
134 char recv_mem
[TLS_PAYLOAD_MAX_LEN
];
135 char buf
[TLS_PAYLOAD_MAX_LEN
];
137 EXPECT_GE(send(self
->fd
, buf
, send_len
, 0), 0);
138 EXPECT_NE(recv(self
->cfd
, recv_mem
, send_len
, 0), -1);
139 EXPECT_EQ(memcmp(buf
, recv_mem
, send_len
), 0);
142 TEST_F(tls
, recv_small
)
144 char const *test_str
= "test_read";
148 send_len
= strlen(test_str
) + 1;
149 EXPECT_EQ(send(self
->fd
, test_str
, send_len
, 0), send_len
);
150 EXPECT_NE(recv(self
->cfd
, buf
, send_len
, 0), -1);
151 EXPECT_EQ(memcmp(buf
, test_str
, send_len
), 0);
154 TEST_F(tls
, msg_more
)
156 char const *test_str
= "test_read";
160 EXPECT_EQ(send(self
->fd
, test_str
, send_len
, MSG_MORE
), send_len
);
161 EXPECT_EQ(recv(self
->cfd
, buf
, send_len
, MSG_DONTWAIT
), -1);
162 EXPECT_EQ(send(self
->fd
, test_str
, send_len
, 0), send_len
);
163 EXPECT_EQ(recv(self
->cfd
, buf
, send_len
* 2, MSG_DONTWAIT
),
165 EXPECT_EQ(memcmp(buf
, test_str
, send_len
), 0);
168 TEST_F(tls
, sendmsg_single
)
172 char const *test_str
= "test_sendmsg";
173 size_t send_len
= 13;
177 vec
.iov_base
= (char *)test_str
;
178 vec
.iov_len
= send_len
;
179 memset(&msg
, 0, sizeof(struct msghdr
));
182 EXPECT_EQ(sendmsg(self
->fd
, &msg
, 0), send_len
);
183 EXPECT_EQ(recv(self
->cfd
, buf
, send_len
, 0), send_len
);
184 EXPECT_EQ(memcmp(buf
, test_str
, send_len
), 0);
187 TEST_F(tls
, sendmsg_large
)
189 void *mem
= malloc(16384);
190 size_t send_len
= 16384;
196 memset(&msg
, 0, sizeof(struct msghdr
));
197 while (sent
++ < sends
) {
198 struct iovec vec
= { (void *)mem
, send_len
};
202 EXPECT_EQ(sendmsg(self
->cfd
, &msg
, 0), send_len
);
205 while (recvs
++ < sends
)
206 EXPECT_NE(recv(self
->fd
, mem
, send_len
, 0), -1);
211 TEST_F(tls
, sendmsg_multiple
)
213 char const *test_str
= "test_sendmsg_multiple";
223 memset(&msg
, 0, sizeof(struct msghdr
));
224 for (i
= 0; i
< iov_len
; i
++) {
225 test_strs
[i
] = (char *)malloc(strlen(test_str
) + 1);
226 snprintf(test_strs
[i
], strlen(test_str
) + 1, "%s", test_str
);
227 vec
[i
].iov_base
= (void *)test_strs
[i
];
228 vec
[i
].iov_len
= strlen(test_strs
[i
]) + 1;
229 total_len
+= vec
[i
].iov_len
;
232 msg
.msg_iovlen
= iov_len
;
234 EXPECT_EQ(sendmsg(self
->cfd
, &msg
, 0), total_len
);
235 buf
= malloc(total_len
);
236 EXPECT_NE(recv(self
->fd
, buf
, total_len
, 0), -1);
237 for (i
= 0; i
< iov_len
; i
++) {
238 EXPECT_EQ(memcmp(test_strs
[i
], buf
+ len_cmp
,
239 strlen(test_strs
[i
])),
241 len_cmp
+= strlen(buf
+ len_cmp
) + 1;
243 for (i
= 0; i
< iov_len
; i
++)
248 TEST_F(tls
, sendmsg_multiple_stress
)
250 char const *test_str
= "abcdefghijklmno";
251 struct iovec vec
[1024];
252 char *test_strs
[1024];
260 memset(&msg
, 0, sizeof(struct msghdr
));
261 for (i
= 0; i
< iov_len
; i
++) {
262 test_strs
[i
] = (char *)malloc(strlen(test_str
) + 1);
263 snprintf(test_strs
[i
], strlen(test_str
) + 1, "%s", test_str
);
264 vec
[i
].iov_base
= (void *)test_strs
[i
];
265 vec
[i
].iov_len
= strlen(test_strs
[i
]) + 1;
266 total_len
+= vec
[i
].iov_len
;
269 msg
.msg_iovlen
= iov_len
;
271 EXPECT_EQ(sendmsg(self
->fd
, &msg
, 0), total_len
);
272 EXPECT_NE(recv(self
->cfd
, buf
, total_len
, 0), -1);
274 for (i
= 0; i
< iov_len
; i
++)
275 len_cmp
+= strlen(buf
+ len_cmp
) + 1;
277 for (i
= 0; i
< iov_len
; i
++)
281 TEST_F(tls
, splice_from_pipe
)
283 int send_len
= TLS_PAYLOAD_MAX_LEN
;
284 char mem_send
[TLS_PAYLOAD_MAX_LEN
];
285 char mem_recv
[TLS_PAYLOAD_MAX_LEN
];
288 ASSERT_GE(pipe(p
), 0);
289 EXPECT_GE(write(p
[1], mem_send
, send_len
), 0);
290 EXPECT_GE(splice(p
[0], NULL
, self
->fd
, NULL
, send_len
, 0), 0);
291 EXPECT_EQ(recv(self
->cfd
, mem_recv
, send_len
, MSG_WAITALL
), send_len
);
292 EXPECT_EQ(memcmp(mem_send
, mem_recv
, send_len
), 0);
295 TEST_F(tls
, splice_from_pipe2
)
297 int send_len
= 16000;
298 char mem_send
[16000];
299 char mem_recv
[16000];
303 ASSERT_GE(pipe(p
), 0);
304 ASSERT_GE(pipe(p2
), 0);
305 EXPECT_GE(write(p
[1], mem_send
, 8000), 0);
306 EXPECT_GE(splice(p
[0], NULL
, self
->fd
, NULL
, 8000, 0), 0);
307 EXPECT_GE(write(p2
[1], mem_send
+ 8000, 8000), 0);
308 EXPECT_GE(splice(p2
[0], NULL
, self
->fd
, NULL
, 8000, 0), 0);
309 EXPECT_GE(recv(self
->cfd
, mem_recv
, send_len
, 0), 0);
310 EXPECT_EQ(memcmp(mem_send
, mem_recv
, send_len
), 0);
313 TEST_F(tls
, send_and_splice
)
315 int send_len
= TLS_PAYLOAD_MAX_LEN
;
316 char mem_send
[TLS_PAYLOAD_MAX_LEN
];
317 char mem_recv
[TLS_PAYLOAD_MAX_LEN
];
318 char const *test_str
= "test_read";
323 ASSERT_GE(pipe(p
), 0);
324 EXPECT_EQ(send(self
->fd
, test_str
, send_len2
, 0), send_len2
);
325 EXPECT_EQ(recv(self
->cfd
, buf
, send_len2
, MSG_WAITALL
), send_len2
);
326 EXPECT_EQ(memcmp(test_str
, buf
, send_len2
), 0);
328 EXPECT_GE(write(p
[1], mem_send
, send_len
), send_len
);
329 EXPECT_GE(splice(p
[0], NULL
, self
->fd
, NULL
, send_len
, 0), send_len
);
331 EXPECT_EQ(recv(self
->cfd
, mem_recv
, send_len
, MSG_WAITALL
), send_len
);
332 EXPECT_EQ(memcmp(mem_send
, mem_recv
, send_len
), 0);
335 TEST_F(tls
, splice_to_pipe
)
337 int send_len
= TLS_PAYLOAD_MAX_LEN
;
338 char mem_send
[TLS_PAYLOAD_MAX_LEN
];
339 char mem_recv
[TLS_PAYLOAD_MAX_LEN
];
342 ASSERT_GE(pipe(p
), 0);
343 EXPECT_GE(send(self
->fd
, mem_send
, send_len
, 0), 0);
344 EXPECT_GE(splice(self
->cfd
, NULL
, p
[1], NULL
, send_len
, 0), 0);
345 EXPECT_GE(read(p
[0], mem_recv
, send_len
), 0);
346 EXPECT_EQ(memcmp(mem_send
, mem_recv
, send_len
), 0);
349 TEST_F(tls
, recvmsg_single
)
351 char const *test_str
= "test_recvmsg_single";
352 int send_len
= strlen(test_str
) + 1;
357 memset(&hdr
, 0, sizeof(hdr
));
358 EXPECT_EQ(send(self
->fd
, test_str
, send_len
, 0), send_len
);
359 vec
.iov_base
= (char *)buf
;
360 vec
.iov_len
= send_len
;
363 EXPECT_NE(recvmsg(self
->cfd
, &hdr
, 0), -1);
364 EXPECT_EQ(memcmp(test_str
, buf
, send_len
), 0);
367 TEST_F(tls
, recvmsg_single_max
)
369 int send_len
= TLS_PAYLOAD_MAX_LEN
;
370 char send_mem
[TLS_PAYLOAD_MAX_LEN
];
371 char recv_mem
[TLS_PAYLOAD_MAX_LEN
];
375 EXPECT_EQ(send(self
->fd
, send_mem
, send_len
, 0), send_len
);
376 vec
.iov_base
= (char *)recv_mem
;
377 vec
.iov_len
= TLS_PAYLOAD_MAX_LEN
;
381 EXPECT_NE(recvmsg(self
->cfd
, &hdr
, 0), -1);
382 EXPECT_EQ(memcmp(send_mem
, recv_mem
, send_len
), 0);
385 TEST_F(tls
, recvmsg_multiple
)
387 unsigned int msg_iovlen
= 1024;
388 unsigned int len_compared
= 0;
389 struct iovec vec
[1024];
390 char *iov_base
[1024];
391 unsigned int iov_len
= 16;
392 int send_len
= 1 << 14;
397 EXPECT_EQ(send(self
->fd
, buf
, send_len
, 0), send_len
);
398 for (i
= 0; i
< msg_iovlen
; i
++) {
399 iov_base
[i
] = (char *)malloc(iov_len
);
400 vec
[i
].iov_base
= iov_base
[i
];
401 vec
[i
].iov_len
= iov_len
;
404 hdr
.msg_iovlen
= msg_iovlen
;
406 EXPECT_NE(recvmsg(self
->cfd
, &hdr
, 0), -1);
407 for (i
= 0; i
< msg_iovlen
; i
++)
408 len_compared
+= iov_len
;
410 for (i
= 0; i
< msg_iovlen
; i
++)
414 TEST_F(tls
, single_send_multiple_recv
)
416 unsigned int total_len
= TLS_PAYLOAD_MAX_LEN
* 2;
417 unsigned int send_len
= TLS_PAYLOAD_MAX_LEN
;
418 char send_mem
[TLS_PAYLOAD_MAX_LEN
* 2];
419 char recv_mem
[TLS_PAYLOAD_MAX_LEN
* 2];
421 EXPECT_GE(send(self
->fd
, send_mem
, total_len
, 0), 0);
422 memset(recv_mem
, 0, total_len
);
424 EXPECT_NE(recv(self
->cfd
, recv_mem
, send_len
, 0), -1);
425 EXPECT_NE(recv(self
->cfd
, recv_mem
+ send_len
, send_len
, 0), -1);
426 EXPECT_EQ(memcmp(send_mem
, recv_mem
, total_len
), 0);
429 TEST_F(tls
, multiple_send_single_recv
)
431 unsigned int total_len
= 2 * 10;
432 unsigned int send_len
= 10;
433 char recv_mem
[2 * 10];
436 EXPECT_GE(send(self
->fd
, send_mem
, send_len
, 0), 0);
437 EXPECT_GE(send(self
->fd
, send_mem
, send_len
, 0), 0);
438 memset(recv_mem
, 0, total_len
);
439 EXPECT_EQ(recv(self
->cfd
, recv_mem
, total_len
, 0), total_len
);
441 EXPECT_EQ(memcmp(send_mem
, recv_mem
, send_len
), 0);
442 EXPECT_EQ(memcmp(send_mem
, recv_mem
+ send_len
, send_len
), 0);
445 TEST_F(tls
, recv_partial
)
447 char const *test_str
= "test_read_partial";
448 char const *test_str_first
= "test_read";
449 char const *test_str_second
= "_partial";
450 int send_len
= strlen(test_str
) + 1;
453 memset(recv_mem
, 0, sizeof(recv_mem
));
454 EXPECT_EQ(send(self
->fd
, test_str
, send_len
, 0), send_len
);
455 EXPECT_NE(recv(self
->cfd
, recv_mem
, strlen(test_str_first
), 0), -1);
456 EXPECT_EQ(memcmp(test_str_first
, recv_mem
, strlen(test_str_first
)), 0);
457 memset(recv_mem
, 0, sizeof(recv_mem
));
458 EXPECT_NE(recv(self
->cfd
, recv_mem
, strlen(test_str_second
), 0), -1);
459 EXPECT_EQ(memcmp(test_str_second
, recv_mem
, strlen(test_str_second
)),
463 TEST_F(tls
, recv_nonblock
)
468 EXPECT_EQ(recv(self
->cfd
, buf
, sizeof(buf
), MSG_DONTWAIT
), -1);
469 err
= (errno
== EAGAIN
|| errno
== EWOULDBLOCK
);
470 EXPECT_EQ(err
, true);
473 TEST_F(tls
, recv_peek
)
475 char const *test_str
= "test_read_peek";
476 int send_len
= strlen(test_str
) + 1;
479 EXPECT_EQ(send(self
->fd
, test_str
, send_len
, 0), send_len
);
480 EXPECT_NE(recv(self
->cfd
, buf
, send_len
, MSG_PEEK
), -1);
481 EXPECT_EQ(memcmp(test_str
, buf
, send_len
), 0);
482 memset(buf
, 0, sizeof(buf
));
483 EXPECT_NE(recv(self
->cfd
, buf
, send_len
, 0), -1);
484 EXPECT_EQ(memcmp(test_str
, buf
, send_len
), 0);
487 TEST_F(tls
, recv_peek_multiple
)
489 char const *test_str
= "test_read_peek";
490 int send_len
= strlen(test_str
) + 1;
491 unsigned int num_peeks
= 100;
495 EXPECT_EQ(send(self
->fd
, test_str
, send_len
, 0), send_len
);
496 for (i
= 0; i
< num_peeks
; i
++) {
497 EXPECT_NE(recv(self
->cfd
, buf
, send_len
, MSG_PEEK
), -1);
498 EXPECT_EQ(memcmp(test_str
, buf
, send_len
), 0);
499 memset(buf
, 0, sizeof(buf
));
501 EXPECT_NE(recv(self
->cfd
, buf
, send_len
, 0), -1);
502 EXPECT_EQ(memcmp(test_str
, buf
, send_len
), 0);
505 TEST_F(tls
, recv_peek_multiple_records
)
507 char const *test_str
= "test_read_peek_mult_recs";
508 char const *test_str_first
= "test_read_peek";
509 char const *test_str_second
= "_mult_recs";
513 len
= strlen(test_str_first
);
514 EXPECT_EQ(send(self
->fd
, test_str_first
, len
, 0), len
);
516 len
= strlen(test_str_second
) + 1;
517 EXPECT_EQ(send(self
->fd
, test_str_second
, len
, 0), len
);
519 len
= strlen(test_str_first
);
521 EXPECT_EQ(recv(self
->cfd
, buf
, len
, MSG_PEEK
| MSG_WAITALL
), len
);
523 /* MSG_PEEK can only peek into the current record. */
524 len
= strlen(test_str_first
);
525 EXPECT_EQ(memcmp(test_str_first
, buf
, len
), 0);
527 len
= strlen(test_str
) + 1;
529 EXPECT_EQ(recv(self
->cfd
, buf
, len
, MSG_WAITALL
), len
);
531 /* Non-MSG_PEEK will advance strparser (and therefore record)
534 len
= strlen(test_str
) + 1;
535 EXPECT_EQ(memcmp(test_str
, buf
, len
), 0);
537 /* MSG_MORE will hold current record open, so later MSG_PEEK
538 * will see everything.
540 len
= strlen(test_str_first
);
541 EXPECT_EQ(send(self
->fd
, test_str_first
, len
, MSG_MORE
), len
);
543 len
= strlen(test_str_second
) + 1;
544 EXPECT_EQ(send(self
->fd
, test_str_second
, len
, 0), len
);
546 len
= strlen(test_str
) + 1;
548 EXPECT_EQ(recv(self
->cfd
, buf
, len
, MSG_PEEK
| MSG_WAITALL
), len
);
550 len
= strlen(test_str
) + 1;
551 EXPECT_EQ(memcmp(test_str
, buf
, len
), 0);
556 char const *test_str
= "test_poll";
557 struct pollfd fd
= { 0, 0, 0 };
561 EXPECT_EQ(send(self
->fd
, test_str
, send_len
, 0), send_len
);
565 EXPECT_EQ(poll(&fd
, 1, 20), 1);
566 EXPECT_EQ(fd
.revents
& POLLIN
, 1);
567 EXPECT_EQ(recv(self
->cfd
, buf
, send_len
, 0), send_len
);
568 /* Test timing out */
569 EXPECT_EQ(poll(&fd
, 1, 20), 0);
572 TEST_F(tls
, poll_wait
)
574 char const *test_str
= "test_poll_wait";
575 int send_len
= strlen(test_str
) + 1;
576 struct pollfd fd
= { 0, 0, 0 };
581 EXPECT_EQ(send(self
->fd
, test_str
, send_len
, 0), send_len
);
582 /* Set timeout to inf. secs */
583 EXPECT_EQ(poll(&fd
, 1, -1), 1);
584 EXPECT_EQ(fd
.revents
& POLLIN
, 1);
585 EXPECT_EQ(recv(self
->cfd
, recv_mem
, send_len
, 0), send_len
);
588 TEST_F(tls
, blocking
)
590 size_t data
= 100000;
603 int res
= send(self
->fd
, buf
,
604 left
> 16384 ? 16384 : left
, 0);
610 pid2
= wait(&status
);
611 EXPECT_EQ(status
, 0);
612 EXPECT_EQ(res
, pid2
);
619 int res
= recv(self
->cfd
, buf
,
620 left
> 16384 ? 16384 : left
, 0);
628 TEST_F(tls
, nonblocking
)
630 size_t data
= 100000;
635 flags
= fcntl(self
->fd
, F_GETFL
, 0);
636 fcntl(self
->fd
, F_SETFL
, flags
| O_NONBLOCK
);
637 fcntl(self
->cfd
, F_SETFL
, flags
| O_NONBLOCK
);
639 /* Ensure nonblocking behavior by imposing a small send
642 EXPECT_EQ(setsockopt(self
->fd
, SOL_SOCKET
, SO_SNDBUF
,
643 &sendbuf
, sizeof(sendbuf
)), 0);
657 int res
= send(self
->fd
, buf
,
658 left
> 16384 ? 16384 : left
, 0);
660 if (res
== -1 && errno
== EAGAIN
) {
670 pid2
= wait(&status
);
672 EXPECT_EQ(status
, 0);
673 EXPECT_EQ(res
, pid2
);
681 int res
= recv(self
->cfd
, buf
,
682 left
> 16384 ? 16384 : left
, 0);
684 if (res
== -1 && errno
== EAGAIN
) {
696 TEST_F(tls
, control_msg
)
701 char cbuf
[CMSG_SPACE(sizeof(char))];
702 char const *test_str
= "test_read";
703 int cmsg_len
= sizeof(char);
704 char record_type
= 100;
705 struct cmsghdr
*cmsg
;
711 vec
.iov_base
= (char *)test_str
;
713 memset(&msg
, 0, sizeof(struct msghdr
));
716 msg
.msg_control
= cbuf
;
717 msg
.msg_controllen
= sizeof(cbuf
);
718 cmsg
= CMSG_FIRSTHDR(&msg
);
719 cmsg
->cmsg_level
= SOL_TLS
;
720 /* test sending non-record types. */
721 cmsg
->cmsg_type
= TLS_SET_RECORD_TYPE
;
722 cmsg
->cmsg_len
= CMSG_LEN(cmsg_len
);
723 *CMSG_DATA(cmsg
) = record_type
;
724 msg
.msg_controllen
= cmsg
->cmsg_len
;
726 EXPECT_EQ(sendmsg(self
->fd
, &msg
, 0), send_len
);
727 /* Should fail because we didn't provide a control message */
728 EXPECT_EQ(recv(self
->cfd
, buf
, send_len
, 0), -1);
731 EXPECT_EQ(recvmsg(self
->cfd
, &msg
, 0), send_len
);
732 cmsg
= CMSG_FIRSTHDR(&msg
);
733 EXPECT_NE(cmsg
, NULL
);
734 EXPECT_EQ(cmsg
->cmsg_level
, SOL_TLS
);
735 EXPECT_EQ(cmsg
->cmsg_type
, TLS_GET_RECORD_TYPE
);
736 record_type
= *((unsigned char *)CMSG_DATA(cmsg
));
737 EXPECT_EQ(record_type
, 100);
738 EXPECT_EQ(memcmp(buf
, test_str
, send_len
), 0);