diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index ad882127079f1f62760517f187705f0d2ae24acf..505cf19ace800f516e5742e5f7df27bae61c6d77 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -687,6 +687,7 @@ struct kvm_vcpu_arch {
 	struct kvm_cpuid_entry2 cpuid_entries[KVM_MAX_CPUID_ENTRIES];
 
 	int maxphyaddr;
+	int tdp_level;
 
 	/* emulate context */
 
diff --git a/arch/x86/kvm/cpuid.c b/arch/x86/kvm/cpuid.c
index 6828be99b9083180c033ca2fe3d92fdaa7f446c7..44dfaefdad0ec0e0e202ae8af9092e9c92581624 100644
--- a/arch/x86/kvm/cpuid.c
+++ b/arch/x86/kvm/cpuid.c
@@ -124,8 +124,9 @@ int kvm_update_cpuid(struct kvm_vcpu *vcpu)
 					   MSR_IA32_MISC_ENABLE_MWAIT);
 	}
 
-	/* Update physical-address width */
+	/* Note, maxphyaddr must be updated before tdp_level. */
 	vcpu->arch.maxphyaddr = cpuid_query_maxphyaddr(vcpu);
+	vcpu->arch.tdp_level = kvm_x86_ops.get_tdp_level(vcpu);
 	kvm_mmu_reset_context(vcpu);
 
 	kvm_pmu_refresh(vcpu);
diff --git a/arch/x86/kvm/mmu/mmu.c b/arch/x86/kvm/mmu/mmu.c
index e618472c572bf37d8f1a8417e028a8c8e7c61717..10cb8db54cd0478b398289709514cf6ec5244592 100644
--- a/arch/x86/kvm/mmu/mmu.c
+++ b/arch/x86/kvm/mmu/mmu.c
@@ -4894,7 +4894,7 @@ kvm_calc_tdp_mmu_root_page_role(struct kvm_vcpu *vcpu, bool base_only)
 	union kvm_mmu_role role = kvm_calc_mmu_role_common(vcpu, base_only);
 
 	role.base.ad_disabled = (shadow_accessed_mask == 0);
-	role.base.level = kvm_x86_ops.get_tdp_level(vcpu);
+	role.base.level = vcpu->arch.tdp_level;
 	role.base.direct = true;
 	role.base.gpte_is_8_bytes = true;
 
@@ -4915,7 +4915,7 @@ static void init_kvm_tdp_mmu(struct kvm_vcpu *vcpu)
 	context->sync_page = nonpaging_sync_page;
 	context->invlpg = NULL;
 	context->update_pte = nonpaging_update_pte;
-	context->shadow_root_level = kvm_x86_ops.get_tdp_level(vcpu);
+	context->shadow_root_level = vcpu->arch.tdp_level;
 	context->direct_map = true;
 	context->get_guest_pgd = get_cr3;
 	context->get_pdptr = kvm_pdptr_read;
@@ -5680,7 +5680,7 @@ static int alloc_mmu_pages(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu)
 	 * SVM's 32-bit NPT support, TDP paging doesn't use PAE paging and can
 	 * skip allocating the PDP table.
 	 */
-	if (tdp_enabled && kvm_x86_ops.get_tdp_level(vcpu) > PT32E_ROOT_LEVEL)
+	if (tdp_enabled && vcpu->arch.tdp_level > PT32E_ROOT_LEVEL)
 		return 0;
 
 	page = alloc_page(GFP_KERNEL_ACCOUNT | __GFP_DMA32);
diff --git a/arch/x86/kvm/svm/nested.c b/arch/x86/kvm/svm/nested.c
index 712db507c81923f5e1a6d9a31b439abc73c89b8a..a89a166d1cb80d3e18cee5941b6d14688adef342 100644
--- a/arch/x86/kvm/svm/nested.c
+++ b/arch/x86/kvm/svm/nested.c
@@ -86,7 +86,7 @@ static void nested_svm_init_mmu_context(struct kvm_vcpu *vcpu)
 	vcpu->arch.mmu->get_guest_pgd     = nested_svm_get_tdp_cr3;
 	vcpu->arch.mmu->get_pdptr         = nested_svm_get_tdp_pdptr;
 	vcpu->arch.mmu->inject_page_fault = nested_svm_inject_npf_exit;
-	vcpu->arch.mmu->shadow_root_level = kvm_x86_ops.get_tdp_level(vcpu);
+	vcpu->arch.mmu->shadow_root_level = vcpu->arch.tdp_level;
 	reset_shadow_zero_bits_mask(vcpu, vcpu->arch.mmu);
 	vcpu->arch.walk_mmu              = &vcpu->arch.nested_mmu;
 }
diff --git a/arch/x86/kvm/vmx/vmx.c b/arch/x86/kvm/vmx/vmx.c
index 7a43fbe05e2d5ad2256b70d0957169b802dc82ba..93b2a708b1da47615899618436f66b980c6204c9 100644
--- a/arch/x86/kvm/vmx/vmx.c
+++ b/arch/x86/kvm/vmx/vmx.c
@@ -3025,8 +3025,6 @@ void vmx_set_cr0(struct kvm_vcpu *vcpu, unsigned long cr0)
 
 static int vmx_get_tdp_level(struct kvm_vcpu *vcpu)
 {
-	WARN_ON(is_guest_mode(vcpu) && nested_cpu_has_ept(get_vmcs12(vcpu)));
-
 	if (cpu_has_vmx_ept_5levels() && (cpuid_maxphyaddr(vcpu) > 48))
 		return 5;
 	return 4;