1 // SPDX-License-Identifier: GPL-2.0-only
3 * Copyright (C) 2024, SUSE LLC
5 * Authors: Enzo Matsumiya <ematsumiya@suse.de>
7 * This file implements I/O compression support for SMB2 messages (SMB 3.1.1 only).
8 * See compress/ for implementation details of each algorithm.
11 * MS-SMB2 "3.1.4.4 Compressing the Message"
12 * MS-SMB2 "3.1.5.3 Decompressing the Chained Message"
13 * MS-XCA - for details of the supported algorithms
15 #include <linux/slab.h>
16 #include <linux/kernel.h>
17 #include <linux/uio.h>
18 #include <linux/sort.h>
21 #include "../common/smb2pdu.h"
22 #include "cifsproto.h"
23 #include "smb2proto.h"
25 #include "compress/lz77.h"
29 * The heuristic_*() functions below try to determine data compressibility.
31 * Derived from fs/btrfs/compression.c, changing coding style, some parameters, and removing
34 * Read that file for better and more detailed explanation of the calculations.
36 * The algorithms are ran in a collected sample of the input (uncompressed) data.
37 * The sample is formed of 2K reads in PAGE_SIZE intervals, with a maximum size of 4M.
39 * Parsing the sample goes from "low-hanging fruits" (fastest algorithms, likely compressible)
40 * to "need more analysis" (likely uncompressible).
48 * has_low_entropy() - Compute Shannon entropy of the sampled data.
49 * @bkt: Bytes counts of the sample.
50 * @slen: Size of the sample.
52 * Return: true if the level (percentage of number of bits that would be required to
53 * compress the data) is below the minimum threshold.
56 * There _is_ an entropy level here that's > 65 (minimum threshold) that would indicate a
57 * possibility of compression, but compressing, or even further analysing, it would waste so much
58 * resources that it's simply not worth it.
60 * Also Shannon entropy is the last computed heuristic; if we got this far and ended up
61 * with uncertainty, just stay on the safe side and call it uncompressible.
63 static bool has_low_entropy(struct bucket
*bkt
, size_t slen
)
65 const size_t threshold
= 65, max_entropy
= 8 * ilog2(16);
66 size_t i
, p
, p2
, len
, sum
= 0;
68 #define pow4(n) (n * n * n * n)
69 len
= ilog2(pow4(slen
));
71 for (i
= 0; i
< 256 && bkt
[i
].count
> 0; i
++) {
74 sum
+= p
* (len
- p2
);
79 return ((sum
* 100 / max_entropy
) <= threshold
);
82 #define BYTE_DIST_BAD 0
83 #define BYTE_DIST_GOOD 1
84 #define BYTE_DIST_MAYBE 2
86 * calc_byte_distribution() - Compute byte distribution on the sampled data.
87 * @bkt: Byte counts of the sample.
88 * @slen: Size of the sample.
91 * BYTE_DIST_BAD: A "hard no" for compression -- a computed uniform distribution of
92 * the bytes (e.g. random or encrypted data).
93 * BYTE_DIST_GOOD: High probability (normal (Gaussian) distribution) of the data being
95 * BYTE_DIST_MAYBE: When computed byte distribution resulted in "low > n < high"
96 * grounds. has_low_entropy() should be used for a final decision.
98 static int calc_byte_distribution(struct bucket
*bkt
, size_t slen
)
100 const size_t low
= 64, high
= 200, threshold
= slen
* 90 / 100;
104 for (i
= 0; i
< low
; i
++)
108 return BYTE_DIST_BAD
;
110 for (; i
< high
&& bkt
[i
].count
> 0; i
++) {
117 return BYTE_DIST_GOOD
;
120 return BYTE_DIST_BAD
;
122 return BYTE_DIST_MAYBE
;
125 static bool is_mostly_ascii(const struct bucket
*bkt
)
130 for (i
= 0; i
< 256; i
++)
131 if (bkt
[i
].count
> 0)
132 /* Too many non-ASCII (0-63) bytes. */
139 static bool has_repeated_data(const u8
*sample
, size_t len
)
143 return (!memcmp(&sample
[0], &sample
[s
], s
));
146 static int cmp_bkt(const void *_a
, const void *_b
)
148 const struct bucket
*a
= _a
, *b
= _b
;
151 if (a
->count
> b
->count
)
159 * Support other iter types, if required.
160 * Only ITER_XARRAY is supported for now.
162 static int collect_sample(const struct iov_iter
*iter
, ssize_t max
, u8
*sample
)
164 struct folio
*folios
[16], *folio
;
165 unsigned int nr
, i
, j
, npages
;
166 loff_t start
= iter
->xarray_start
+ iter
->iov_offset
;
167 pgoff_t last
, index
= start
/ PAGE_SIZE
;
168 size_t len
, off
, foff
;
172 last
= (start
+ max
- 1) / PAGE_SIZE
;
174 nr
= xa_extract(iter
->xarray
, (void **)folios
, index
, last
, ARRAY_SIZE(folios
),
179 for (i
= 0; i
< nr
; i
++) {
181 npages
= folio_nr_pages(folio
);
182 foff
= start
- folio_pos(folio
);
183 off
= foff
% PAGE_SIZE
;
185 for (j
= foff
/ PAGE_SIZE
; j
< npages
; j
++) {
188 len
= min_t(size_t, max
, PAGE_SIZE
- off
);
189 len2
= min_t(size_t, len
, SZ_2K
);
191 p
= kmap_local_page(folio_page(folio
, j
));
192 memcpy(&sample
[s
], p
, len2
);
197 if (len2
< SZ_2K
|| s
>= max
- SZ_2K
)
209 } while (nr
== ARRAY_SIZE(folios
));
215 * is_compressible() - Determines if a chunk of data is compressible.
216 * @data: Iterator containing uncompressed data.
218 * Return: true if @data is compressible, false otherwise.
220 * Tests shows that this function is quite reliable in predicting data compressibility,
221 * matching close to 1:1 with the behaviour of LZ77 compression success and failures.
223 static bool is_compressible(const struct iov_iter
*data
)
225 const size_t read_size
= SZ_2K
, bkt_size
= 256, max
= SZ_4M
;
226 struct bucket
*bkt
= NULL
;
232 /* Preventive double check -- already checked in should_compress(). */
233 len
= iov_iter_count(data
);
234 if (unlikely(len
< read_size
))
237 if (len
- read_size
> max
)
240 sample
= kvzalloc(len
, GFP_KERNEL
);
247 /* Sample 2K bytes per page of the uncompressed data. */
248 i
= collect_sample(data
, len
, sample
);
258 if (has_repeated_data(sample
, len
))
261 bkt
= kcalloc(bkt_size
, sizeof(*bkt
), GFP_KERNEL
);
269 for (i
= 0; i
< len
; i
++)
270 bkt
[sample
[i
]].count
++;
272 if (is_mostly_ascii(bkt
))
275 /* Sort in descending order */
276 sort(bkt
, bkt_size
, sizeof(*bkt
), cmp_bkt
, NULL
);
278 i
= calc_byte_distribution(bkt
, len
);
279 if (i
!= BYTE_DIST_MAYBE
) {
285 ret
= has_low_entropy(bkt
, len
);
293 bool should_compress(const struct cifs_tcon
*tcon
, const struct smb_rqst
*rq
)
295 const struct smb2_hdr
*shdr
= rq
->rq_iov
->iov_base
;
297 if (unlikely(!tcon
|| !tcon
->ses
|| !tcon
->ses
->server
))
300 if (!tcon
->ses
->server
->compression
.enabled
)
303 if (!(tcon
->share_flags
& SMB2_SHAREFLAG_COMPRESS_DATA
))
306 if (shdr
->Command
== SMB2_WRITE
) {
307 const struct smb2_write_req
*wreq
= rq
->rq_iov
->iov_base
;
309 if (le32_to_cpu(wreq
->Length
) < SMB_COMPRESS_MIN_LEN
)
312 return is_compressible(&rq
->rq_iter
);
315 return (shdr
->Command
== SMB2_READ
);
318 int smb_compress(struct TCP_Server_Info
*server
, struct smb_rqst
*rq
, compress_send_fn send_fn
)
320 struct iov_iter iter
;
322 void *src
, *dst
= NULL
;
325 if (!server
|| !rq
|| !rq
->rq_iov
|| !rq
->rq_iov
->iov_base
)
328 if (rq
->rq_iov
->iov_len
!= sizeof(struct smb2_write_req
))
331 slen
= iov_iter_count(&rq
->rq_iter
);
332 src
= kvzalloc(slen
, GFP_KERNEL
);
338 /* Keep the original iter intact. */
341 if (!copy_from_iter_full(src
, slen
, &iter
)) {
347 * This is just overprovisioning, as the algorithm will error out if @dst reaches 7/8
351 dst
= kvzalloc(dlen
, GFP_KERNEL
);
357 ret
= lz77_compress(src
, slen
, dst
, &dlen
);
359 struct smb2_compression_hdr hdr
= { 0 };
360 struct smb_rqst comp_rq
= { .rq_nvec
= 3, };
363 hdr
.ProtocolId
= SMB2_COMPRESSION_TRANSFORM_ID
;
364 hdr
.OriginalCompressedSegmentSize
= cpu_to_le32(slen
);
365 hdr
.CompressionAlgorithm
= SMB3_COMPRESS_LZ77
;
366 hdr
.Flags
= SMB2_COMPRESSION_FLAG_NONE
;
367 hdr
.Offset
= cpu_to_le32(rq
->rq_iov
[0].iov_len
);
369 iov
[0].iov_base
= &hdr
;
370 iov
[0].iov_len
= sizeof(hdr
);
371 iov
[1] = rq
->rq_iov
[0];
372 iov
[2].iov_base
= dst
;
373 iov
[2].iov_len
= dlen
;
375 comp_rq
.rq_iov
= iov
;
377 ret
= send_fn(server
, 1, &comp_rq
);
378 } else if (ret
== -EMSGSIZE
|| dlen
>= slen
) {
379 ret
= send_fn(server
, 1, rq
);