diff --git a/arch/x86/kvm/emulate.c b/arch/x86/kvm/emulate.c
index a6b2828532530725d8c713360bb1cf830e046223..f526acee2eed6f696c92a0a84358872027b624fc 100644
--- a/arch/x86/kvm/emulate.c
+++ b/arch/x86/kvm/emulate.c
@@ -2571,6 +2571,12 @@ static int em_rsm(struct x86_emulate_ctxt *ctxt)
 	if (ret != X86EMUL_CONTINUE)
 		return X86EMUL_UNHANDLEABLE;
 
+	if ((ctxt->ops->get_hflags(ctxt) & X86EMUL_SMM_INSIDE_NMI_MASK) == 0)
+		ctxt->ops->set_nmi_mask(ctxt, false);
+
+	ctxt->ops->set_hflags(ctxt, ctxt->ops->get_hflags(ctxt) &
+		~(X86EMUL_SMM_INSIDE_NMI_MASK | X86EMUL_SMM_MASK));
+
 	/*
 	 * Get back to real mode, to prepare a safe state in which to load
 	 * CR0/CR3/CR4/EFER.  It's all a bit more complicated if the vCPU
@@ -2624,12 +2630,6 @@ static int em_rsm(struct x86_emulate_ctxt *ctxt)
 		return X86EMUL_UNHANDLEABLE;
 	}
 
-	if ((ctxt->ops->get_hflags(ctxt) & X86EMUL_SMM_INSIDE_NMI_MASK) == 0)
-		ctxt->ops->set_nmi_mask(ctxt, false);
-
-	ctxt->ops->set_hflags(ctxt, ctxt->ops->get_hflags(ctxt) &
-		~(X86EMUL_SMM_INSIDE_NMI_MASK | X86EMUL_SMM_MASK));
-
 	ctxt->ops->post_leave_smm(ctxt);
 
 	return X86EMUL_CONTINUE;
diff --git a/arch/x86/kvm/svm.c b/arch/x86/kvm/svm.c
index 6b1cd73e4053d56d030062c0856c87af4186dbec..406b558abfef7379eb46bd2de18e5d6890079eb9 100644
--- a/arch/x86/kvm/svm.c
+++ b/arch/x86/kvm/svm.c
@@ -6239,21 +6239,17 @@ static int svm_pre_leave_smm(struct kvm_vcpu *vcpu, const char *smstate)
 	struct page *page;
 	u64 guest;
 	u64 vmcb;
-	int ret;
 
 	guest = GET_SMSTATE(u64, smstate, 0x7ed8);
 	vmcb = GET_SMSTATE(u64, smstate, 0x7ee0);
 
 	if (guest) {
-		vcpu->arch.hflags &= ~HF_SMM_MASK;
 		nested_vmcb = nested_svm_map(svm, vmcb, &page);
-		if (nested_vmcb)
-			enter_svm_guest_mode(svm, vmcb, nested_vmcb, page);
-		else
-			ret = 1;
-		vcpu->arch.hflags |= HF_SMM_MASK;
+		if (!nested_vmcb)
+			return 1;
+		enter_svm_guest_mode(svm, vmcb, nested_vmcb, page);
 	}
-	return ret;
+	return 0;
 }
 
 static int enable_smi_window(struct kvm_vcpu *vcpu)
diff --git a/arch/x86/kvm/vmx/vmx.c b/arch/x86/kvm/vmx/vmx.c
index 14ea25eadde8bb4e404e7290833ae741a59515e7..b4e7d645275a2153c42fa252cce8a8cbb930b59e 100644
--- a/arch/x86/kvm/vmx/vmx.c
+++ b/arch/x86/kvm/vmx/vmx.c
@@ -7409,9 +7409,7 @@ static int vmx_pre_leave_smm(struct kvm_vcpu *vcpu, const char *smstate)
 	}
 
 	if (vmx->nested.smm.guest_mode) {
-		vcpu->arch.hflags &= ~HF_SMM_MASK;
 		ret = nested_vmx_enter_non_root_mode(vcpu, false);
-		vcpu->arch.hflags |= HF_SMM_MASK;
 		if (ret)
 			return ret;