vsock: keep poll shutdown state consistent

vsock_poll() reads vsk->peer_shutdown before taking the socket lock
to set EPOLLHUP and EPOLLRDHUP, then reads it again after taking
the lock to report EOF readability. A shutdown packet can update
peer_shutdown while poll is waiting for the lock, so one poll invocation
can report EOF readability without the corresponding HUP/RDHUP bits.

For connectible sockets, take one peer_shutdown snapshot after
lock_sock() and use it for all peer-shutdown-derived poll bits. For
datagram sockets, which do not take lock_sock() in poll(), take one
lockless READ_ONCE() snapshot and pair it with WRITE_ONCE() on the
writer side.

This keeps the peer-shutdown-derived bits internally consistent for each
poll pass.

Fixes: d021c34405 ("VSOCK: Introduce VM Sockets")
Signed-off-by: Ziyu Zhang <ziyuzhang201@gmail.com>
Link: https://patch.msgid.link/20260519165636.62542-1-ziyuzhang201@gmail.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
This commit is contained in:
Ziyu Zhang 2026-05-20 00:56:36 +08:00 committed by Jakub Kicinski
parent aa8963fdce
commit aae9d8a552
4 changed files with 52 additions and 28 deletions

View File

@ -642,7 +642,7 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
*/
sock_reset_flag(sk, SOCK_DONE);
sk->sk_state = TCP_CLOSE;
vsk->peer_shutdown = 0;
WRITE_ONCE(vsk->peer_shutdown, 0);
}
if (sk->sk_type == SOCK_SEQPACKET) {
@ -933,7 +933,7 @@ static struct sock *__vsock_create(struct net *net,
vsk->rejected = false;
vsk->sent_request = false;
vsk->ignore_connecting_rst = false;
vsk->peer_shutdown = 0;
WRITE_ONCE(vsk->peer_shutdown, 0);
INIT_DELAYED_WORK(&vsk->connect_work, vsock_connect_timeout);
INIT_DELAYED_WORK(&vsk->pending_work, vsock_pending_work);
@ -1241,6 +1241,25 @@ static int vsock_shutdown(struct socket *sock, int mode)
return err;
}
static __poll_t vsock_poll_shutdown(struct sock *sk, u32 peer_shutdown)
{
__poll_t mask = 0;
/* INET sockets treat local write shutdown and peer write shutdown as a
* case of EPOLLHUP set.
*/
if (sk->sk_shutdown == SHUTDOWN_MASK ||
((sk->sk_shutdown & SEND_SHUTDOWN) &&
(peer_shutdown & SEND_SHUTDOWN)))
mask |= EPOLLHUP;
if (sk->sk_shutdown & RCV_SHUTDOWN ||
peer_shutdown & SEND_SHUTDOWN)
mask |= EPOLLRDHUP;
return mask;
}
static __poll_t vsock_poll(struct file *file, struct socket *sock,
poll_table *wait)
{
@ -1258,24 +1277,17 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,
/* Signify that there has been an error on this socket. */
mask |= EPOLLERR;
/* INET sockets treat local write shutdown and peer write shutdown as a
* case of EPOLLHUP set.
*/
if ((sk->sk_shutdown == SHUTDOWN_MASK) ||
((sk->sk_shutdown & SEND_SHUTDOWN) &&
(vsk->peer_shutdown & SEND_SHUTDOWN))) {
mask |= EPOLLHUP;
}
if (sk->sk_shutdown & RCV_SHUTDOWN ||
vsk->peer_shutdown & SEND_SHUTDOWN) {
mask |= EPOLLRDHUP;
}
if (sk_is_readable(sk))
mask |= EPOLLIN | EPOLLRDNORM;
if (sock->type == SOCK_DGRAM) {
u32 peer_shutdown = READ_ONCE(vsk->peer_shutdown);
/* DGRAM sockets do not take lock_sock() in poll(), so use one
* lockless snapshot for all shutdown-derived mask bits.
*/
mask |= vsock_poll_shutdown(sk, peer_shutdown);
/* For datagram sockets we can read if there is something in
* the queue and write as long as the socket isn't shutdown for
* sending.
@ -1290,6 +1302,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,
} else if (sock_type_connectible(sk->sk_type)) {
const struct vsock_transport *transport;
u32 peer_shutdown;
lock_sock(sk);
@ -1322,8 +1335,10 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,
* terminated should also be considered read, and we check the
* shutdown flag for that.
*/
peer_shutdown = READ_ONCE(vsk->peer_shutdown);
mask |= vsock_poll_shutdown(sk, peer_shutdown);
if (sk->sk_shutdown & RCV_SHUTDOWN ||
vsk->peer_shutdown & SEND_SHUTDOWN) {
peer_shutdown & SEND_SHUTDOWN) {
mask |= EPOLLIN | EPOLLRDNORM;
}

View File

@ -264,7 +264,7 @@ static void hvs_do_close_lock_held(struct vsock_sock *vsk,
struct sock *sk = sk_vsock(vsk);
sock_set_flag(sk, SOCK_DONE);
vsk->peer_shutdown = SHUTDOWN_MASK;
WRITE_ONCE(vsk->peer_shutdown, SHUTDOWN_MASK);
if (vsock_stream_has_data(vsk) <= 0)
sk->sk_state = TCP_CLOSING;
sk->sk_state_change(sk);
@ -593,7 +593,9 @@ static int hvs_update_recv_data(struct hvsock *hvs)
return -EIO;
if (payload_len == 0)
hvs->vsk->peer_shutdown |= SEND_SHUTDOWN;
WRITE_ONCE(hvs->vsk->peer_shutdown,
READ_ONCE(hvs->vsk->peer_shutdown) |
SEND_SHUTDOWN);
hvs->recv_data_len = payload_len;
hvs->recv_data_off = 0;
@ -736,7 +738,8 @@ static s64 hvs_stream_has_data(struct vsock_sock *vsk)
return ret;
return hvs->recv_data_len;
case 0:
vsk->peer_shutdown |= SEND_SHUTDOWN;
WRITE_ONCE(vsk->peer_shutdown,
READ_ONCE(vsk->peer_shutdown) | SEND_SHUTDOWN);
ret = 0;
break;
default: /* -1 */

View File

@ -1228,7 +1228,7 @@ static void virtio_transport_do_close(struct vsock_sock *vsk,
struct sock *sk = sk_vsock(vsk);
sock_set_flag(sk, SOCK_DONE);
vsk->peer_shutdown = SHUTDOWN_MASK;
WRITE_ONCE(vsk->peer_shutdown, SHUTDOWN_MASK);
if (vsock_stream_has_data(vsk) <= 0)
sk->sk_state = TCP_CLOSING;
sk->sk_state_change(sk);
@ -1431,12 +1431,15 @@ virtio_transport_recv_connected(struct sock *sk,
case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
sk->sk_write_space(sk);
break;
case VIRTIO_VSOCK_OP_SHUTDOWN:
case VIRTIO_VSOCK_OP_SHUTDOWN: {
u32 peer_shutdown = READ_ONCE(vsk->peer_shutdown);
if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
vsk->peer_shutdown |= RCV_SHUTDOWN;
peer_shutdown |= RCV_SHUTDOWN;
if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
vsk->peer_shutdown |= SEND_SHUTDOWN;
if (vsk->peer_shutdown == SHUTDOWN_MASK) {
peer_shutdown |= SEND_SHUTDOWN;
WRITE_ONCE(vsk->peer_shutdown, peer_shutdown);
if (peer_shutdown == SHUTDOWN_MASK) {
if (vsock_stream_has_data(vsk) <= 0 && !sock_flag(sk, SOCK_DONE)) {
(void)virtio_transport_reset(vsk, NULL);
virtio_transport_do_close(vsk, true);
@ -1451,6 +1454,7 @@ virtio_transport_recv_connected(struct sock *sk,
if (le32_to_cpu(virtio_vsock_hdr(skb)->flags))
sk->sk_state_change(sk);
break;
}
case VIRTIO_VSOCK_OP_RST:
virtio_transport_do_close(vsk, true);
break;

View File

@ -819,7 +819,7 @@ static void vmci_transport_handle_detach(struct sock *sk)
/* On a detach the peer will not be sending or receiving
* anymore.
*/
vsk->peer_shutdown = SHUTDOWN_MASK;
WRITE_ONCE(vsk->peer_shutdown, SHUTDOWN_MASK);
/* We should not be sending anymore since the peer won't be
* there to receive, but we can still receive if there is data
@ -1542,7 +1542,9 @@ static int vmci_transport_recv_connected(struct sock *sk,
if (pkt->u.mode) {
vsk = vsock_sk(sk);
vsk->peer_shutdown |= pkt->u.mode;
WRITE_ONCE(vsk->peer_shutdown,
READ_ONCE(vsk->peer_shutdown) |
pkt->u.mode);
sk->sk_state_change(sk);
}
break;
@ -1559,7 +1561,7 @@ static int vmci_transport_recv_connected(struct sock *sk,
* a clean shutdown.
*/
sock_set_flag(sk, SOCK_DONE);
vsk->peer_shutdown = SHUTDOWN_MASK;
WRITE_ONCE(vsk->peer_shutdown, SHUTDOWN_MASK);
if (vsock_stream_has_data(vsk) <= 0)
sk->sk_state = TCP_CLOSING;