2 * Copyright (c) 2006, 2020 Oracle and/or its affiliates.
4 * This software is available to you under a choice of one of two
5 * licenses. You may choose to be licensed under the terms of the GNU
6 * General Public License (GPL) Version 2, available from the file
7 * COPYING in the main directory of this source tree, or the
8 * OpenIB.org BSD license below:
10 * Redistribution and use in source and binary forms, with or
11 * without modification, are permitted provided that the following
14 * - Redistributions of source code must retain the above
15 * copyright notice, this list of conditions and the following
18 * - Redistributions in binary form must reproduce the above
19 * copyright notice, this list of conditions and the following
20 * disclaimer in the documentation and/or other materials
21 * provided with the distribution.
23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
33 #include <linux/kernel.h>
34 #include <linux/slab.h>
35 #include <linux/export.h>
36 #include <linux/skbuff.h>
37 #include <linux/list.h>
38 #include <linux/errqueue.h>
42 static unsigned int rds_exthdr_size
[__RDS_EXTHDR_MAX
] = {
43 [RDS_EXTHDR_NONE
] = 0,
44 [RDS_EXTHDR_VERSION
] = sizeof(struct rds_ext_header_version
),
45 [RDS_EXTHDR_RDMA
] = sizeof(struct rds_ext_header_rdma
),
46 [RDS_EXTHDR_RDMA_DEST
] = sizeof(struct rds_ext_header_rdma_dest
),
47 [RDS_EXTHDR_NPATHS
] = sizeof(u16
),
48 [RDS_EXTHDR_GEN_NUM
] = sizeof(u32
),
51 void rds_message_addref(struct rds_message
*rm
)
53 rdsdebug("addref rm %p ref %d\n", rm
, refcount_read(&rm
->m_refcount
));
54 refcount_inc(&rm
->m_refcount
);
56 EXPORT_SYMBOL_GPL(rds_message_addref
);
58 static inline bool rds_zcookie_add(struct rds_msg_zcopy_info
*info
, u32 cookie
)
60 struct rds_zcopy_cookies
*ck
= &info
->zcookies
;
61 int ncookies
= ck
->num
;
63 if (ncookies
== RDS_MAX_ZCOOKIES
)
65 ck
->cookies
[ncookies
] = cookie
;
70 static struct rds_msg_zcopy_info
*rds_info_from_znotifier(struct rds_znotifier
*znotif
)
72 return container_of(znotif
, struct rds_msg_zcopy_info
, znotif
);
75 void rds_notify_msg_zcopy_purge(struct rds_msg_zcopy_queue
*q
)
79 struct rds_msg_zcopy_info
*info
, *tmp
;
81 spin_lock_irqsave(&q
->lock
, flags
);
82 list_splice(&q
->zcookie_head
, ©
);
83 INIT_LIST_HEAD(&q
->zcookie_head
);
84 spin_unlock_irqrestore(&q
->lock
, flags
);
86 list_for_each_entry_safe(info
, tmp
, ©
, rs_zcookie_next
) {
87 list_del(&info
->rs_zcookie_next
);
92 static void rds_rm_zerocopy_callback(struct rds_sock
*rs
,
93 struct rds_znotifier
*znotif
)
95 struct rds_msg_zcopy_info
*info
;
96 struct rds_msg_zcopy_queue
*q
;
97 u32 cookie
= znotif
->z_cookie
;
98 struct rds_zcopy_cookies
*ck
;
99 struct list_head
*head
;
102 mm_unaccount_pinned_pages(&znotif
->z_mmp
);
103 q
= &rs
->rs_zcookie_queue
;
104 spin_lock_irqsave(&q
->lock
, flags
);
105 head
= &q
->zcookie_head
;
106 if (!list_empty(head
)) {
107 info
= list_entry(head
, struct rds_msg_zcopy_info
,
109 if (info
&& rds_zcookie_add(info
, cookie
)) {
110 spin_unlock_irqrestore(&q
->lock
, flags
);
111 kfree(rds_info_from_znotifier(znotif
));
112 /* caller invokes rds_wake_sk_sleep() */
117 info
= rds_info_from_znotifier(znotif
);
118 ck
= &info
->zcookies
;
119 memset(ck
, 0, sizeof(*ck
));
120 WARN_ON(!rds_zcookie_add(info
, cookie
));
121 list_add_tail(&q
->zcookie_head
, &info
->rs_zcookie_next
);
123 spin_unlock_irqrestore(&q
->lock
, flags
);
124 /* caller invokes rds_wake_sk_sleep() */
128 * This relies on dma_map_sg() not touching sg[].page during merging.
130 static void rds_message_purge(struct rds_message
*rm
)
132 unsigned long i
, flags
;
135 if (unlikely(test_bit(RDS_MSG_PAGEVEC
, &rm
->m_flags
)))
138 spin_lock_irqsave(&rm
->m_rs_lock
, flags
);
140 struct rds_sock
*rs
= rm
->m_rs
;
142 if (rm
->data
.op_mmp_znotifier
) {
144 rds_rm_zerocopy_callback(rs
, rm
->data
.op_mmp_znotifier
);
145 rds_wake_sk_sleep(rs
);
146 rm
->data
.op_mmp_znotifier
= NULL
;
148 sock_put(rds_rs_to_sk(rs
));
151 spin_unlock_irqrestore(&rm
->m_rs_lock
, flags
);
153 for (i
= 0; i
< rm
->data
.op_nents
; i
++) {
154 /* XXX will have to put_page for page refs */
156 __free_page(sg_page(&rm
->data
.op_sg
[i
]));
158 put_page(sg_page(&rm
->data
.op_sg
[i
]));
160 rm
->data
.op_nents
= 0;
162 if (rm
->rdma
.op_active
)
163 rds_rdma_free_op(&rm
->rdma
);
164 if (rm
->rdma
.op_rdma_mr
)
165 kref_put(&rm
->rdma
.op_rdma_mr
->r_kref
, __rds_put_mr_final
);
167 if (rm
->atomic
.op_active
)
168 rds_atomic_free_op(&rm
->atomic
);
169 if (rm
->atomic
.op_rdma_mr
)
170 kref_put(&rm
->atomic
.op_rdma_mr
->r_kref
, __rds_put_mr_final
);
173 void rds_message_put(struct rds_message
*rm
)
175 rdsdebug("put rm %p ref %d\n", rm
, refcount_read(&rm
->m_refcount
));
176 WARN(!refcount_read(&rm
->m_refcount
), "danger refcount zero on %p\n", rm
);
177 if (refcount_dec_and_test(&rm
->m_refcount
)) {
178 BUG_ON(!list_empty(&rm
->m_sock_item
));
179 BUG_ON(!list_empty(&rm
->m_conn_item
));
180 rds_message_purge(rm
);
185 EXPORT_SYMBOL_GPL(rds_message_put
);
187 void rds_message_populate_header(struct rds_header
*hdr
, __be16 sport
,
188 __be16 dport
, u64 seq
)
191 hdr
->h_sport
= sport
;
192 hdr
->h_dport
= dport
;
193 hdr
->h_sequence
= cpu_to_be64(seq
);
194 hdr
->h_exthdr
[0] = RDS_EXTHDR_NONE
;
196 EXPORT_SYMBOL_GPL(rds_message_populate_header
);
198 int rds_message_add_extension(struct rds_header
*hdr
, unsigned int type
,
199 const void *data
, unsigned int len
)
201 unsigned int ext_len
= sizeof(u8
) + len
;
204 /* For now, refuse to add more than one extension header */
205 if (hdr
->h_exthdr
[0] != RDS_EXTHDR_NONE
)
208 if (type
>= __RDS_EXTHDR_MAX
|| len
!= rds_exthdr_size
[type
])
211 if (ext_len
>= RDS_HEADER_EXT_SPACE
)
216 memcpy(dst
, data
, len
);
218 dst
[len
] = RDS_EXTHDR_NONE
;
221 EXPORT_SYMBOL_GPL(rds_message_add_extension
);
224 * If a message has extension headers, retrieve them here.
227 * unsigned int pos = 0;
230 * buflen = sizeof(buffer);
231 * type = rds_message_next_extension(hdr, &pos, buffer, &buflen);
232 * if (type == RDS_EXTHDR_NONE)
237 int rds_message_next_extension(struct rds_header
*hdr
,
238 unsigned int *pos
, void *buf
, unsigned int *buflen
)
240 unsigned int offset
, ext_type
, ext_len
;
241 u8
*src
= hdr
->h_exthdr
;
244 if (offset
>= RDS_HEADER_EXT_SPACE
)
247 /* Get the extension type and length. For now, the
248 * length is implied by the extension type. */
249 ext_type
= src
[offset
++];
251 if (ext_type
== RDS_EXTHDR_NONE
|| ext_type
>= __RDS_EXTHDR_MAX
)
253 ext_len
= rds_exthdr_size
[ext_type
];
254 if (offset
+ ext_len
> RDS_HEADER_EXT_SPACE
)
257 *pos
= offset
+ ext_len
;
258 if (ext_len
< *buflen
)
260 memcpy(buf
, src
+ offset
, *buflen
);
264 *pos
= RDS_HEADER_EXT_SPACE
;
266 return RDS_EXTHDR_NONE
;
269 int rds_message_add_rdma_dest_extension(struct rds_header
*hdr
, u32 r_key
, u32 offset
)
271 struct rds_ext_header_rdma_dest ext_hdr
;
273 ext_hdr
.h_rdma_rkey
= cpu_to_be32(r_key
);
274 ext_hdr
.h_rdma_offset
= cpu_to_be32(offset
);
275 return rds_message_add_extension(hdr
, RDS_EXTHDR_RDMA_DEST
, &ext_hdr
, sizeof(ext_hdr
));
277 EXPORT_SYMBOL_GPL(rds_message_add_rdma_dest_extension
);
280 * Each rds_message is allocated with extra space for the scatterlist entries
281 * rds ops will need. This is to minimize memory allocation count. Then, each rds op
282 * can grab SGs when initializing its part of the rds_message.
284 struct rds_message
*rds_message_alloc(unsigned int extra_len
, gfp_t gfp
)
286 struct rds_message
*rm
;
288 if (extra_len
> KMALLOC_MAX_SIZE
- sizeof(struct rds_message
))
291 rm
= kzalloc(sizeof(struct rds_message
) + extra_len
, gfp
);
296 rm
->m_total_sgs
= extra_len
/ sizeof(struct scatterlist
);
298 refcount_set(&rm
->m_refcount
, 1);
299 INIT_LIST_HEAD(&rm
->m_sock_item
);
300 INIT_LIST_HEAD(&rm
->m_conn_item
);
301 spin_lock_init(&rm
->m_rs_lock
);
302 init_waitqueue_head(&rm
->m_flush_wait
);
309 * RDS ops use this to grab SG entries from the rm's sg pool.
311 struct scatterlist
*rds_message_alloc_sgs(struct rds_message
*rm
, int nents
)
313 struct scatterlist
*sg_first
= (struct scatterlist
*) &rm
[1];
314 struct scatterlist
*sg_ret
;
317 pr_warn("rds: alloc sgs failed! nents <= 0\n");
318 return ERR_PTR(-EINVAL
);
321 if (rm
->m_used_sgs
+ nents
> rm
->m_total_sgs
) {
322 pr_warn("rds: alloc sgs failed! total %d used %d nents %d\n",
323 rm
->m_total_sgs
, rm
->m_used_sgs
, nents
);
324 return ERR_PTR(-ENOMEM
);
327 sg_ret
= &sg_first
[rm
->m_used_sgs
];
328 sg_init_table(sg_ret
, nents
);
329 rm
->m_used_sgs
+= nents
;
334 struct rds_message
*rds_message_map_pages(unsigned long *page_addrs
, unsigned int total_len
)
336 struct rds_message
*rm
;
338 int num_sgs
= DIV_ROUND_UP(total_len
, PAGE_SIZE
);
339 int extra_bytes
= num_sgs
* sizeof(struct scatterlist
);
341 rm
= rds_message_alloc(extra_bytes
, GFP_NOWAIT
);
343 return ERR_PTR(-ENOMEM
);
345 set_bit(RDS_MSG_PAGEVEC
, &rm
->m_flags
);
346 rm
->m_inc
.i_hdr
.h_len
= cpu_to_be32(total_len
);
347 rm
->data
.op_nents
= DIV_ROUND_UP(total_len
, PAGE_SIZE
);
348 rm
->data
.op_sg
= rds_message_alloc_sgs(rm
, num_sgs
);
349 if (IS_ERR(rm
->data
.op_sg
)) {
351 return ERR_CAST(rm
->data
.op_sg
);
354 for (i
= 0; i
< rm
->data
.op_nents
; ++i
) {
355 sg_set_page(&rm
->data
.op_sg
[i
],
356 virt_to_page(page_addrs
[i
]),
363 static int rds_message_zcopy_from_user(struct rds_message
*rm
, struct iov_iter
*from
)
365 struct scatterlist
*sg
;
367 int length
= iov_iter_count(from
);
368 int total_copied
= 0;
369 struct rds_msg_zcopy_info
*info
;
371 rm
->m_inc
.i_hdr
.h_len
= cpu_to_be32(iov_iter_count(from
));
374 * now allocate and copy in the data payload.
378 info
= kzalloc(sizeof(*info
), GFP_KERNEL
);
381 INIT_LIST_HEAD(&info
->rs_zcookie_next
);
382 rm
->data
.op_mmp_znotifier
= &info
->znotif
;
383 if (mm_account_pinned_pages(&rm
->data
.op_mmp_znotifier
->z_mmp
,
388 while (iov_iter_count(from
)) {
393 copied
= iov_iter_get_pages(from
, &pages
, PAGE_SIZE
,
399 for (i
= 0; i
< rm
->data
.op_nents
; i
++)
400 put_page(sg_page(&rm
->data
.op_sg
[i
]));
401 mmp
= &rm
->data
.op_mmp_znotifier
->z_mmp
;
402 mm_unaccount_pinned_pages(mmp
);
406 total_copied
+= copied
;
407 iov_iter_advance(from
, copied
);
409 sg_set_page(sg
, pages
, copied
, start
);
413 WARN_ON_ONCE(length
!= 0);
417 rm
->data
.op_mmp_znotifier
= NULL
;
421 int rds_message_copy_from_user(struct rds_message
*rm
, struct iov_iter
*from
,
424 unsigned long to_copy
, nbytes
;
425 unsigned long sg_off
;
426 struct scatterlist
*sg
;
429 rm
->m_inc
.i_hdr
.h_len
= cpu_to_be32(iov_iter_count(from
));
431 /* now allocate and copy in the data payload. */
433 sg_off
= 0; /* Dear gcc, sg->page will be null from kzalloc. */
436 return rds_message_zcopy_from_user(rm
, from
);
438 while (iov_iter_count(from
)) {
440 ret
= rds_page_remainder_alloc(sg
, iov_iter_count(from
),
448 to_copy
= min_t(unsigned long, iov_iter_count(from
),
449 sg
->length
- sg_off
);
451 rds_stats_add(s_copy_from_user
, to_copy
);
452 nbytes
= copy_page_from_iter(sg_page(sg
), sg
->offset
+ sg_off
,
454 if (nbytes
!= to_copy
)
459 if (sg_off
== sg
->length
)
466 int rds_message_inc_copy_to_user(struct rds_incoming
*inc
, struct iov_iter
*to
)
468 struct rds_message
*rm
;
469 struct scatterlist
*sg
;
470 unsigned long to_copy
;
471 unsigned long vec_off
;
476 rm
= container_of(inc
, struct rds_message
, m_inc
);
477 len
= be32_to_cpu(rm
->m_inc
.i_hdr
.h_len
);
483 while (iov_iter_count(to
) && copied
< len
) {
484 to_copy
= min_t(unsigned long, iov_iter_count(to
),
485 sg
->length
- vec_off
);
486 to_copy
= min_t(unsigned long, to_copy
, len
- copied
);
488 rds_stats_add(s_copy_to_user
, to_copy
);
489 ret
= copy_page_to_iter(sg_page(sg
), sg
->offset
+ vec_off
,
497 if (vec_off
== sg
->length
) {
507 * If the message is still on the send queue, wait until the transport
508 * is done with it. This is particularly important for RDMA operations.
510 void rds_message_wait(struct rds_message
*rm
)
512 wait_event_interruptible(rm
->m_flush_wait
,
513 !test_bit(RDS_MSG_MAPPED
, &rm
->m_flags
));
516 void rds_message_unmapped(struct rds_message
*rm
)
518 clear_bit(RDS_MSG_MAPPED
, &rm
->m_flags
);
519 wake_up_interruptible(&rm
->m_flush_wait
);
521 EXPORT_SYMBOL_GPL(rds_message_unmapped
);