1 // SPDX-License-Identifier: GPL-2.0-only
5 * Copyright (c) 2017-present, Facebook, Inc.
7 #include <linux/crypto.h>
8 #include <linux/init.h>
9 #include <linux/interrupt.h>
11 #include <linux/module.h>
12 #include <linux/net.h>
13 #include <linux/vmalloc.h>
14 #include <linux/zstd.h>
15 #include <crypto/internal/scompress.h>
18 #define ZSTD_DEF_LEVEL 3
27 static ZSTD_parameters
zstd_params(void)
29 return ZSTD_getParams(ZSTD_DEF_LEVEL
, 0, 0);
32 static int zstd_comp_init(struct zstd_ctx
*ctx
)
35 const ZSTD_parameters params
= zstd_params();
36 const size_t wksp_size
= ZSTD_CCtxWorkspaceBound(params
.cParams
);
38 ctx
->cwksp
= vzalloc(wksp_size
);
44 ctx
->cctx
= ZSTD_initCCtx(ctx
->cwksp
, wksp_size
);
56 static int zstd_decomp_init(struct zstd_ctx
*ctx
)
59 const size_t wksp_size
= ZSTD_DCtxWorkspaceBound();
61 ctx
->dwksp
= vzalloc(wksp_size
);
67 ctx
->dctx
= ZSTD_initDCtx(ctx
->dwksp
, wksp_size
);
79 static void zstd_comp_exit(struct zstd_ctx
*ctx
)
86 static void zstd_decomp_exit(struct zstd_ctx
*ctx
)
93 static int __zstd_init(void *ctx
)
97 ret
= zstd_comp_init(ctx
);
100 ret
= zstd_decomp_init(ctx
);
106 static void *zstd_alloc_ctx(struct crypto_scomp
*tfm
)
109 struct zstd_ctx
*ctx
;
111 ctx
= kzalloc(sizeof(*ctx
), GFP_KERNEL
);
113 return ERR_PTR(-ENOMEM
);
115 ret
= __zstd_init(ctx
);
124 static int zstd_init(struct crypto_tfm
*tfm
)
126 struct zstd_ctx
*ctx
= crypto_tfm_ctx(tfm
);
128 return __zstd_init(ctx
);
131 static void __zstd_exit(void *ctx
)
134 zstd_decomp_exit(ctx
);
137 static void zstd_free_ctx(struct crypto_scomp
*tfm
, void *ctx
)
140 kfree_sensitive(ctx
);
143 static void zstd_exit(struct crypto_tfm
*tfm
)
145 struct zstd_ctx
*ctx
= crypto_tfm_ctx(tfm
);
150 static int __zstd_compress(const u8
*src
, unsigned int slen
,
151 u8
*dst
, unsigned int *dlen
, void *ctx
)
154 struct zstd_ctx
*zctx
= ctx
;
155 const ZSTD_parameters params
= zstd_params();
157 out_len
= ZSTD_compressCCtx(zctx
->cctx
, dst
, *dlen
, src
, slen
, params
);
158 if (ZSTD_isError(out_len
))
164 static int zstd_compress(struct crypto_tfm
*tfm
, const u8
*src
,
165 unsigned int slen
, u8
*dst
, unsigned int *dlen
)
167 struct zstd_ctx
*ctx
= crypto_tfm_ctx(tfm
);
169 return __zstd_compress(src
, slen
, dst
, dlen
, ctx
);
172 static int zstd_scompress(struct crypto_scomp
*tfm
, const u8
*src
,
173 unsigned int slen
, u8
*dst
, unsigned int *dlen
,
176 return __zstd_compress(src
, slen
, dst
, dlen
, ctx
);
179 static int __zstd_decompress(const u8
*src
, unsigned int slen
,
180 u8
*dst
, unsigned int *dlen
, void *ctx
)
183 struct zstd_ctx
*zctx
= ctx
;
185 out_len
= ZSTD_decompressDCtx(zctx
->dctx
, dst
, *dlen
, src
, slen
);
186 if (ZSTD_isError(out_len
))
192 static int zstd_decompress(struct crypto_tfm
*tfm
, const u8
*src
,
193 unsigned int slen
, u8
*dst
, unsigned int *dlen
)
195 struct zstd_ctx
*ctx
= crypto_tfm_ctx(tfm
);
197 return __zstd_decompress(src
, slen
, dst
, dlen
, ctx
);
200 static int zstd_sdecompress(struct crypto_scomp
*tfm
, const u8
*src
,
201 unsigned int slen
, u8
*dst
, unsigned int *dlen
,
204 return __zstd_decompress(src
, slen
, dst
, dlen
, ctx
);
207 static struct crypto_alg alg
= {
209 .cra_driver_name
= "zstd-generic",
210 .cra_flags
= CRYPTO_ALG_TYPE_COMPRESS
,
211 .cra_ctxsize
= sizeof(struct zstd_ctx
),
212 .cra_module
= THIS_MODULE
,
213 .cra_init
= zstd_init
,
214 .cra_exit
= zstd_exit
,
215 .cra_u
= { .compress
= {
216 .coa_compress
= zstd_compress
,
217 .coa_decompress
= zstd_decompress
} }
220 static struct scomp_alg scomp
= {
221 .alloc_ctx
= zstd_alloc_ctx
,
222 .free_ctx
= zstd_free_ctx
,
223 .compress
= zstd_scompress
,
224 .decompress
= zstd_sdecompress
,
227 .cra_driver_name
= "zstd-scomp",
228 .cra_module
= THIS_MODULE
,
232 static int __init
zstd_mod_init(void)
236 ret
= crypto_register_alg(&alg
);
240 ret
= crypto_register_scomp(&scomp
);
242 crypto_unregister_alg(&alg
);
247 static void __exit
zstd_mod_fini(void)
249 crypto_unregister_alg(&alg
);
250 crypto_unregister_scomp(&scomp
);
253 subsys_initcall(zstd_mod_init
);
254 module_exit(zstd_mod_fini
);
256 MODULE_LICENSE("GPL");
257 MODULE_DESCRIPTION("Zstd Compression Algorithm");
258 MODULE_ALIAS_CRYPTO("zstd");