diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index 54bb706e40e8d6227fb309c338c4c0a2b634fb99..364b2a737d94cc756ab17830f506f06c60e6ca9b 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -332,30 +332,41 @@ static inline bool is_access_track_spte(u64 spte)
 }
 
 /*
- * the low bit of the generation number is always presumed to be zero.
- * This disables mmio caching during memslot updates.  The concept is
- * similar to a seqcount but instead of retrying the access we just punt
- * and ignore the cache.
+ * Due to limited space in PTEs, the MMIO generation is a 19 bit subset of
+ * the memslots generation and is derived as follows:
  *
- * spte bits 3-11 are used as bits 1-9 of the generation number,
- * the bits 52-61 are used as bits 10-19 of the generation number.
+ * Bits 1-9 of the memslot generation are propagated to spte bits 3-11
+ * Bits 10-19 of the memslot generation are propagated to spte bits 52-61
+ *
+ * The MMIO generation starts at bit 1 of the memslots generation in order to
+ * skip over bit 0, the KVM_MEMSLOT_GEN_UPDATE_IN_PROGRESS flag.  Including
+ * the flag would require stealing a bit from the "real" generation number and
+ * thus effectively halve the maximum number of MMIO generations that can be
+ * handled before encountering a wrap (which requires a full MMU zap).  The
+ * flag is instead explicitly queried when checking for MMIO spte cache hits.
  */
-#define MMIO_SPTE_GEN_LOW_SHIFT		2
-#define MMIO_SPTE_GEN_HIGH_SHIFT	52
-
-#define MMIO_GEN_SHIFT			20
-#define MMIO_GEN_LOW_SHIFT		10
-#define MMIO_GEN_LOW_MASK		((1 << MMIO_GEN_LOW_SHIFT) - 2)
-#define MMIO_GEN_MASK			((1 << MMIO_GEN_SHIFT) - 1)
-
+#define MMIO_SPTE_GEN_MASK		GENMASK_ULL(19, 1)
+#define MMIO_SPTE_GEN_SHIFT		1
+
+#define MMIO_SPTE_GEN_LOW_START		3
+#define MMIO_SPTE_GEN_LOW_END		11
+#define MMIO_SPTE_GEN_LOW_MASK		GENMASK_ULL(MMIO_SPTE_GEN_LOW_END, \
+						    MMIO_SPTE_GEN_LOW_START)
+
+#define MMIO_SPTE_GEN_HIGH_START	52
+#define MMIO_SPTE_GEN_HIGH_END		61
+#define MMIO_SPTE_GEN_HIGH_MASK		GENMASK_ULL(MMIO_SPTE_GEN_HIGH_END, \
+						    MMIO_SPTE_GEN_HIGH_START)
 static u64 generation_mmio_spte_mask(u64 gen)
 {
 	u64 mask;
 
-	WARN_ON(gen & ~MMIO_GEN_MASK);
+	WARN_ON(gen & ~MMIO_SPTE_GEN_MASK);
 
-	mask = (gen & MMIO_GEN_LOW_MASK) << MMIO_SPTE_GEN_LOW_SHIFT;
-	mask |= (gen >> MMIO_GEN_LOW_SHIFT) << MMIO_SPTE_GEN_HIGH_SHIFT;
+	gen >>= MMIO_SPTE_GEN_SHIFT;
+
+	mask = (gen << MMIO_SPTE_GEN_LOW_START) & MMIO_SPTE_GEN_LOW_MASK;
+	mask |= (gen << MMIO_SPTE_GEN_HIGH_START) & MMIO_SPTE_GEN_HIGH_MASK;
 	return mask;
 }
 
@@ -365,20 +376,15 @@ static u64 get_mmio_spte_generation(u64 spte)
 
 	spte &= ~shadow_mmio_mask;
 
-	gen = (spte >> MMIO_SPTE_GEN_LOW_SHIFT) & MMIO_GEN_LOW_MASK;
-	gen |= (spte >> MMIO_SPTE_GEN_HIGH_SHIFT) << MMIO_GEN_LOW_SHIFT;
-	return gen;
-}
-
-static u64 kvm_current_mmio_generation(struct kvm_vcpu *vcpu)
-{
-	return kvm_vcpu_memslots(vcpu)->generation & MMIO_GEN_MASK;
+	gen = (spte & MMIO_SPTE_GEN_LOW_MASK) >> MMIO_SPTE_GEN_LOW_START;
+	gen |= (spte & MMIO_SPTE_GEN_HIGH_MASK) >> MMIO_SPTE_GEN_HIGH_START;
+	return gen << MMIO_SPTE_GEN_SHIFT;
 }
 
 static void mark_mmio_spte(struct kvm_vcpu *vcpu, u64 *sptep, u64 gfn,
 			   unsigned access)
 {
-	u64 gen = kvm_current_mmio_generation(vcpu);
+	u64 gen = kvm_vcpu_memslots(vcpu)->generation & MMIO_SPTE_GEN_MASK;
 	u64 mask = generation_mmio_spte_mask(gen);
 	u64 gpa = gfn << PAGE_SHIFT;
 
@@ -409,7 +415,7 @@ static gfn_t get_mmio_spte_gfn(u64 spte)
 
 static unsigned get_mmio_spte_access(u64 spte)
 {
-	u64 mask = generation_mmio_spte_mask(MMIO_GEN_MASK) | shadow_mmio_mask;
+	u64 mask = generation_mmio_spte_mask(MMIO_SPTE_GEN_MASK) | shadow_mmio_mask;
 	return (spte & ~mask) & ~PAGE_MASK;
 }
 
@@ -426,9 +432,13 @@ static bool set_mmio_spte(struct kvm_vcpu *vcpu, u64 *sptep, gfn_t gfn,
 
 static bool check_mmio_spte(struct kvm_vcpu *vcpu, u64 spte)
 {
-	u64 kvm_gen, spte_gen;
+	u64 kvm_gen, spte_gen, gen;
+
+	gen = kvm_vcpu_memslots(vcpu)->generation;
+	if (unlikely(gen & KVM_MEMSLOT_GEN_UPDATE_IN_PROGRESS))
+		return false;
 
-	kvm_gen = kvm_current_mmio_generation(vcpu);
+	kvm_gen = gen & MMIO_SPTE_GEN_MASK;
 	spte_gen = get_mmio_spte_generation(spte);
 
 	trace_check_mmio_spte(spte, kvm_gen, spte_gen);
@@ -5895,13 +5905,13 @@ static bool kvm_has_zapped_obsolete_pages(struct kvm *kvm)
 
 void kvm_mmu_invalidate_mmio_sptes(struct kvm *kvm, u64 gen)
 {
-	gen &= MMIO_GEN_MASK;
+	gen &= MMIO_SPTE_GEN_MASK;
 
 	/*
-	 * Shift to eliminate the "update in-progress" flag, which isn't
-	 * included in the spte's generation number.
+	 * Shift to adjust for the "update in-progress" flag, which isn't
+	 * included in the MMIO generation number.
 	 */
-	gen >>= 1;
+	gen >>= MMIO_SPTE_GEN_SHIFT;
 
 	/*
 	 * Generation numbers are incremented in multiples of the number of