diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c index 4abc359b3db0..ec9c310cf5d7 100644 --- a/kernel/bpf/tnum.c +++ b/kernel/bpf/tnum.c @@ -286,8 +286,7 @@ struct tnum tnum_bswap64(struct tnum a) */ u64 tnum_step(struct tnum t, u64 z) { - u64 tmax, j, p, q, r, s, v, u, w, res; - u8 k; + u64 tmax, d, carry_mask, filled, inc; tmax = t.value | t.mask; @@ -299,29 +298,22 @@ u64 tnum_step(struct tnum t, u64 z) if (z < t.value) return t.value; - /* keep t's known bits, and match all unknown bits to z */ - j = t.value | (z & t.mask); - - if (j > z) { - p = ~z & t.value & ~t.mask; - k = fls64(p); /* k is the most-significant 0-to-1 flip */ - q = U64_MAX << k; - r = q & z; /* positions > k matched to z */ - s = ~q & t.value; /* positions <= k matched to t.value */ - v = r | s; - res = v; - } else { - p = z & ~t.value & ~t.mask; - k = fls64(p); /* k is the most-significant 1-to-0 flip */ - q = U64_MAX << k; - r = q & t.mask & z; /* unknown positions > k, matched to z */ - s = q & ~t.mask; /* known positions > k, set to 1 */ - v = r | s; - /* add 1 to unknown positions > k to make value greater than z */ - u = v + (1ULL << k); - /* extract bits in unknown positions > k from u, rest from t.value */ - w = (u & t.mask) | t.value; - res = w; - } - return res; + /* + * Let r be the result tnum member, z = t.value + d. + * Every tnum member is t.value | s for some submask s of t.mask, + * and since t.value & t.mask == 0, t.value | s == t.value + s. + * So r > z becomes s > d where d = z - t.value. + * + * Find the smallest submask s of t.mask greater than d by + * "incrementing d within the mask": fill every non-mask + * position with 1 (`filled`) so +1 ripples through the gaps, + * then keep only mask bits. `carry_mask` additionally fills + * positions below the highest non-mask 1 in d, preventing + * it from trapping the carry. + */ + d = z - t.value; + carry_mask = (1ULL << fls64(d & ~t.mask)) - 1; + filled = d | carry_mask | ~t.mask; + inc = (filled + 1) & t.mask; + return t.value | inc; }