diff --git a/drivers/infiniband/sw/rxe/rxe.c b/drivers/infiniband/sw/rxe/rxe.c index e891199cbdef..b0714f9abe3d 100644 --- a/drivers/infiniband/sw/rxe/rxe.c +++ b/drivers/infiniband/sw/rxe/rxe.c @@ -8,6 +8,8 @@ #include #include "rxe.h" #include "rxe_loc.h" +#include "rxe_net.h" +#include "rxe_ns.h" MODULE_AUTHOR("Bob Pearson, Frank Zago, John Groves, Kamal Heib"); MODULE_DESCRIPTION("Soft RDMA transport"); @@ -200,6 +202,8 @@ void rxe_set_mtu(struct rxe_dev *rxe, unsigned int ndev_mtu) port->mtu_cap = ib_mtu_enum_to_int(mtu); } +static struct rdma_link_ops rxe_link_ops; + /* called by ifc layer to create new rxe device. * The caller should allocate memory for rxe by calling ib_alloc_device. */ @@ -208,6 +212,7 @@ int rxe_add(struct rxe_dev *rxe, unsigned int mtu, const char *ibdev_name, { rxe_init(rxe, ndev); rxe_set_mtu(rxe, mtu); + rxe->ib_dev.link_ops = &rxe_link_ops; return rxe_register_device(rxe, ibdev_name, ndev); } @@ -231,6 +236,10 @@ static int rxe_newlink(const char *ibdev_name, struct net_device *ndev) goto err; } + err = rxe_net_init(ndev); + if (err) + return err; + err = rxe_net_add(ibdev_name, ndev); if (err) { rxe_err("failed to add %s\n", ndev->name); @@ -240,9 +249,17 @@ static int rxe_newlink(const char *ibdev_name, struct net_device *ndev) return err; } +static int rxe_dellink(struct ib_device *dev) +{ + rxe_net_del(dev); + + return 0; +} + static struct rdma_link_ops rxe_link_ops = { .type = "rxe", .newlink = rxe_newlink, + .dellink = rxe_dellink, }; static int __init rxe_module_init(void) @@ -253,15 +270,24 @@ static int __init rxe_module_init(void) if (err) return err; - err = rxe_net_init(); - if (err) { - rxe_destroy_wq(); - return err; - } + err = rxe_namespace_init(); + if (err) + goto err_destroy_wq; + + err = rxe_register_notifier(); + if (err) + goto err_namespace_exit; rdma_link_register(&rxe_link_ops); + pr_info("loaded\n"); return 0; + +err_namespace_exit: + rxe_namespace_exit(); +err_destroy_wq: + rxe_destroy_wq(); + return err; } static void __exit rxe_module_exit(void) @@ -271,6 +297,8 @@ static void __exit rxe_module_exit(void) rxe_net_exit(); rxe_destroy_wq(); + rxe_namespace_exit(); + pr_info("unloaded\n"); } diff --git a/drivers/infiniband/sw/rxe/rxe_net.c b/drivers/infiniband/sw/rxe/rxe_net.c index 0bd0902b11f7..211bd3000acc 100644 --- a/drivers/infiniband/sw/rxe/rxe_net.c +++ b/drivers/infiniband/sw/rxe/rxe_net.c @@ -17,8 +17,11 @@ #include "rxe.h" #include "rxe_net.h" #include "rxe_loc.h" +#include "rxe_ns.h" -static struct rxe_recv_sockets recv_sockets; +#ifndef SK_REF_FOR_TUNNEL +#define SK_REF_FOR_TUNNEL 2 +#endif #ifdef CONFIG_DEBUG_LOCK_ALLOC /* @@ -101,20 +104,20 @@ static inline void rxe_reclassify_recv_socket(struct socket *sock) } static struct dst_entry *rxe_find_route4(struct rxe_qp *qp, + struct net *net, struct net_device *ndev, struct in_addr *saddr, struct in_addr *daddr) { struct rtable *rt; - struct flowi4 fl = { { 0 } }; + struct flowi4 fl = {}; - memset(&fl, 0, sizeof(fl)); fl.flowi4_oif = ndev->ifindex; memcpy(&fl.saddr, saddr, sizeof(*saddr)); memcpy(&fl.daddr, daddr, sizeof(*daddr)); fl.flowi4_proto = IPPROTO_UDP; - rt = ip_route_output_key(&init_net, &fl); + rt = ip_route_output_key(net, &fl); if (IS_ERR(rt)) { rxe_dbg_qp(qp, "no route to %pI4\n", &daddr->s_addr); return NULL; @@ -125,21 +128,21 @@ static struct dst_entry *rxe_find_route4(struct rxe_qp *qp, #if IS_ENABLED(CONFIG_IPV6) static struct dst_entry *rxe_find_route6(struct rxe_qp *qp, + struct net *net, struct net_device *ndev, struct in6_addr *saddr, struct in6_addr *daddr) { struct dst_entry *ndst; - struct flowi6 fl6 = { { 0 } }; + struct flowi6 fl6 = {}; - memset(&fl6, 0, sizeof(fl6)); fl6.flowi6_oif = ndev->ifindex; memcpy(&fl6.saddr, saddr, sizeof(*saddr)); memcpy(&fl6.daddr, daddr, sizeof(*daddr)); fl6.flowi6_proto = IPPROTO_UDP; - ndst = ipv6_stub->ipv6_dst_lookup_flow(sock_net(recv_sockets.sk6->sk), - recv_sockets.sk6->sk, &fl6, + ndst = ipv6_stub->ipv6_dst_lookup_flow(net, + rxe_ns_pernet_sk6(net), &fl6, NULL); if (IS_ERR(ndst)) { rxe_dbg_qp(qp, "no route to %pI6\n", daddr); @@ -160,6 +163,7 @@ static struct dst_entry *rxe_find_route6(struct rxe_qp *qp, #else static struct dst_entry *rxe_find_route6(struct rxe_qp *qp, + struct net *net, struct net_device *ndev, struct in6_addr *saddr, struct in6_addr *daddr) @@ -174,6 +178,7 @@ static struct dst_entry *rxe_find_route(struct net_device *ndev, struct rxe_av *av) { struct dst_entry *dst = NULL; + struct net *net; if (qp_type(qp) == IB_QPT_RC) dst = sk_dst_get(qp->sk->sk); @@ -182,20 +187,22 @@ static struct dst_entry *rxe_find_route(struct net_device *ndev, if (dst) dst_release(dst); + net = dev_net(ndev); + if (av->network_type == RXE_NETWORK_TYPE_IPV4) { struct in_addr *saddr; struct in_addr *daddr; saddr = &av->sgid_addr._sockaddr_in.sin_addr; daddr = &av->dgid_addr._sockaddr_in.sin_addr; - dst = rxe_find_route4(qp, ndev, saddr, daddr); + dst = rxe_find_route4(qp, net, ndev, saddr, daddr); } else if (av->network_type == RXE_NETWORK_TYPE_IPV6) { struct in6_addr *saddr6; struct in6_addr *daddr6; saddr6 = &av->sgid_addr._sockaddr_in6.sin6_addr; daddr6 = &av->dgid_addr._sockaddr_in6.sin6_addr; - dst = rxe_find_route6(qp, ndev, saddr6, daddr6); + dst = rxe_find_route6(qp, net, ndev, saddr6, daddr6); #if IS_ENABLED(CONFIG_IPV6) if (dst) qp->dst_cookie = @@ -624,6 +631,43 @@ int rxe_net_add(const char *ibdev_name, struct net_device *ndev) return 0; } +static void rxe_sock_put(struct sock *sk, + void (*set_sk)(struct net *, struct sock *), + struct net *net) +{ + if (refcount_read(&sk->sk_refcnt) > SK_REF_FOR_TUNNEL) { + __sock_put(sk); + } else { + rxe_release_udp_tunnel(sk->sk_socket); + sk = NULL; + set_sk(net, sk); + } +} + +void rxe_net_del(struct ib_device *dev) +{ + struct rxe_dev *rxe = container_of(dev, struct rxe_dev, ib_dev); + struct net_device *ndev; + struct sock *sk; + struct net *net; + + ndev = rxe_ib_device_get_netdev(&rxe->ib_dev); + if (!ndev) + return; + + net = dev_net(ndev); + + sk = rxe_ns_pernet_sk4(net); + if (sk) + rxe_sock_put(sk, rxe_ns_pernet_set_sk4, net); + + sk = rxe_ns_pernet_sk6(net); + if (sk) + rxe_sock_put(sk, rxe_ns_pernet_set_sk6, net); + + dev_put(ndev); +} + static void rxe_port_event(struct rxe_dev *rxe, enum ib_event_type event) { @@ -680,6 +724,7 @@ static int rxe_notify(struct notifier_block *not_blk, switch (event) { case NETDEV_UNREGISTER: ib_unregister_device_queued(&rxe->ib_dev); + rxe_net_del(&rxe->ib_dev); break; case NETDEV_CHANGEMTU: rxe_dbg_dev(rxe, "%s changed mtu to %d\n", ndev->name, ndev->mtu); @@ -709,66 +754,97 @@ static struct notifier_block rxe_net_notifier = { .notifier_call = rxe_notify, }; -static int rxe_net_ipv4_init(void) +static int rxe_net_ipv4_init(struct net *net) { - recv_sockets.sk4 = rxe_setup_udp_tunnel(&init_net, - htons(ROCE_V2_UDP_DPORT), false); - if (IS_ERR(recv_sockets.sk4)) { - recv_sockets.sk4 = NULL; + struct sock *sk; + struct socket *sock; + + sk = rxe_ns_pernet_sk4(net); + if (sk) { + sock_hold(sk); + return 0; + } + + sock = rxe_setup_udp_tunnel(net, htons(ROCE_V2_UDP_DPORT), false); + if (IS_ERR(sock)) { pr_err("Failed to create IPv4 UDP tunnel\n"); return -1; } + rxe_ns_pernet_set_sk4(net, sock->sk); return 0; } -static int rxe_net_ipv6_init(void) +static int rxe_net_ipv6_init(struct net *net) { #if IS_ENABLED(CONFIG_IPV6) + struct sock *sk; + struct socket *sock; - recv_sockets.sk6 = rxe_setup_udp_tunnel(&init_net, - htons(ROCE_V2_UDP_DPORT), true); - if (PTR_ERR(recv_sockets.sk6) == -EAFNOSUPPORT) { - recv_sockets.sk6 = NULL; + sk = rxe_ns_pernet_sk6(net); + if (sk) { + sock_hold(sk); + return 0; + } + + sock = rxe_setup_udp_tunnel(net, htons(ROCE_V2_UDP_DPORT), true); + if (PTR_ERR(sock) == -EAFNOSUPPORT) { pr_warn("IPv6 is not supported, can not create a UDPv6 socket\n"); return 0; } - if (IS_ERR(recv_sockets.sk6)) { - recv_sockets.sk6 = NULL; + if (IS_ERR(sock)) { pr_err("Failed to create IPv6 UDP tunnel\n"); return -1; } + + rxe_ns_pernet_set_sk6(net, sock->sk); + #endif return 0; } +int rxe_register_notifier(void) +{ + int err; + + err = register_netdevice_notifier(&rxe_net_notifier); + if (err) { + pr_err("Failed to register netdev notifier\n"); + return -1; + } + + return 0; +} + void rxe_net_exit(void) { - rxe_release_udp_tunnel(recv_sockets.sk6); - rxe_release_udp_tunnel(recv_sockets.sk4); unregister_netdevice_notifier(&rxe_net_notifier); } -int rxe_net_init(void) +int rxe_net_init(struct net_device *ndev) { + struct net *net; + struct sock *sk; int err; - recv_sockets.sk6 = NULL; + net = dev_net(ndev); - err = rxe_net_ipv4_init(); + err = rxe_net_ipv4_init(net); if (err) return err; - err = rxe_net_ipv6_init(); + + err = rxe_net_ipv6_init(net); if (err) goto err_out; - err = register_netdevice_notifier(&rxe_net_notifier); - if (err) { - pr_err("Failed to register netdev notifier\n"); - goto err_out; - } + return 0; + err_out: - rxe_net_exit(); + /* If ipv6 error, release ipv4 resource */ + sk = rxe_ns_pernet_sk4(net); + if (sk) + rxe_sock_put(sk, rxe_ns_pernet_set_sk4, net); + return err; } diff --git a/drivers/infiniband/sw/rxe/rxe_net.h b/drivers/infiniband/sw/rxe/rxe_net.h index 45d80d00f86b..56249677d692 100644 --- a/drivers/infiniband/sw/rxe/rxe_net.h +++ b/drivers/infiniband/sw/rxe/rxe_net.h @@ -11,14 +11,11 @@ #include #include -struct rxe_recv_sockets { - struct socket *sk4; - struct socket *sk6; -}; - int rxe_net_add(const char *ibdev_name, struct net_device *ndev); +void rxe_net_del(struct ib_device *dev); -int rxe_net_init(void); +int rxe_register_notifier(void); +int rxe_net_init(struct net_device *ndev); void rxe_net_exit(void); #endif /* RXE_NET_H */