1 // SPDX-License-Identifier: GPL-2.0-only
3 * Copyright (c) 2021, 2022 Oracle. All rights reserved.
5 * The AUTH_TLS credential is used only to probe a remote peer
6 * for RPC-over-TLS support.
9 #include <linux/types.h>
10 #include <linux/module.h>
11 #include <linux/sunrpc/clnt.h>
13 static const char *starttls_token
= "STARTTLS";
14 static const size_t starttls_len
= 8;
16 static struct rpc_auth tls_auth
;
17 static struct rpc_cred tls_cred
;
19 static void tls_encode_probe(struct rpc_rqst
*rqstp
, struct xdr_stream
*xdr
,
24 static int tls_decode_probe(struct rpc_rqst
*rqstp
, struct xdr_stream
*xdr
,
30 static const struct rpc_procinfo rpcproc_tls_probe
= {
31 .p_encode
= tls_encode_probe
,
32 .p_decode
= tls_decode_probe
,
35 static void rpc_tls_probe_call_prepare(struct rpc_task
*task
, void *data
)
37 task
->tk_flags
&= ~RPC_TASK_NO_RETRANS_TIMEOUT
;
41 static void rpc_tls_probe_call_done(struct rpc_task
*task
, void *data
)
45 static const struct rpc_call_ops rpc_tls_probe_ops
= {
46 .rpc_call_prepare
= rpc_tls_probe_call_prepare
,
47 .rpc_call_done
= rpc_tls_probe_call_done
,
50 static int tls_probe(struct rpc_clnt
*clnt
)
52 struct rpc_message msg
= {
53 .rpc_proc
= &rpcproc_tls_probe
,
55 struct rpc_task_setup task_setup_data
= {
58 .rpc_op_cred
= &tls_cred
,
59 .callback_ops
= &rpc_tls_probe_ops
,
60 .flags
= RPC_TASK_SOFT
| RPC_TASK_SOFTCONN
,
62 struct rpc_task
*task
;
65 task
= rpc_run_task(&task_setup_data
);
68 status
= task
->tk_status
;
73 static struct rpc_auth
*tls_create(const struct rpc_auth_create_args
*args
,
74 struct rpc_clnt
*clnt
)
76 refcount_inc(&tls_auth
.au_count
);
80 static void tls_destroy(struct rpc_auth
*auth
)
84 static struct rpc_cred
*tls_lookup_cred(struct rpc_auth
*auth
,
85 struct auth_cred
*acred
, int flags
)
87 return get_rpccred(&tls_cred
);
90 static void tls_destroy_cred(struct rpc_cred
*cred
)
94 static int tls_match(struct auth_cred
*acred
, struct rpc_cred
*cred
, int taskflags
)
99 static int tls_marshal(struct rpc_task
*task
, struct xdr_stream
*xdr
)
103 p
= xdr_reserve_space(xdr
, 4 * XDR_UNIT
);
110 *p
++ = rpc_auth_null
;
115 static int tls_refresh(struct rpc_task
*task
)
117 set_bit(RPCAUTH_CRED_UPTODATE
, &task
->tk_rqstp
->rq_cred
->cr_flags
);
121 static int tls_validate(struct rpc_task
*task
, struct xdr_stream
*xdr
)
126 p
= xdr_inline_decode(xdr
, XDR_UNIT
);
129 if (*p
!= rpc_auth_null
)
131 if (xdr_stream_decode_opaque_inline(xdr
, &str
, starttls_len
) != starttls_len
)
132 return -EPROTONOSUPPORT
;
133 if (memcmp(str
, starttls_token
, starttls_len
))
134 return -EPROTONOSUPPORT
;
138 const struct rpc_authops authtls_ops
= {
139 .owner
= THIS_MODULE
,
140 .au_flavor
= RPC_AUTH_TLS
,
142 .create
= tls_create
,
143 .destroy
= tls_destroy
,
144 .lookup_cred
= tls_lookup_cred
,
148 static struct rpc_auth tls_auth
= {
149 .au_cslack
= NUL_CALLSLACK
,
150 .au_rslack
= NUL_REPLYSLACK
,
151 .au_verfsize
= NUL_REPLYSLACK
,
152 .au_ralign
= NUL_REPLYSLACK
,
153 .au_ops
= &authtls_ops
,
154 .au_flavor
= RPC_AUTH_TLS
,
155 .au_count
= REFCOUNT_INIT(1),
158 static const struct rpc_credops tls_credops
= {
159 .cr_name
= "AUTH_TLS",
160 .crdestroy
= tls_destroy_cred
,
161 .crmatch
= tls_match
,
162 .crmarshal
= tls_marshal
,
163 .crwrap_req
= rpcauth_wrap_req_encode
,
164 .crrefresh
= tls_refresh
,
165 .crvalidate
= tls_validate
,
166 .crunwrap_resp
= rpcauth_unwrap_resp_decode
,
169 static struct rpc_cred tls_cred
= {
170 .cr_lru
= LIST_HEAD_INIT(tls_cred
.cr_lru
),
171 .cr_auth
= &tls_auth
,
172 .cr_ops
= &tls_credops
,
173 .cr_count
= REFCOUNT_INIT(2),
174 .cr_flags
= 1UL << RPCAUTH_CRED_UPTODATE
,