diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h index b6ec08072533..ac1c75975908 100644 --- a/include/net/inet_sock.h +++ b/include/net/inet_sock.h @@ -355,14 +355,6 @@ static inline struct sock *skb_to_full_sk(const struct sk_buff *skb) #define inet_sk(ptr) container_of_const(ptr, struct inet_sock, sk) -static inline void __inet_sk_copy_descendant(struct sock *sk_to, - const struct sock *sk_from, - const int ancestor_size) -{ - memcpy(inet_sk(sk_to) + 1, inet_sk(sk_from) + 1, - sk_from->sk_prot->obj_size - ancestor_size); -} - int inet_sk_rebuild_header(struct sock *sk); /** diff --git a/include/net/sctp/sctp.h b/include/net/sctp/sctp.h index e96d1bd087f6..bb4b80c12541 100644 --- a/include/net/sctp/sctp.h +++ b/include/net/sctp/sctp.h @@ -94,8 +94,7 @@ void sctp_data_ready(struct sock *sk); __poll_t sctp_poll(struct file *file, struct socket *sock, poll_table *wait); void sctp_sock_rfree(struct sk_buff *skb); -void sctp_copy_sock(struct sock *newsk, struct sock *sk, - struct sctp_association *asoc); + extern struct percpu_counter sctp_sockets_allocated; int sctp_asconf_mgmt(struct sctp_sock *, struct sctp_sockaddr_entry *); struct sk_buff *sctp_skb_recv_datagram(struct sock *, int, int *); diff --git a/include/net/sctp/structs.h b/include/net/sctp/structs.h index 2ae390219efd..3dd304e411d0 100644 --- a/include/net/sctp/structs.h +++ b/include/net/sctp/structs.h @@ -497,9 +497,6 @@ struct sctp_pf { int (*bind_verify) (struct sctp_sock *, union sctp_addr *); int (*send_verify) (struct sctp_sock *, union sctp_addr *); int (*supported_addrs)(const struct sctp_sock *, __be16 *); - struct sock *(*create_accept_sk) (struct sock *sk, - struct sctp_association *asoc, - bool kern); int (*addr_to_user)(struct sctp_sock *sk, union sctp_addr *addr); void (*to_sk_saddr)(union sctp_addr *, struct sock *sk); void (*to_sk_daddr)(union sctp_addr *, struct sock *sk); diff --git a/include/net/sock.h b/include/net/sock.h index 01ce231603db..c7e58b8e8a90 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -1822,7 +1822,12 @@ struct sock *sk_alloc(struct net *net, int family, gfp_t priority, void sk_free(struct sock *sk); void sk_net_refcnt_upgrade(struct sock *sk); void sk_destruct(struct sock *sk); -struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority); +struct sock *sk_clone(const struct sock *sk, const gfp_t priority, bool lock); + +static inline struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority) +{ + return sk_clone(sk, priority, true); +} struct sk_buff *sock_wmalloc(struct sock *sk, unsigned long size, int force, gfp_t priority); diff --git a/net/core/sock.c b/net/core/sock.c index a99132cc0965..7a9bbc2afcf0 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -2462,13 +2462,16 @@ static void sk_init_common(struct sock *sk) } /** - * sk_clone_lock - clone a socket, and lock its clone - * @sk: the socket to clone - * @priority: for allocation (%GFP_KERNEL, %GFP_ATOMIC, etc) + * sk_clone - clone a socket + * @sk: the socket to clone + * @priority: for allocation (%GFP_KERNEL, %GFP_ATOMIC, etc) + * @lock: if true, lock the cloned sk * - * Caller must unlock socket even in error path (bh_unlock_sock(newsk)) + * If @lock is true, the clone is locked by bh_lock_sock(), and + * caller must unlock socket even in error path by bh_unlock_sock(). */ -struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority) +struct sock *sk_clone(const struct sock *sk, const gfp_t priority, + bool lock) { struct proto *prot = READ_ONCE(sk->sk_prot); struct sk_filter *filter; @@ -2497,9 +2500,13 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority) __netns_tracker_alloc(sock_net(newsk), &newsk->ns_tracker, false, priority); } + sk_node_init(&newsk->sk_node); sock_lock_init(newsk); - bh_lock_sock(newsk); + + if (lock) + bh_lock_sock(newsk); + newsk->sk_backlog.head = newsk->sk_backlog.tail = NULL; newsk->sk_backlog.len = 0; @@ -2590,12 +2597,13 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority) * destructor and make plain sk_free() */ newsk->sk_destruct = NULL; - bh_unlock_sock(newsk); + if (lock) + bh_unlock_sock(newsk); sk_free(newsk); newsk = NULL; goto out; } -EXPORT_SYMBOL_GPL(sk_clone_lock); +EXPORT_SYMBOL_GPL(sk_clone); static u32 sk_dst_gso_max_size(struct sock *sk, const struct net_device *dev) { diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index e8771faa5bbf..0784e2a873a1 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -755,9 +755,7 @@ EXPORT_SYMBOL(inet_stream_connect); void __inet_accept(struct socket *sock, struct socket *newsock, struct sock *newsk) { - /* TODO: use sk_clone_lock() in SCTP and remove protocol checks */ - if (mem_cgroup_sockets_enabled && - (!IS_ENABLED(CONFIG_IP_SCTP) || sk_is_tcp(newsk))) { + if (mem_cgroup_sockets_enabled) { gfp_t gfp = GFP_KERNEL | __GFP_NOFAIL; mem_cgroup_sk_alloc(newsk); @@ -790,6 +788,7 @@ void __inet_accept(struct socket *sock, struct socket *newsock, struct sock *new newsock->state = SS_CONNECTED; } +EXPORT_SYMBOL_GPL(__inet_accept); /* * Accept a pending connection. The TCP layer now gives BSD semantics. diff --git a/net/sctp/ipv6.c b/net/sctp/ipv6.c index d725b2158758..069b7e45d8bd 100644 --- a/net/sctp/ipv6.c +++ b/net/sctp/ipv6.c @@ -777,56 +777,6 @@ static enum sctp_scope sctp_v6_scope(union sctp_addr *addr) return retval; } -/* Create and initialize a new sk for the socket to be returned by accept(). */ -static struct sock *sctp_v6_create_accept_sk(struct sock *sk, - struct sctp_association *asoc, - bool kern) -{ - struct ipv6_pinfo *newnp, *np = inet6_sk(sk); - struct sctp6_sock *newsctp6sk; - struct inet_sock *newinet; - struct sock *newsk; - - newsk = sk_alloc(sock_net(sk), PF_INET6, GFP_KERNEL, sk->sk_prot, kern); - if (!newsk) - goto out; - - sock_init_data(NULL, newsk); - - sctp_copy_sock(newsk, sk, asoc); - sock_reset_flag(sk, SOCK_ZAPPED); - - newsctp6sk = (struct sctp6_sock *)newsk; - newinet = inet_sk(newsk); - newinet->pinet6 = &newsctp6sk->inet6; - newinet->ipv6_fl_list = NULL; - - sctp_sk(newsk)->v4mapped = sctp_sk(sk)->v4mapped; - - newnp = inet6_sk(newsk); - - memcpy(newnp, np, sizeof(struct ipv6_pinfo)); - newnp->ipv6_mc_list = NULL; - newnp->ipv6_ac_list = NULL; - - sctp_v6_copy_ip_options(sk, newsk); - - /* Initialize sk's sport, dport, rcv_saddr and daddr for getsockname() - * and getpeername(). - */ - sctp_v6_to_sk_daddr(&asoc->peer.primary_addr, newsk); - - newsk->sk_v6_rcv_saddr = sk->sk_v6_rcv_saddr; - - if (newsk->sk_prot->init(newsk)) { - sk_common_release(newsk); - newsk = NULL; - } - -out: - return newsk; -} - /* Format a sockaddr for return to user space. This makes sure the return is * AF_INET or AF_INET6 depending on the SCTP_I_WANT_MAPPED_V4_ADDR option. */ @@ -1173,7 +1123,6 @@ static struct sctp_pf sctp_pf_inet6 = { .bind_verify = sctp_inet6_bind_verify, .send_verify = sctp_inet6_send_verify, .supported_addrs = sctp_inet6_supported_addrs, - .create_accept_sk = sctp_v6_create_accept_sk, .addr_to_user = sctp_v6_addr_to_user, .to_sk_saddr = sctp_v6_to_sk_saddr, .to_sk_daddr = sctp_v6_to_sk_daddr, diff --git a/net/sctp/protocol.c b/net/sctp/protocol.c index 9dbc24af749b..2c3398f75d76 100644 --- a/net/sctp/protocol.c +++ b/net/sctp/protocol.c @@ -580,38 +580,6 @@ static int sctp_v4_is_ce(const struct sk_buff *skb) return INET_ECN_is_ce(ip_hdr(skb)->tos); } -/* Create and initialize a new sk for the socket returned by accept(). */ -static struct sock *sctp_v4_create_accept_sk(struct sock *sk, - struct sctp_association *asoc, - bool kern) -{ - struct sock *newsk = sk_alloc(sock_net(sk), PF_INET, GFP_KERNEL, - sk->sk_prot, kern); - struct inet_sock *newinet; - - if (!newsk) - goto out; - - sock_init_data(NULL, newsk); - - sctp_copy_sock(newsk, sk, asoc); - sock_reset_flag(newsk, SOCK_ZAPPED); - - sctp_v4_copy_ip_options(sk, newsk); - - newinet = inet_sk(newsk); - - newinet->inet_daddr = asoc->peer.primary_addr.v4.sin_addr.s_addr; - - if (newsk->sk_prot->init(newsk)) { - sk_common_release(newsk); - newsk = NULL; - } - -out: - return newsk; -} - static int sctp_v4_addr_to_user(struct sctp_sock *sp, union sctp_addr *addr) { /* No address mapping for V4 sockets */ @@ -1119,7 +1087,6 @@ static struct sctp_pf sctp_pf_inet = { .bind_verify = sctp_inet_bind_verify, .send_verify = sctp_inet_send_verify, .supported_addrs = sctp_inet_supported_addrs, - .create_accept_sk = sctp_v4_create_accept_sk, .addr_to_user = sctp_v4_addr_to_user, .to_sk_saddr = sctp_v4_to_sk_saddr, .to_sk_daddr = sctp_v4_to_sk_daddr, diff --git a/net/sctp/socket.c b/net/sctp/socket.c index ed8293a34240..ac737e60829b 100644 --- a/net/sctp/socket.c +++ b/net/sctp/socket.c @@ -1553,8 +1553,6 @@ static void sctp_close(struct sock *sk, long timeout) spin_unlock_bh(&net->sctp.addr_wq_lock); sock_put(sk); - - SCTP_DBG_OBJCNT_DEC(sock); } /* Handle EPIPE error. */ @@ -4844,6 +4842,74 @@ static int sctp_disconnect(struct sock *sk, int flags) return 0; } +static struct sock *sctp_clone_sock(struct sock *sk, + struct sctp_association *asoc, + enum sctp_socket_type type) +{ + struct sock *newsk = sk_clone(sk, GFP_KERNEL, false); + struct inet_sock *newinet; + struct sctp_sock *newsp; + int err = -ENOMEM; + + if (!newsk) + return ERR_PTR(err); + + /* sk_clone() sets refcnt to 2 */ + sock_put(newsk); + + newinet = inet_sk(newsk); + newsp = sctp_sk(newsk); + + newsp->pf->to_sk_daddr(&asoc->peer.primary_addr, newsk); + newinet->inet_dport = htons(asoc->peer.port); + + newsp->pf->copy_ip_options(sk, newsk); + atomic_set(&newinet->inet_id, get_random_u16()); + + inet_set_bit(MC_LOOP, newsk); + newinet->mc_ttl = 1; + newinet->mc_index = 0; + newinet->mc_list = NULL; + +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) { + struct ipv6_pinfo *newnp = inet6_sk(newsk); + + newinet->pinet6 = &((struct sctp6_sock *)newsk)->inet6; + newinet->ipv6_fl_list = NULL; + + memcpy(newnp, inet6_sk(sk), sizeof(struct ipv6_pinfo)); + newnp->ipv6_mc_list = NULL; + newnp->ipv6_ac_list = NULL; + } +#endif + + skb_queue_head_init(&newsp->pd_lobby); + + newsp->ep = sctp_endpoint_new(newsk, GFP_KERNEL); + if (!newsp->ep) + goto out_release; + + SCTP_DBG_OBJCNT_INC(sock); + sk_sockets_allocated_inc(newsk); + sock_prot_inuse_add(sock_net(sk), newsk->sk_prot, 1); + + err = sctp_sock_migrate(sk, newsk, asoc, type); + if (err) + goto out_release; + + /* Set newsk security attributes from original sk and connection + * security attribute from asoc. + */ + security_sctp_sk_clone(asoc, sk, newsk); + + return newsk; + +out_release: + sk_common_release(newsk); + return ERR_PTR(err); +} + /* 4.1.4 accept() - TCP Style Syntax * * Applications use accept() call to remove an established SCTP @@ -4853,18 +4919,13 @@ static int sctp_disconnect(struct sock *sk, int flags) */ static struct sock *sctp_accept(struct sock *sk, struct proto_accept_arg *arg) { - struct sctp_sock *sp; - struct sctp_endpoint *ep; - struct sock *newsk = NULL; struct sctp_association *asoc; - long timeo; + struct sock *newsk = NULL; int error = 0; + long timeo; lock_sock(sk); - sp = sctp_sk(sk); - ep = sp->ep; - if (!sctp_style(sk, TCP)) { error = -EOPNOTSUPP; goto out; @@ -4885,20 +4946,12 @@ static struct sock *sctp_accept(struct sock *sk, struct proto_accept_arg *arg) /* We treat the list of associations on the endpoint as the accept * queue and pick the first association on the list. */ - asoc = list_entry(ep->asocs.next, struct sctp_association, asocs); + asoc = list_entry(sctp_sk(sk)->ep->asocs.next, + struct sctp_association, asocs); - newsk = sp->pf->create_accept_sk(sk, asoc, arg->kern); - if (!newsk) { - error = -ENOMEM; - goto out; - } - - /* Populate the fields of the newsk from the oldsk and migrate the - * asoc to the newsk. - */ - error = sctp_sock_migrate(sk, newsk, asoc, SCTP_SOCKET_TCP); - if (error) { - sk_common_release(newsk); + newsk = sctp_clone_sock(sk, asoc, SCTP_SOCKET_TCP); + if (IS_ERR(newsk)) { + error = PTR_ERR(newsk); newsk = NULL; } @@ -5109,9 +5162,12 @@ static void sctp_destroy_sock(struct sock *sk) sp->do_auto_asconf = 0; list_del(&sp->auto_asconf_list); } + sctp_endpoint_free(sp->ep); + sk_sockets_allocated_dec(sk); sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); + SCTP_DBG_OBJCNT_DEC(sock); } static void sctp_destruct_sock(struct sock *sk) @@ -5615,11 +5671,11 @@ static int sctp_getsockopt_autoclose(struct sock *sk, int len, char __user *optv /* Helper routine to branch off an association to a new socket. */ static int sctp_do_peeloff(struct sock *sk, sctp_assoc_t id, - struct socket **sockp) + struct socket **sockp) { struct sctp_association *asoc = sctp_id2assoc(sk, id); - struct sctp_sock *sp = sctp_sk(sk); struct socket *sock; + struct sock *newsk; int err = 0; /* Do not peel off from one netns to another one. */ @@ -5635,30 +5691,24 @@ static int sctp_do_peeloff(struct sock *sk, sctp_assoc_t id, if (!sctp_style(sk, UDP)) return -EINVAL; - /* Create a new socket. */ - err = sock_create(sk->sk_family, SOCK_SEQPACKET, IPPROTO_SCTP, &sock); - if (err < 0) + err = sock_create_lite(sk->sk_family, SOCK_SEQPACKET, IPPROTO_SCTP, &sock); + if (err) return err; - sctp_copy_sock(sock->sk, sk, asoc); - - /* Make peeled-off sockets more like 1-1 accepted sockets. - * Set the daddr and initialize id to something more random and also - * copy over any ip options. - */ - sp->pf->to_sk_daddr(&asoc->peer.primary_addr, sock->sk); - sp->pf->copy_ip_options(sk, sock->sk); - - /* Populate the fields of the newsk from the oldsk and migrate the - * asoc to the newsk. - */ - err = sctp_sock_migrate(sk, sock->sk, asoc, - SCTP_SOCKET_UDP_HIGH_BANDWIDTH); - if (err) { + newsk = sctp_clone_sock(sk, asoc, SCTP_SOCKET_UDP_HIGH_BANDWIDTH); + if (IS_ERR(newsk)) { sock_release(sock); - sock = NULL; + *sockp = NULL; + return PTR_ERR(newsk); } + lock_sock_nested(newsk, SINGLE_DEPTH_NESTING); + __inet_accept(sk->sk_socket, sock, newsk); + release_sock(newsk); + + sock->ops = sk->sk_socket->ops; + __module_get(sock->ops->owner); + *sockp = sock; return err; @@ -9441,71 +9491,6 @@ static void sctp_skb_set_owner_r_frag(struct sk_buff *skb, struct sock *sk) sctp_skb_set_owner_r(skb, sk); } -void sctp_copy_sock(struct sock *newsk, struct sock *sk, - struct sctp_association *asoc) -{ - struct inet_sock *inet = inet_sk(sk); - struct inet_sock *newinet; - struct sctp_sock *sp = sctp_sk(sk); - - newsk->sk_type = sk->sk_type; - newsk->sk_bound_dev_if = sk->sk_bound_dev_if; - newsk->sk_flags = sk->sk_flags; - newsk->sk_tsflags = sk->sk_tsflags; - newsk->sk_no_check_tx = sk->sk_no_check_tx; - newsk->sk_no_check_rx = sk->sk_no_check_rx; - newsk->sk_reuse = sk->sk_reuse; - sctp_sk(newsk)->reuse = sp->reuse; - - newsk->sk_shutdown = sk->sk_shutdown; - newsk->sk_destruct = sk->sk_destruct; - newsk->sk_family = sk->sk_family; - newsk->sk_protocol = IPPROTO_SCTP; - newsk->sk_backlog_rcv = sk->sk_prot->backlog_rcv; - newsk->sk_sndbuf = sk->sk_sndbuf; - newsk->sk_rcvbuf = sk->sk_rcvbuf; - newsk->sk_lingertime = sk->sk_lingertime; - newsk->sk_rcvtimeo = READ_ONCE(sk->sk_rcvtimeo); - newsk->sk_sndtimeo = READ_ONCE(sk->sk_sndtimeo); - newsk->sk_rxhash = sk->sk_rxhash; - - newinet = inet_sk(newsk); - - /* Initialize sk's sport, dport, rcv_saddr and daddr for - * getsockname() and getpeername() - */ - newinet->inet_sport = inet->inet_sport; - newinet->inet_saddr = inet->inet_saddr; - newinet->inet_rcv_saddr = inet->inet_rcv_saddr; - newinet->inet_dport = htons(asoc->peer.port); - newinet->pmtudisc = inet->pmtudisc; - atomic_set(&newinet->inet_id, get_random_u16()); - - newinet->uc_ttl = inet->uc_ttl; - inet_set_bit(MC_LOOP, newsk); - newinet->mc_ttl = 1; - newinet->mc_index = 0; - newinet->mc_list = NULL; - - if (newsk->sk_flags & SK_FLAGS_TIMESTAMP) - net_enable_timestamp(); - - /* Set newsk security attributes from original sk and connection - * security attribute from asoc. - */ - security_sctp_sk_clone(asoc, sk, newsk); -} - -static inline void sctp_copy_descendant(struct sock *sk_to, - const struct sock *sk_from) -{ - size_t ancestor_size = sizeof(struct inet_sock); - - ancestor_size += sk_from->sk_prot->obj_size; - ancestor_size -= offsetof(struct sctp_sock, pd_lobby); - __inet_sk_copy_descendant(sk_to, sk_from, ancestor_size); -} - /* Populate the fields of the newsk from the oldsk and migrate the assoc * and its messages to the newsk. */ @@ -9522,14 +9507,6 @@ static int sctp_sock_migrate(struct sock *oldsk, struct sock *newsk, struct sctp_bind_hashbucket *head; int err; - /* Migrate socket buffer sizes and all the socket level options to the - * new socket. - */ - newsk->sk_sndbuf = oldsk->sk_sndbuf; - newsk->sk_rcvbuf = oldsk->sk_rcvbuf; - /* Brute force copy old sctp opt. */ - sctp_copy_descendant(newsk, oldsk); - /* Restore the ep value that was overwritten with the above structure * copy. */