diff --git a/include/net/psp/functions.h b/include/net/psp/functions.h index 91ba06733321..ef7743664da3 100644 --- a/include/net/psp/functions.h +++ b/include/net/psp/functions.h @@ -124,19 +124,18 @@ psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb) return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc)); } -static inline struct psp_assoc *psp_sk_get_assoc_rcu(struct sock *sk) +static inline struct psp_assoc *psp_sk_get_assoc_rcu(const struct sock *sk) { - struct inet_timewait_sock *tw; struct psp_assoc *pas; int state; - state = 1 << READ_ONCE(sk->sk_state); - if (!sk_is_inet(sk) || state & TCPF_NEW_SYN_RECV) + state = READ_ONCE(sk->sk_state); + if (!sk_is_inet(sk) || state == TCP_NEW_SYN_RECV) return NULL; - tw = inet_twsk(sk); - pas = state & TCPF_TIME_WAIT ? rcu_dereference(tw->psp_assoc) : - rcu_dereference(sk->psp_assoc); + pas = state == TCP_TIME_WAIT ? + rcu_dereference(inet_twsk(sk)->psp_assoc) : + rcu_dereference(sk->psp_assoc); return pas; }