diff --git a/arch/arm64/kernel/signal.c b/arch/arm64/kernel/signal.c index 9c2e26e01d72..08ffc5a5aea4 100644 --- a/arch/arm64/kernel/signal.c +++ b/arch/arm64/kernel/signal.c @@ -449,12 +449,28 @@ static int restore_sve_fpsimd_context(struct user_ctxs *user) if (user->sve_size < SVE_SIG_CONTEXT_SIZE(vq)) return -EINVAL; + if (sm) { + sme_alloc(current, false); + if (!current->thread.sme_state) + return -ENOMEM; + } + sve_alloc(current, true); if (!current->thread.sve_state) { clear_thread_flag(TIF_SVE); return -ENOMEM; } + if (sm) { + current->thread.svcr |= SVCR_SM_MASK; + set_thread_flag(TIF_SME); + } else { + current->thread.svcr &= ~SVCR_SM_MASK; + set_thread_flag(TIF_SVE); + } + + current->thread.fp_type = FP_STATE_SVE; + err = __copy_from_user(current->thread.sve_state, (char __user const *)user->sve + SVE_SIG_REGS_OFFSET, @@ -462,12 +478,6 @@ static int restore_sve_fpsimd_context(struct user_ctxs *user) if (err) return -EFAULT; - if (flags & SVE_SIG_FLAG_SM) - current->thread.svcr |= SVCR_SM_MASK; - else - set_thread_flag(TIF_SVE); - current->thread.fp_type = FP_STATE_SVE; - err = read_fpsimd_context(&fpsimd, user); if (err) return err;