Merge branch 'bpf-recognize-special-arithmetic-shift-in-the-verifier'

Puranjay Mohan says:

====================
bpf: Recognize special arithmetic shift in the verifier

v3: https://lore.kernel.org/all/20260103022310.935686-1-puranjay@kernel.org/
Changes in v3->v4:
- Fork verifier state while processing BPF_OR when src_reg has [-1,0]
  range and 2nd operand is a constant. This is to detect the following pattern:
	i32 X > -1 ? C1 : -1 --> (X >>s 31) | C1
- Add selftests for above.
- Remove __description("s>>=63") (Eduard in another patchset)

v2: https://lore.kernel.org/bpf/20251115022611.64898-1-alexei.starovoitov@gmail.com/
Changes in v2->v3:
- fork verifier state while processing BPF_AND when src_reg has [-1,0]
  range and 2nd operand is a constant.

v1->v2:
Use __mark_reg32_known() or __mark_reg_known() for zero too.
Add comment to selftest.

v1:
https://lore.kernel.org/bpf/20251114031039.63852-1-alexei.starovoitov@gmail.com/
====================

Link: https://patch.msgid.link/20260112201424.816836-1-puranjay@kernel.org
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
This commit is contained in:
Alexei Starovoitov 2026-01-13 09:33:38 -08:00
commit da4ab5dcc9
2 changed files with 124 additions and 0 deletions

View File

@ -15491,6 +15491,35 @@ static bool is_safe_to_compute_dst_reg_range(struct bpf_insn *insn,
}
}
static int maybe_fork_scalars(struct bpf_verifier_env *env, struct bpf_insn *insn,
struct bpf_reg_state *dst_reg)
{
struct bpf_verifier_state *branch;
struct bpf_reg_state *regs;
bool alu32;
if (dst_reg->smin_value == -1 && dst_reg->smax_value == 0)
alu32 = false;
else if (dst_reg->s32_min_value == -1 && dst_reg->s32_max_value == 0)
alu32 = true;
else
return 0;
branch = push_stack(env, env->insn_idx + 1, env->insn_idx, false);
if (IS_ERR(branch))
return PTR_ERR(branch);
regs = branch->frame[branch->curframe]->regs;
if (alu32) {
__mark_reg32_known(&regs[insn->dst_reg], 0);
__mark_reg32_known(dst_reg, -1ull);
} else {
__mark_reg_known(&regs[insn->dst_reg], 0);
__mark_reg_known(dst_reg, -1ull);
}
return 0;
}
/* WARNING: This function does calculations on 64-bit values, but the actual
* execution may occur on 32-bit values. Therefore, things like bitshifts
* need extra checks in the 32-bit case.
@ -15553,11 +15582,21 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env,
scalar_min_max_mul(dst_reg, &src_reg);
break;
case BPF_AND:
if (tnum_is_const(src_reg.var_off)) {
ret = maybe_fork_scalars(env, insn, dst_reg);
if (ret)
return ret;
}
dst_reg->var_off = tnum_and(dst_reg->var_off, src_reg.var_off);
scalar32_min_max_and(dst_reg, &src_reg);
scalar_min_max_and(dst_reg, &src_reg);
break;
case BPF_OR:
if (tnum_is_const(src_reg.var_off)) {
ret = maybe_fork_scalars(env, insn, dst_reg);
if (ret)
return ret;
}
dst_reg->var_off = tnum_or(dst_reg->var_off, src_reg.var_off);
scalar32_min_max_or(dst_reg, &src_reg);
scalar_min_max_or(dst_reg, &src_reg);

View File

@ -738,4 +738,89 @@ __naked void ldx_w_zero_extend_check(void)
: __clobber_all);
}
SEC("socket")
__success __success_unpriv __retval(0)
__naked void arsh_31_and(void)
{
/* Below is what LLVM generates in cilium's bpf_wiregard.o */
asm volatile (" \
call %[bpf_get_prandom_u32]; \
w2 = w0; \
w2 s>>= 31; \
w2 &= -134; /* w2 becomes 0 or -134 */ \
if w2 s> -1 goto +2; \
/* Branch always taken because w2 = -134 */ \
if w2 != -136 goto +1; \
w0 /= 0; \
w0 = 0; \
exit; \
" :
: __imm(bpf_get_prandom_u32)
: __clobber_all);
}
SEC("socket")
__success __success_unpriv __retval(0)
__naked void arsh_63_and(void)
{
/* Copy of arsh_31 with s/w/r/ */
asm volatile (" \
call %[bpf_get_prandom_u32]; \
r2 = r0; \
r2 <<= 32; \
r2 s>>= 63; \
r2 &= -134; \
if r2 s> -1 goto +2; \
/* Branch always taken because w2 = -134 */ \
if r2 != -136 goto +1; \
r0 /= 0; \
r0 = 0; \
exit; \
" :
: __imm(bpf_get_prandom_u32)
: __clobber_all);
}
SEC("socket")
__success __success_unpriv __retval(0)
__naked void arsh_31_or(void)
{
asm volatile (" \
call %[bpf_get_prandom_u32]; \
w2 = w0; \
w2 s>>= 31; \
w2 |= 134; /* w2 becomes -1 or 134 */ \
if w2 s> -1 goto +2; \
/* Branch always taken because w2 = -1 */ \
if w2 == -1 goto +1; \
w0 /= 0; \
w0 = 0; \
exit; \
" :
: __imm(bpf_get_prandom_u32)
: __clobber_all);
}
SEC("socket")
__success __success_unpriv __retval(0)
__naked void arsh_63_or(void)
{
/* Copy of arsh_31 with s/w/r/ */
asm volatile (" \
call %[bpf_get_prandom_u32]; \
r2 = r0; \
r2 <<= 32; \
r2 s>>= 63; \
r2 |= 134; /* r2 becomes -1 or 134 */ \
if r2 s> -1 goto +2; \
/* Branch always taken because w2 = -1 */ \
if r2 == -1 goto +1; \
r0 /= 0; \
r0 = 0; \
exit; \
" :
: __imm(bpf_get_prandom_u32)
: __clobber_all);
}
char _license[] SEC("license") = "GPL";