diff --git a/include/linux/bpf.h b/include/linux/bpf.h index 2da986136d26..da6a00dd313f 100644 --- a/include/linux/bpf.h +++ b/include/linux/bpf.h @@ -1746,6 +1746,8 @@ struct bpf_prog_aux { struct bpf_map __rcu *st_ops_assoc; }; +#define BPF_NR_CONTEXTS 4 /* normal, softirq, hardirq, NMI */ + struct bpf_prog { u16 pages; /* Number of allocated pages */ u16 jited:1, /* Is our filter JIT'ed? */ @@ -1772,7 +1774,7 @@ struct bpf_prog { u8 tag[BPF_TAG_SIZE]; }; struct bpf_prog_stats __percpu *stats; - int __percpu *active; + u8 __percpu *active; /* u8[BPF_NR_CONTEXTS] for recursion protection */ unsigned int (*bpf_func)(const void *ctx, const struct bpf_insn *insn); struct bpf_prog_aux *aux; /* Auxiliary fields */ @@ -2006,12 +2008,36 @@ struct bpf_struct_ops_common_value { static inline bool bpf_prog_get_recursion_context(struct bpf_prog *prog) { - return this_cpu_inc_return(*(prog->active)) == 1; +#ifdef CONFIG_ARM64 + u8 rctx = interrupt_context_level(); + u8 *active = this_cpu_ptr(prog->active); + u32 val; + + preempt_disable(); + active[rctx]++; + val = le32_to_cpu(*(__le32 *)active); + preempt_enable(); + if (val != BIT(rctx * 8)) + return false; + + return true; +#else + return this_cpu_inc_return(*(int __percpu *)(prog->active)) == 1; +#endif } static inline void bpf_prog_put_recursion_context(struct bpf_prog *prog) { - this_cpu_dec(*(prog->active)); +#ifdef CONFIG_ARM64 + u8 rctx = interrupt_context_level(); + u8 *active = this_cpu_ptr(prog->active); + + preempt_disable(); + active[rctx]--; + preempt_enable(); +#else + this_cpu_dec(*(int __percpu *)(prog->active)); +#endif } #if defined(CONFIG_BPF_JIT) && defined(CONFIG_BPF_SYSCALL) diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c index c66316e32563..e0b8a8a5aaa9 100644 --- a/kernel/bpf/core.c +++ b/kernel/bpf/core.c @@ -112,7 +112,8 @@ struct bpf_prog *bpf_prog_alloc_no_stats(unsigned int size, gfp_t gfp_extra_flag vfree(fp); return NULL; } - fp->active = alloc_percpu_gfp(int, bpf_memcg_flags(GFP_KERNEL | gfp_extra_flags)); + fp->active = __alloc_percpu_gfp(sizeof(u8[BPF_NR_CONTEXTS]), 4, + bpf_memcg_flags(GFP_KERNEL | gfp_extra_flags)); if (!fp->active) { vfree(fp); kfree(aux);