1 // SPDX-License-Identifier: GPL-2.0-or-later
5 * Authors: Mina Almasry <almasrymina@google.com>
6 * Willem de Bruijn <willemdebruijn.kernel@gmail.com>
7 * Kaiyuan Zhang <kaiyuanz@google.com
10 #include <linux/dma-buf.h>
11 #include <linux/genalloc.h>
13 #include <linux/netdevice.h>
14 #include <linux/types.h>
15 #include <net/netdev_queues.h>
16 #include <net/netdev_rx_queue.h>
17 #include <net/page_pool/helpers.h>
18 #include <trace/events/page_pool.h>
21 #include "mp_dmabuf_devmem.h"
22 #include "page_pool_priv.h"
24 /* Device memory support */
26 /* Protected by rtnl_lock() */
27 static DEFINE_XARRAY_FLAGS(net_devmem_dmabuf_bindings
, XA_FLAGS_ALLOC1
);
29 static void net_devmem_dmabuf_free_chunk_owner(struct gen_pool
*genpool
,
30 struct gen_pool_chunk
*chunk
,
33 struct dmabuf_genpool_chunk_owner
*owner
= chunk
->owner
;
39 static dma_addr_t
net_devmem_get_dma_addr(const struct net_iov
*niov
)
41 struct dmabuf_genpool_chunk_owner
*owner
= net_iov_owner(niov
);
43 return owner
->base_dma_addr
+
44 ((dma_addr_t
)net_iov_idx(niov
) << PAGE_SHIFT
);
47 void __net_devmem_dmabuf_binding_free(struct net_devmem_dmabuf_binding
*binding
)
51 gen_pool_for_each_chunk(binding
->chunk_pool
,
52 net_devmem_dmabuf_free_chunk_owner
, NULL
);
54 size
= gen_pool_size(binding
->chunk_pool
);
55 avail
= gen_pool_avail(binding
->chunk_pool
);
57 if (!WARN(size
!= avail
, "can't destroy genpool. size=%zu, avail=%zu",
59 gen_pool_destroy(binding
->chunk_pool
);
61 dma_buf_unmap_attachment_unlocked(binding
->attachment
, binding
->sgt
,
63 dma_buf_detach(binding
->dmabuf
, binding
->attachment
);
64 dma_buf_put(binding
->dmabuf
);
65 xa_destroy(&binding
->bound_rxqs
);
70 net_devmem_alloc_dmabuf(struct net_devmem_dmabuf_binding
*binding
)
72 struct dmabuf_genpool_chunk_owner
*owner
;
73 unsigned long dma_addr
;
78 dma_addr
= gen_pool_alloc_owner(binding
->chunk_pool
, PAGE_SIZE
,
83 offset
= dma_addr
- owner
->base_dma_addr
;
84 index
= offset
/ PAGE_SIZE
;
85 niov
= &owner
->niovs
[index
];
89 atomic_long_set(&niov
->pp_ref_count
, 0);
94 void net_devmem_free_dmabuf(struct net_iov
*niov
)
96 struct net_devmem_dmabuf_binding
*binding
= net_iov_binding(niov
);
97 unsigned long dma_addr
= net_devmem_get_dma_addr(niov
);
99 if (WARN_ON(!gen_pool_has_addr(binding
->chunk_pool
, dma_addr
,
103 gen_pool_free(binding
->chunk_pool
, dma_addr
, PAGE_SIZE
);
106 void net_devmem_unbind_dmabuf(struct net_devmem_dmabuf_binding
*binding
)
108 struct netdev_rx_queue
*rxq
;
109 unsigned long xa_idx
;
110 unsigned int rxq_idx
;
112 if (binding
->list
.next
)
113 list_del(&binding
->list
);
115 xa_for_each(&binding
->bound_rxqs
, xa_idx
, rxq
) {
116 WARN_ON(rxq
->mp_params
.mp_priv
!= binding
);
118 rxq
->mp_params
.mp_priv
= NULL
;
120 rxq_idx
= get_netdev_rx_queue_index(rxq
);
122 WARN_ON(netdev_rx_queue_restart(binding
->dev
, rxq_idx
));
125 xa_erase(&net_devmem_dmabuf_bindings
, binding
->id
);
127 net_devmem_dmabuf_binding_put(binding
);
130 int net_devmem_bind_dmabuf_to_queue(struct net_device
*dev
, u32 rxq_idx
,
131 struct net_devmem_dmabuf_binding
*binding
,
132 struct netlink_ext_ack
*extack
)
134 struct netdev_rx_queue
*rxq
;
138 if (rxq_idx
>= dev
->real_num_rx_queues
) {
139 NL_SET_ERR_MSG(extack
, "rx queue index out of range");
143 rxq
= __netif_get_rx_queue(dev
, rxq_idx
);
144 if (rxq
->mp_params
.mp_priv
) {
145 NL_SET_ERR_MSG(extack
, "designated queue already memory provider bound");
149 #ifdef CONFIG_XDP_SOCKETS
151 NL_SET_ERR_MSG(extack
, "designated queue already in use by AF_XDP");
156 err
= xa_alloc(&binding
->bound_rxqs
, &xa_idx
, rxq
, xa_limit_32b
,
161 rxq
->mp_params
.mp_priv
= binding
;
163 err
= netdev_rx_queue_restart(dev
, rxq_idx
);
170 rxq
->mp_params
.mp_priv
= NULL
;
171 xa_erase(&binding
->bound_rxqs
, xa_idx
);
176 struct net_devmem_dmabuf_binding
*
177 net_devmem_bind_dmabuf(struct net_device
*dev
, unsigned int dmabuf_fd
,
178 struct netlink_ext_ack
*extack
)
180 struct net_devmem_dmabuf_binding
*binding
;
181 static u32 id_alloc_next
;
182 struct scatterlist
*sg
;
183 struct dma_buf
*dmabuf
;
184 unsigned int sg_idx
, i
;
185 unsigned long virtual;
188 dmabuf
= dma_buf_get(dmabuf_fd
);
190 return ERR_CAST(dmabuf
);
192 binding
= kzalloc_node(sizeof(*binding
), GFP_KERNEL
,
193 dev_to_node(&dev
->dev
));
201 err
= xa_alloc_cyclic(&net_devmem_dmabuf_bindings
, &binding
->id
,
202 binding
, xa_limit_32b
, &id_alloc_next
,
205 goto err_free_binding
;
207 xa_init_flags(&binding
->bound_rxqs
, XA_FLAGS_ALLOC
);
209 refcount_set(&binding
->ref
, 1);
211 binding
->dmabuf
= dmabuf
;
213 binding
->attachment
= dma_buf_attach(binding
->dmabuf
, dev
->dev
.parent
);
214 if (IS_ERR(binding
->attachment
)) {
215 err
= PTR_ERR(binding
->attachment
);
216 NL_SET_ERR_MSG(extack
, "Failed to bind dmabuf to device");
220 binding
->sgt
= dma_buf_map_attachment_unlocked(binding
->attachment
,
222 if (IS_ERR(binding
->sgt
)) {
223 err
= PTR_ERR(binding
->sgt
);
224 NL_SET_ERR_MSG(extack
, "Failed to map dmabuf attachment");
228 /* For simplicity we expect to make PAGE_SIZE allocations, but the
229 * binding can be much more flexible than that. We may be able to
230 * allocate MTU sized chunks here. Leave that for future work...
232 binding
->chunk_pool
=
233 gen_pool_create(PAGE_SHIFT
, dev_to_node(&dev
->dev
));
234 if (!binding
->chunk_pool
) {
240 for_each_sgtable_dma_sg(binding
->sgt
, sg
, sg_idx
) {
241 dma_addr_t dma_addr
= sg_dma_address(sg
);
242 struct dmabuf_genpool_chunk_owner
*owner
;
243 size_t len
= sg_dma_len(sg
);
244 struct net_iov
*niov
;
246 owner
= kzalloc_node(sizeof(*owner
), GFP_KERNEL
,
247 dev_to_node(&dev
->dev
));
250 goto err_free_chunks
;
253 owner
->base_virtual
= virtual;
254 owner
->base_dma_addr
= dma_addr
;
255 owner
->num_niovs
= len
/ PAGE_SIZE
;
256 owner
->binding
= binding
;
258 err
= gen_pool_add_owner(binding
->chunk_pool
, dma_addr
,
259 dma_addr
, len
, dev_to_node(&dev
->dev
),
264 goto err_free_chunks
;
267 owner
->niovs
= kvmalloc_array(owner
->num_niovs
,
268 sizeof(*owner
->niovs
),
272 goto err_free_chunks
;
275 for (i
= 0; i
< owner
->num_niovs
; i
++) {
276 niov
= &owner
->niovs
[i
];
278 page_pool_set_dma_addr_netmem(net_iov_to_netmem(niov
),
279 net_devmem_get_dma_addr(niov
));
288 gen_pool_for_each_chunk(binding
->chunk_pool
,
289 net_devmem_dmabuf_free_chunk_owner
, NULL
);
290 gen_pool_destroy(binding
->chunk_pool
);
292 dma_buf_unmap_attachment_unlocked(binding
->attachment
, binding
->sgt
,
295 dma_buf_detach(dmabuf
, binding
->attachment
);
297 xa_erase(&net_devmem_dmabuf_bindings
, binding
->id
);
305 void dev_dmabuf_uninstall(struct net_device
*dev
)
307 struct net_devmem_dmabuf_binding
*binding
;
308 struct netdev_rx_queue
*rxq
;
309 unsigned long xa_idx
;
312 for (i
= 0; i
< dev
->real_num_rx_queues
; i
++) {
313 binding
= dev
->_rx
[i
].mp_params
.mp_priv
;
317 xa_for_each(&binding
->bound_rxqs
, xa_idx
, rxq
)
318 if (rxq
== &dev
->_rx
[i
]) {
319 xa_erase(&binding
->bound_rxqs
, xa_idx
);
325 /*** "Dmabuf devmem memory provider" ***/
327 int mp_dmabuf_devmem_init(struct page_pool
*pool
)
329 struct net_devmem_dmabuf_binding
*binding
= pool
->mp_priv
;
340 if (pool
->p
.order
!= 0)
343 net_devmem_dmabuf_binding_get(binding
);
347 netmem_ref
mp_dmabuf_devmem_alloc_netmems(struct page_pool
*pool
, gfp_t gfp
)
349 struct net_devmem_dmabuf_binding
*binding
= pool
->mp_priv
;
350 struct net_iov
*niov
;
353 niov
= net_devmem_alloc_dmabuf(binding
);
357 netmem
= net_iov_to_netmem(niov
);
359 page_pool_set_pp_info(pool
, netmem
);
361 pool
->pages_state_hold_cnt
++;
362 trace_page_pool_state_hold(pool
, netmem
, pool
->pages_state_hold_cnt
);
366 void mp_dmabuf_devmem_destroy(struct page_pool
*pool
)
368 struct net_devmem_dmabuf_binding
*binding
= pool
->mp_priv
;
370 net_devmem_dmabuf_binding_put(binding
);
373 bool mp_dmabuf_devmem_release_page(struct page_pool
*pool
, netmem_ref netmem
)
375 long refcount
= atomic_long_read(netmem_get_pp_ref_count_ref(netmem
));
377 if (WARN_ON_ONCE(!netmem_is_net_iov(netmem
)))
380 if (WARN_ON_ONCE(refcount
!= 1))
383 page_pool_clear_pp_info(netmem
);
385 net_devmem_free_dmabuf(netmem_to_net_iov(netmem
));
387 /* We don't want the page pool put_page()ing our net_iovs. */