diff --git a/fs/smb/common/smbdirect/smbdirect_connection.c b/fs/smb/common/smbdirect/smbdirect_connection.c index 573dc278ca71..8290e45464e3 100644 --- a/fs/smb/common/smbdirect/smbdirect_connection.c +++ b/fs/smb/common/smbdirect/smbdirect_connection.c @@ -6,6 +6,113 @@ #include "smbdirect_internal.h" +static void smbdirect_connection_destroy_mem_pools(struct smbdirect_socket *sc); + +__maybe_unused /* this is temporary while this file is included in others */ +static int smbdirect_connection_create_mem_pools(struct smbdirect_socket *sc) +{ + const struct smbdirect_socket_parameters *sp = &sc->parameters; + char name[80]; + size_t i; + + /* + * We use sizeof(struct smbdirect_negotiate_resp) for the + * payload size as it is larger as + * sizeof(struct smbdirect_data_transfer). + * + * This will fit client and server usage for now. + */ + snprintf(name, sizeof(name), "smbdirect_send_io_cache_%p", sc); + struct kmem_cache_args send_io_args = { + .align = __alignof__(struct smbdirect_send_io), + }; + sc->send_io.mem.cache = kmem_cache_create(name, + sizeof(struct smbdirect_send_io) + + sizeof(struct smbdirect_negotiate_resp), + &send_io_args, + SLAB_HWCACHE_ALIGN); + if (!sc->send_io.mem.cache) + goto err; + + sc->send_io.mem.pool = mempool_create_slab_pool(sp->send_credit_target, + sc->send_io.mem.cache); + if (!sc->send_io.mem.pool) + goto err; + + /* + * A payload size of sp->max_recv_size should fit + * any message. + * + * For smbdirect_data_transfer messages the whole + * buffer might be exposed to userspace + * (currently on the client side...) + * The documentation says data_offset = 0 would be + * strange but valid. + */ + snprintf(name, sizeof(name), "smbdirect_recv_io_cache_%p", sc); + struct kmem_cache_args recv_io_args = { + .align = __alignof__(struct smbdirect_recv_io), + .useroffset = sizeof(struct smbdirect_recv_io), + .usersize = sp->max_recv_size, + }; + sc->recv_io.mem.cache = kmem_cache_create(name, + sizeof(struct smbdirect_recv_io) + + sp->max_recv_size, + &recv_io_args, + SLAB_HWCACHE_ALIGN); + if (!sc->recv_io.mem.cache) + goto err; + + sc->recv_io.mem.pool = mempool_create_slab_pool(sp->recv_credit_max, + sc->recv_io.mem.cache); + if (!sc->recv_io.mem.pool) + goto err; + + for (i = 0; i < sp->recv_credit_max; i++) { + struct smbdirect_recv_io *recv_io; + + recv_io = mempool_alloc(sc->recv_io.mem.pool, + sc->recv_io.mem.gfp_mask); + if (!recv_io) + goto err; + recv_io->socket = sc; + recv_io->sge.length = 0; + list_add_tail(&recv_io->list, &sc->recv_io.free.list); + } + + return 0; +err: + smbdirect_connection_destroy_mem_pools(sc); + return -ENOMEM; +} + +static void smbdirect_connection_destroy_mem_pools(struct smbdirect_socket *sc) +{ + struct smbdirect_recv_io *recv_io, *next_io; + + list_for_each_entry_safe(recv_io, next_io, &sc->recv_io.free.list, list) { + list_del(&recv_io->list); + mempool_free(recv_io, sc->recv_io.mem.pool); + } + + /* + * Note mempool_destroy() and kmem_cache_destroy() + * work fine with a NULL pointer + */ + + mempool_destroy(sc->recv_io.mem.pool); + sc->recv_io.mem.pool = NULL; + + kmem_cache_destroy(sc->recv_io.mem.cache); + sc->recv_io.mem.cache = NULL; + + mempool_destroy(sc->send_io.mem.pool); + sc->send_io.mem.pool = NULL; + + kmem_cache_destroy(sc->send_io.mem.cache); + sc->send_io.mem.cache = NULL; +} + __maybe_unused /* this is temporary while this file is included in others */ static struct smbdirect_send_io *smbdirect_connection_alloc_send_io(struct smbdirect_socket *sc) {