diff --git a/drivers/net/vrf.c b/drivers/net/vrf.c index 662e26117353..ccf677015d5b 100644 --- a/drivers/net/vrf.c +++ b/drivers/net/vrf.c @@ -35,6 +35,7 @@ #include #include #include +#include #define DRV_NAME "vrf" #define DRV_VERSION "1.1" @@ -424,12 +425,26 @@ static int vrf_local_xmit(struct sk_buff *skb, struct net_device *dev, return NETDEV_TX_OK; } +static void vrf_nf_set_untracked(struct sk_buff *skb) +{ + if (skb_get_nfct(skb) == 0) + nf_ct_set(skb, NULL, IP_CT_UNTRACKED); +} + +static void vrf_nf_reset_ct(struct sk_buff *skb) +{ + if (skb_get_nfct(skb) == IP_CT_UNTRACKED) + nf_reset_ct(skb); +} + #if IS_ENABLED(CONFIG_IPV6) static int vrf_ip6_local_out(struct net *net, struct sock *sk, struct sk_buff *skb) { int err; + vrf_nf_reset_ct(skb); + err = nf_hook(NFPROTO_IPV6, NF_INET_LOCAL_OUT, net, sk, skb, NULL, skb_dst(skb)->dev, dst_output); @@ -508,6 +523,8 @@ static int vrf_ip_local_out(struct net *net, struct sock *sk, { int err; + vrf_nf_reset_ct(skb); + err = nf_hook(NFPROTO_IPV4, NF_INET_LOCAL_OUT, net, sk, skb, NULL, skb_dst(skb)->dev, dst_output); if (likely(err == 1)) @@ -626,8 +643,7 @@ static void vrf_finish_direct(struct sk_buff *skb) skb_pull(skb, ETH_HLEN); } - /* reset skb device */ - nf_reset_ct(skb); + vrf_nf_reset_ct(skb); } #if IS_ENABLED(CONFIG_IPV6) @@ -641,7 +657,7 @@ static int vrf_finish_output6(struct net *net, struct sock *sk, struct neighbour *neigh; int ret; - nf_reset_ct(skb); + vrf_nf_reset_ct(skb); skb->protocol = htons(ETH_P_IPV6); skb->dev = dev; @@ -752,6 +768,8 @@ static struct sk_buff *vrf_ip6_out_direct(struct net_device *vrf_dev, skb->dev = vrf_dev; + vrf_nf_set_untracked(skb); + err = nf_hook(NFPROTO_IPV6, NF_INET_LOCAL_OUT, net, sk, skb, NULL, vrf_dev, vrf_ip6_out_direct_finish); @@ -858,7 +876,7 @@ static int vrf_finish_output(struct net *net, struct sock *sk, struct sk_buff *s struct neighbour *neigh; bool is_v6gw = false; - nf_reset_ct(skb); + vrf_nf_reset_ct(skb); /* Be paranoid, rather than too clever. */ if (unlikely(skb_headroom(skb) < hh_len && dev->header_ops)) { @@ -980,6 +998,8 @@ static struct sk_buff *vrf_ip_out_direct(struct net_device *vrf_dev, skb->dev = vrf_dev; + vrf_nf_set_untracked(skb); + err = nf_hook(NFPROTO_IPV4, NF_INET_LOCAL_OUT, net, sk, skb, NULL, vrf_dev, vrf_ip_out_direct_finish); diff --git a/net/netfilter/nf_conntrack_proto.c b/net/netfilter/nf_conntrack_proto.c index 8f7a9837349c..d1f2d3c8d2b1 100644 --- a/net/netfilter/nf_conntrack_proto.c +++ b/net/netfilter/nf_conntrack_proto.c @@ -155,6 +155,16 @@ unsigned int nf_confirm(struct sk_buff *skb, unsigned int protoff, } EXPORT_SYMBOL_GPL(nf_confirm); +static bool in_vrf_postrouting(const struct nf_hook_state *state) +{ +#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV) + if (state->hook == NF_INET_POST_ROUTING && + netif_is_l3_master(state->out)) + return true; +#endif + return false; +} + static unsigned int ipv4_confirm(void *priv, struct sk_buff *skb, const struct nf_hook_state *state) @@ -166,6 +176,9 @@ static unsigned int ipv4_confirm(void *priv, if (!ct || ctinfo == IP_CT_RELATED_REPLY) return nf_conntrack_confirm(skb); + if (in_vrf_postrouting(state)) + return NF_ACCEPT; + return nf_confirm(skb, skb_network_offset(skb) + ip_hdrlen(skb), ct, ctinfo); @@ -374,6 +387,9 @@ static unsigned int ipv6_confirm(void *priv, if (!ct || ctinfo == IP_CT_RELATED_REPLY) return nf_conntrack_confirm(skb); + if (in_vrf_postrouting(state)) + return NF_ACCEPT; + protoff = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &pnum, &frag_off); if (protoff < 0 || (frag_off & htons(~0x7)) != 0) { diff --git a/net/netfilter/nf_nat_core.c b/net/netfilter/nf_nat_core.c index 273117683922..4d50d51db796 100644 --- a/net/netfilter/nf_nat_core.c +++ b/net/netfilter/nf_nat_core.c @@ -699,6 +699,16 @@ unsigned int nf_nat_packet(struct nf_conn *ct, } EXPORT_SYMBOL_GPL(nf_nat_packet); +static bool in_vrf_postrouting(const struct nf_hook_state *state) +{ +#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV) + if (state->hook == NF_INET_POST_ROUTING && + netif_is_l3_master(state->out)) + return true; +#endif + return false; +} + unsigned int nf_nat_inet_fn(void *priv, struct sk_buff *skb, const struct nf_hook_state *state) @@ -715,7 +725,7 @@ nf_nat_inet_fn(void *priv, struct sk_buff *skb, * packet filter it out, or implement conntrack/NAT for that * protocol. 8) --RR */ - if (!ct) + if (!ct || in_vrf_postrouting(state)) return NF_ACCEPT; nat = nfct_nat(ct);