1 /* SPDX-License-Identifier: GPL-2.0 */
2 #ifndef MEAN_AND_VARIANCE_H_
3 #define MEAN_AND_VARIANCE_H_
5 #include <linux/types.h>
6 #include <linux/limits.h>
7 #include <linux/math.h>
8 #include <linux/math64.h>
10 #define SQRT_U64_MAX 4294967295ULL
13 * u128_u: u128 user mode, because not all architectures support a real int128
16 * We don't use this version in userspace, because in userspace we link with
17 * Rust and rustc has issues with u128.
20 #if defined(__SIZEOF_INT128__) && defined(__KERNEL__) && !defined(CONFIG_PARISC)
24 } __aligned(16) u128_u
;
26 static inline u128_u
u64_to_u128(u64 a
)
28 return (u128_u
) { .v
= a
};
31 static inline u64
u128_lo(u128_u a
)
36 static inline u64
u128_hi(u128_u a
)
41 static inline u128_u
u128_add(u128_u a
, u128_u b
)
47 static inline u128_u
u128_sub(u128_u a
, u128_u b
)
53 static inline u128_u
u128_shl(u128_u a
, s8 shift
)
59 static inline u128_u
u128_square(u64 a
)
61 u128_u b
= u64_to_u128(a
);
71 } __aligned(16) u128_u
;
75 static inline u128_u
u64_to_u128(u64 a
)
77 return (u128_u
) { .lo
= a
};
80 static inline u64
u128_lo(u128_u a
)
85 static inline u64
u128_hi(u128_u a
)
92 static inline u128_u
u128_add(u128_u a
, u128_u b
)
97 c
.hi
= a
.hi
+ b
.hi
+ (c
.lo
< a
.lo
);
101 static inline u128_u
u128_sub(u128_u a
, u128_u b
)
106 c
.hi
= a
.hi
- b
.hi
- (c
.lo
> a
.lo
);
110 static inline u128_u
u128_shl(u128_u i
, s8 shift
)
114 r
.lo
= i
.lo
<< (shift
& 63);
116 r
.hi
= (i
.hi
<< (shift
& 63)) | (i
.lo
>> (-shift
& 63));
118 r
.hi
= i
.lo
<< (-shift
& 63);
124 static inline u128_u
u128_square(u64 i
)
127 u64 h
= i
>> 32, l
= i
& U32_MAX
;
129 r
= u128_shl(u64_to_u128(h
*h
), 64);
130 r
= u128_add(r
, u128_shl(u64_to_u128(h
*l
), 32));
131 r
= u128_add(r
, u128_shl(u64_to_u128(l
*h
), 32));
132 r
= u128_add(r
, u64_to_u128(l
*l
));
138 static inline u128_u
u64s_to_u128(u64 hi
, u64 lo
)
140 u128_u c
= u64_to_u128(hi
);
143 c
= u128_add(c
, u64_to_u128(lo
));
147 u128_u
u128_div(u128_u n
, u64 d
);
149 struct mean_and_variance
{
155 /* expontentially weighted variant */
156 struct mean_and_variance_weighted
{
162 * fast_divpow2() - fast approximation for n / (1 << d)
164 * @d: the power of 2 denominator.
166 * note: this rounds towards 0.
168 static inline s64
fast_divpow2(s64 n
, u8 d
)
170 return (n
+ ((n
< 0) ? ((1 << d
) - 1) : 0)) >> d
;
174 * mean_and_variance_update() - update a mean_and_variance struct @s1 with a new sample @v1
176 * @s1: the mean_and_variance to update.
177 * @v1: the new sample.
179 * see linked pdf equation 12.
182 mean_and_variance_update(struct mean_and_variance
*s
, s64 v
)
186 s
->sum_squares
= u128_add(s
->sum_squares
, u128_square(abs(v
)));
189 s64
mean_and_variance_get_mean(struct mean_and_variance s
);
190 u64
mean_and_variance_get_variance(struct mean_and_variance s1
);
191 u32
mean_and_variance_get_stddev(struct mean_and_variance s
);
193 void mean_and_variance_weighted_update(struct mean_and_variance_weighted
*s
,
194 s64 v
, bool initted
, u8 weight
);
196 s64
mean_and_variance_weighted_get_mean(struct mean_and_variance_weighted s
,
198 u64
mean_and_variance_weighted_get_variance(struct mean_and_variance_weighted s
,
200 u32
mean_and_variance_weighted_get_stddev(struct mean_and_variance_weighted s
,
203 #endif // MEAN_AND_VAIRANCE_H_