hmm.c 34.3 KB
Newer Older
1
// SPDX-License-Identifier: GPL-2.0-or-later
2 3 4
/*
 * Copyright 2013 Red Hat Inc.
 *
Jérôme Glisse's avatar
Jérôme Glisse committed
5
 * Authors: Jérôme Glisse <jglisse@redhat.com>
6 7 8 9 10 11 12
 */
/*
 * Refer to include/linux/hmm.h for information about heterogeneous memory
 * management or HMM for short.
 */
#include <linux/mm.h>
#include <linux/hmm.h>
13
#include <linux/init.h>
14 15
#include <linux/rmap.h>
#include <linux/swap.h>
16 17
#include <linux/slab.h>
#include <linux/sched.h>
18 19
#include <linux/mmzone.h>
#include <linux/pagemap.h>
20 21
#include <linux/swapops.h>
#include <linux/hugetlb.h>
22
#include <linux/memremap.h>
23
#include <linux/sched/mm.h>
24
#include <linux/jump_label.h>
25
#include <linux/dma-mapping.h>
26
#include <linux/mmu_notifier.h>
27 28
#include <linux/memory_hotplug.h>

29 30
static const struct mmu_notifier_ops hmm_mmu_notifier_ops;

31 32
/**
 * hmm_get_or_create - register HMM against an mm (HMM internal)
33 34
 *
 * @mm: mm struct to attach to
35 36
 * Returns: returns an HMM object, either by referencing the existing
 *          (per-process) object, or by creating a new one.
37
 *
38 39 40 41
 * This is not intended to be used directly by device drivers. If mm already
 * has an HMM struct then it get a reference on it and returns it. Otherwise
 * it allocates an HMM struct, initializes it, associate it with the mm and
 * returns it.
42
 */
43
static struct hmm *hmm_get_or_create(struct mm_struct *mm)
44
{
45
	struct hmm *hmm;
46

47
	lockdep_assert_held_write(&mm->mmap_sem);
48

49 50 51 52 53 54
	/* Abuse the page_table_lock to also protect mm->hmm. */
	spin_lock(&mm->page_table_lock);
	hmm = mm->hmm;
	if (mm->hmm && kref_get_unless_zero(&mm->hmm->kref))
		goto out_unlock;
	spin_unlock(&mm->page_table_lock);
55 56 57 58

	hmm = kmalloc(sizeof(*hmm), GFP_KERNEL);
	if (!hmm)
		return NULL;
59
	init_waitqueue_head(&hmm->wq);
60 61 62
	INIT_LIST_HEAD(&hmm->mirrors);
	init_rwsem(&hmm->mirrors_sem);
	hmm->mmu_notifier.ops = NULL;
63
	INIT_LIST_HEAD(&hmm->ranges);
64
	spin_lock_init(&hmm->ranges_lock);
65
	kref_init(&hmm->kref);
66
	hmm->notifiers = 0;
67 68
	hmm->mm = mm;

69 70 71 72 73
	hmm->mmu_notifier.ops = &hmm_mmu_notifier_ops;
	if (__mmu_notifier_register(&hmm->mmu_notifier, mm)) {
		kfree(hmm);
		return NULL;
	}
74

75
	mmgrab(hmm->mm);
76 77

	/*
78 79
	 * We hold the exclusive mmap_sem here so we know that mm->hmm is
	 * still NULL or 0 kref, and is safe to update.
80 81
	 */
	spin_lock(&mm->page_table_lock);
82
	mm->hmm = hmm;
83

84
out_unlock:
85
	spin_unlock(&mm->page_table_lock);
86
	return hmm;
87
}
88

89 90
static void hmm_free_rcu(struct rcu_head *rcu)
{
91 92 93
	struct hmm *hmm = container_of(rcu, struct hmm, rcu);

	mmdrop(hmm->mm);
94
	kfree(hmm);
95 96
}

97 98 99 100
static void hmm_free(struct kref *kref)
{
	struct hmm *hmm = container_of(kref, struct hmm, kref);

101 102 103 104
	spin_lock(&hmm->mm->page_table_lock);
	if (hmm->mm->hmm == hmm)
		hmm->mm->hmm = NULL;
	spin_unlock(&hmm->mm->page_table_lock);
105

106
	mmu_notifier_unregister_no_release(&hmm->mmu_notifier, hmm->mm);
107
	mmu_notifier_call_srcu(&hmm->rcu, hmm_free_rcu);
108 109 110 111 112 113 114
}

static inline void hmm_put(struct hmm *hmm)
{
	kref_put(&hmm->kref, hmm_free);
}

115
static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
116
{
117
	struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
118
	struct hmm_mirror *mirror;
119

120 121
	/* Bail out if hmm is in the process of being freed */
	if (!kref_get_unless_zero(&hmm->kref))
122
		return;
123

124 125 126 127 128
	/*
	 * Since hmm_range_register() holds the mmget() lock hmm_release() is
	 * prevented as long as a range exists.
	 */
	WARN_ON(!list_empty_careful(&hmm->ranges));
129

130 131 132 133 134 135 136
	down_read(&hmm->mirrors_sem);
	list_for_each_entry(mirror, &hmm->mirrors, list) {
		/*
		 * Note: The driver is not allowed to trigger
		 * hmm_mirror_unregister() from this thread.
		 */
		if (mirror->ops->release)
137
			mirror->ops->release(mirror);
138
	}
139
	up_read(&hmm->mirrors_sem);
140 141

	hmm_put(hmm);
142
}
143

144
static void notifiers_decrement(struct hmm *hmm)
145
{
146
	unsigned long flags;
147

148 149 150 151
	spin_lock_irqsave(&hmm->ranges_lock, flags);
	hmm->notifiers--;
	if (!hmm->notifiers) {
		struct hmm_range *range;
152

153 154 155 156
		list_for_each_entry(range, &hmm->ranges, list) {
			if (range->valid)
				continue;
			range->valid = true;
157
		}
158
		wake_up_all(&hmm->wq);
159
	}
160
	spin_unlock_irqrestore(&hmm->ranges_lock, flags);
161 162
}

163
static int hmm_invalidate_range_start(struct mmu_notifier *mn,
164
			const struct mmu_notifier_range *nrange)
165
{
166
	struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
167
	struct hmm_mirror *mirror;
168
	struct hmm_update update;
169
	struct hmm_range *range;
170
	unsigned long flags;
171
	int ret = 0;
172

173 174
	if (!kref_get_unless_zero(&hmm->kref))
		return 0;
175

176 177
	update.start = nrange->start;
	update.end = nrange->end;
178
	update.event = HMM_UPDATE_INVALIDATE;
179
	update.blockable = mmu_notifier_range_blockable(nrange);
180

181
	spin_lock_irqsave(&hmm->ranges_lock, flags);
182 183 184 185 186 187 188
	hmm->notifiers++;
	list_for_each_entry(range, &hmm->ranges, list) {
		if (update.end < range->start || update.start >= range->end)
			continue;

		range->valid = false;
	}
189
	spin_unlock_irqrestore(&hmm->ranges_lock, flags);
190

191
	if (mmu_notifier_range_blockable(nrange))
192 193 194 195 196
		down_read(&hmm->mirrors_sem);
	else if (!down_read_trylock(&hmm->mirrors_sem)) {
		ret = -EAGAIN;
		goto out;
	}
197

198
	list_for_each_entry(mirror, &hmm->mirrors, list) {
199
		int rc;
200

201 202 203 204
		rc = mirror->ops->sync_cpu_device_pagetables(mirror, &update);
		if (rc) {
			if (WARN_ON(update.blockable || rc != -EAGAIN))
				continue;
205
			ret = -EAGAIN;
206
			break;
207 208 209 210 211
		}
	}
	up_read(&hmm->mirrors_sem);

out:
212 213
	if (ret)
		notifiers_decrement(hmm);
214 215
	hmm_put(hmm);
	return ret;
216 217 218
}

static void hmm_invalidate_range_end(struct mmu_notifier *mn,
219
			const struct mmu_notifier_range *nrange)
220
{
221
	struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
222

223 224
	if (!kref_get_unless_zero(&hmm->kref))
		return;
225

226
	notifiers_decrement(hmm);
227
	hmm_put(hmm);
228 229 230
}

static const struct mmu_notifier_ops hmm_mmu_notifier_ops = {
231
	.release		= hmm_release,
232 233 234 235 236 237 238 239 240
	.invalidate_range_start	= hmm_invalidate_range_start,
	.invalidate_range_end	= hmm_invalidate_range_end,
};

/*
 * hmm_mirror_register() - register a mirror against an mm
 *
 * @mirror: new mirror struct to register
 * @mm: mm to register against
241
 * Return: 0 on success, -ENOMEM if no memory, -EINVAL if invalid arguments
242 243 244 245 246 247
 *
 * To start mirroring a process address space, the device driver must register
 * an HMM mirror struct.
 */
int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm)
{
248
	lockdep_assert_held_write(&mm->mmap_sem);
249

250 251 252 253
	/* Sanity check */
	if (!mm || !mirror || !mirror->ops)
		return -EINVAL;

254
	mirror->hmm = hmm_get_or_create(mm);
255 256 257 258
	if (!mirror->hmm)
		return -ENOMEM;

	down_write(&mirror->hmm->mirrors_sem);
259 260
	list_add(&mirror->list, &mirror->hmm->mirrors);
	up_write(&mirror->hmm->mirrors_sem);
261 262 263 264 265 266 267 268

	return 0;
}
EXPORT_SYMBOL(hmm_mirror_register);

/*
 * hmm_mirror_unregister() - unregister a mirror
 *
269
 * @mirror: mirror struct to unregister
270 271 272 273 274
 *
 * Stop mirroring a process address space, and cleanup.
 */
void hmm_mirror_unregister(struct hmm_mirror *mirror)
{
275
	struct hmm *hmm = mirror->hmm;
276 277

	down_write(&hmm->mirrors_sem);
278
	list_del(&mirror->list);
279
	up_write(&hmm->mirrors_sem);
280
	hmm_put(hmm);
281 282
}
EXPORT_SYMBOL(hmm_mirror_unregister);
283

284 285
struct hmm_vma_walk {
	struct hmm_range	*range;
286
	struct dev_pagemap	*pgmap;
287 288 289 290 291
	unsigned long		last;
	bool			fault;
	bool			block;
};

292 293
static int hmm_vma_do_fault(struct mm_walk *walk, unsigned long addr,
			    bool write_fault, uint64_t *pfn)
294
{
295
	unsigned int flags = FAULT_FLAG_REMOTE;
296
	struct hmm_vma_walk *hmm_vma_walk = walk->private;
297
	struct hmm_range *range = hmm_vma_walk->range;
298
	struct vm_area_struct *vma = walk->vma;
299
	vm_fault_t ret;
300 301

	flags |= hmm_vma_walk->block ? 0 : FAULT_FLAG_ALLOW_RETRY;
302
	flags |= write_fault ? FAULT_FLAG_WRITE : 0;
303 304
	ret = handle_mm_fault(vma, addr, flags);
	if (ret & VM_FAULT_RETRY)
305
		return -EAGAIN;
306
	if (ret & VM_FAULT_ERROR) {
307
		*pfn = range->values[HMM_PFN_ERROR];
308 309 310
		return -EFAULT;
	}

311
	return -EBUSY;
312 313
}

314 315 316 317
static int hmm_pfns_bad(unsigned long addr,
			unsigned long end,
			struct mm_walk *walk)
{
318 319
	struct hmm_vma_walk *hmm_vma_walk = walk->private;
	struct hmm_range *range = hmm_vma_walk->range;
320
	uint64_t *pfns = range->pfns;
321 322 323 324
	unsigned long i;

	i = (addr - range->start) >> PAGE_SHIFT;
	for (; addr < end; addr += PAGE_SIZE, i++)
325
		pfns[i] = range->values[HMM_PFN_ERROR];
326 327 328 329

	return 0;
}

330 331 332 333
/*
 * hmm_vma_walk_hole() - handle a range lacking valid pmd or pte(s)
 * @start: range virtual start address (inclusive)
 * @end: range virtual end address (exclusive)
334 335
 * @fault: should we fault or not ?
 * @write_fault: write fault ?
336
 * @walk: mm_walk structure
337
 * Return: 0 on success, -EBUSY after page fault, or page fault error
338 339 340 341
 *
 * This function will be called whenever pmd_none() or pte_none() returns true,
 * or whenever there is no page directory covering the virtual address range.
 */
342 343 344
static int hmm_vma_walk_hole_(unsigned long addr, unsigned long end,
			      bool fault, bool write_fault,
			      struct mm_walk *walk)
345
{
346 347
	struct hmm_vma_walk *hmm_vma_walk = walk->private;
	struct hmm_range *range = hmm_vma_walk->range;
348
	uint64_t *pfns = range->pfns;
349
	unsigned long i, page_size;
350

351
	hmm_vma_walk->last = addr;
352 353 354 355
	page_size = hmm_range_page_size(range);
	i = (addr - range->start) >> range->page_shift;

	for (; addr < end; addr += page_size, i++) {
356
		pfns[i] = range->values[HMM_PFN_NONE];
357
		if (fault || write_fault) {
358
			int ret;
359

360 361
			ret = hmm_vma_do_fault(walk, addr, write_fault,
					       &pfns[i]);
362
			if (ret != -EBUSY)
363 364 365 366
				return ret;
		}
	}

367
	return (fault || write_fault) ? -EBUSY : 0;
368 369 370 371 372 373
}

static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
				      uint64_t pfns, uint64_t cpu_flags,
				      bool *fault, bool *write_fault)
{
374 375
	struct hmm_range *range = hmm_vma_walk->range;

376 377 378
	if (!hmm_vma_walk->fault)
		return;

379 380 381 382 383 384 385 386 387 388 389 390
	/*
	 * So we not only consider the individual per page request we also
	 * consider the default flags requested for the range. The API can
	 * be use in 2 fashions. The first one where the HMM user coalesce
	 * multiple page fault into one request and set flags per pfns for
	 * of those faults. The second one where the HMM user want to pre-
	 * fault a range with specific flags. For the latter one it is a
	 * waste to have the user pre-fill the pfn arrays with a default
	 * flags value.
	 */
	pfns = (pfns & range->pfn_flags_mask) | range->default_flags;

391
	/* We aren't ask to do anything ... */
392
	if (!(pfns & range->flags[HMM_PFN_VALID]))
393
		return;
394 395 396 397 398 399 400
	/* If this is device memory than only fault if explicitly requested */
	if ((cpu_flags & range->flags[HMM_PFN_DEVICE_PRIVATE])) {
		/* Do we fault on device memory ? */
		if (pfns & range->flags[HMM_PFN_DEVICE_PRIVATE]) {
			*write_fault = pfns & range->flags[HMM_PFN_WRITE];
			*fault = true;
		}
401 402
		return;
	}
403 404 405 406 407 408 409

	/* If CPU page table is not valid then we need to fault */
	*fault = !(cpu_flags & range->flags[HMM_PFN_VALID]);
	/* Need to write fault ? */
	if ((pfns & range->flags[HMM_PFN_WRITE]) &&
	    !(cpu_flags & range->flags[HMM_PFN_WRITE])) {
		*write_fault = true;
410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
		*fault = true;
	}
}

static void hmm_range_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
				 const uint64_t *pfns, unsigned long npages,
				 uint64_t cpu_flags, bool *fault,
				 bool *write_fault)
{
	unsigned long i;

	if (!hmm_vma_walk->fault) {
		*fault = *write_fault = false;
		return;
	}

426
	*fault = *write_fault = false;
427 428 429
	for (i = 0; i < npages; ++i) {
		hmm_pte_need_fault(hmm_vma_walk, pfns[i], cpu_flags,
				   fault, write_fault);
430
		if ((*write_fault))
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
			return;
	}
}

static int hmm_vma_walk_hole(unsigned long addr, unsigned long end,
			     struct mm_walk *walk)
{
	struct hmm_vma_walk *hmm_vma_walk = walk->private;
	struct hmm_range *range = hmm_vma_walk->range;
	bool fault, write_fault;
	unsigned long i, npages;
	uint64_t *pfns;

	i = (addr - range->start) >> PAGE_SHIFT;
	npages = (end - addr) >> PAGE_SHIFT;
	pfns = &range->pfns[i];
	hmm_range_need_fault(hmm_vma_walk, pfns, npages,
			     0, &fault, &write_fault);
	return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
}

452
static inline uint64_t pmd_to_hmm_pfn_flags(struct hmm_range *range, pmd_t pmd)
453 454 455
{
	if (pmd_protnone(pmd))
		return 0;
456 457 458
	return pmd_write(pmd) ? range->flags[HMM_PFN_VALID] |
				range->flags[HMM_PFN_WRITE] :
				range->flags[HMM_PFN_VALID];
459 460
}

461 462 463 464 465 466 467 468 469
static inline uint64_t pud_to_hmm_pfn_flags(struct hmm_range *range, pud_t pud)
{
	if (!pud_present(pud))
		return 0;
	return pud_write(pud) ? range->flags[HMM_PFN_VALID] |
				range->flags[HMM_PFN_WRITE] :
				range->flags[HMM_PFN_VALID];
}

470 471 472 473 474 475
static int hmm_vma_handle_pmd(struct mm_walk *walk,
			      unsigned long addr,
			      unsigned long end,
			      uint64_t *pfns,
			      pmd_t pmd)
{
476
#ifdef CONFIG_TRANSPARENT_HUGEPAGE
477
	struct hmm_vma_walk *hmm_vma_walk = walk->private;
478
	struct hmm_range *range = hmm_vma_walk->range;
479 480
	unsigned long pfn, npages, i;
	bool fault, write_fault;
481
	uint64_t cpu_flags;
482

483
	npages = (end - addr) >> PAGE_SHIFT;
484
	cpu_flags = pmd_to_hmm_pfn_flags(range, pmd);
485 486
	hmm_range_need_fault(hmm_vma_walk, pfns, npages, cpu_flags,
			     &fault, &write_fault);
487

488 489
	if (pmd_protnone(pmd) || fault || write_fault)
		return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
490 491

	pfn = pmd_pfn(pmd) + pte_index(addr);
492 493 494 495 496 497 498
	for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++) {
		if (pmd_devmap(pmd)) {
			hmm_vma_walk->pgmap = get_dev_pagemap(pfn,
					      hmm_vma_walk->pgmap);
			if (unlikely(!hmm_vma_walk->pgmap))
				return -EBUSY;
		}
499
		pfns[i] = hmm_device_entry_from_pfn(range, pfn) | cpu_flags;
500 501 502 503 504
	}
	if (hmm_vma_walk->pgmap) {
		put_dev_pagemap(hmm_vma_walk->pgmap);
		hmm_vma_walk->pgmap = NULL;
	}
505 506
	hmm_vma_walk->last = end;
	return 0;
507 508 509 510
#else
	/* If THP is not enabled then we should never reach that code ! */
	return -EINVAL;
#endif
511 512
}

513
static inline uint64_t pte_to_hmm_pfn_flags(struct hmm_range *range, pte_t pte)
514
{
515
	if (pte_none(pte) || !pte_present(pte) || pte_protnone(pte))
516
		return 0;
517 518 519
	return pte_write(pte) ? range->flags[HMM_PFN_VALID] |
				range->flags[HMM_PFN_WRITE] :
				range->flags[HMM_PFN_VALID];
520 521
}

522 523 524 525 526
static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
			      unsigned long end, pmd_t *pmdp, pte_t *ptep,
			      uint64_t *pfn)
{
	struct hmm_vma_walk *hmm_vma_walk = walk->private;
527
	struct hmm_range *range = hmm_vma_walk->range;
528
	struct vm_area_struct *vma = walk->vma;
529 530
	bool fault, write_fault;
	uint64_t cpu_flags;
531
	pte_t pte = *ptep;
532
	uint64_t orig_pfn = *pfn;
533

534
	*pfn = range->values[HMM_PFN_NONE];
535
	fault = write_fault = false;
536 537

	if (pte_none(pte)) {
538 539
		hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0,
				   &fault, &write_fault);
540
		if (fault || write_fault)
541 542 543 544 545 546 547 548
			goto fault;
		return 0;
	}

	if (!pte_present(pte)) {
		swp_entry_t entry = pte_to_swp_entry(pte);

		if (!non_swap_entry(entry)) {
549
			if (fault || write_fault)
550 551 552 553 554 555 556 557 558
				goto fault;
			return 0;
		}

		/*
		 * This is a special swap entry, ignore migration, use
		 * device and report anything else as error.
		 */
		if (is_device_private_entry(entry)) {
559 560
			cpu_flags = range->flags[HMM_PFN_VALID] |
				range->flags[HMM_PFN_DEVICE_PRIVATE];
561
			cpu_flags |= is_write_device_private_entry(entry) ?
562 563 564 565 566
				range->flags[HMM_PFN_WRITE] : 0;
			hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
					   &fault, &write_fault);
			if (fault || write_fault)
				goto fault;
567 568
			*pfn = hmm_device_entry_from_pfn(range,
					    swp_offset(entry));
569
			*pfn |= cpu_flags;
570 571 572 573
			return 0;
		}

		if (is_migration_entry(entry)) {
574
			if (fault || write_fault) {
575 576 577
				pte_unmap(ptep);
				hmm_vma_walk->last = addr;
				migration_entry_wait(vma->vm_mm,
578
						     pmdp, addr);
579
				return -EBUSY;
580 581 582 583 584
			}
			return 0;
		}

		/* Report error for everything else */
585
		*pfn = range->values[HMM_PFN_ERROR];
586
		return -EFAULT;
587 588 589 590
	} else {
		cpu_flags = pte_to_hmm_pfn_flags(range, pte);
		hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
				   &fault, &write_fault);
591 592
	}

593
	if (fault || write_fault)
594 595
		goto fault;

596 597 598 599 600 601 602 603 604 605
	if (pte_devmap(pte)) {
		hmm_vma_walk->pgmap = get_dev_pagemap(pte_pfn(pte),
					      hmm_vma_walk->pgmap);
		if (unlikely(!hmm_vma_walk->pgmap))
			return -EBUSY;
	} else if (IS_ENABLED(CONFIG_ARCH_HAS_PTE_SPECIAL) && pte_special(pte)) {
		*pfn = range->values[HMM_PFN_SPECIAL];
		return -EFAULT;
	}

606
	*pfn = hmm_device_entry_from_pfn(range, pte_pfn(pte)) | cpu_flags;
607 608 609
	return 0;

fault:
610 611 612 613
	if (hmm_vma_walk->pgmap) {
		put_dev_pagemap(hmm_vma_walk->pgmap);
		hmm_vma_walk->pgmap = NULL;
	}
614 615
	pte_unmap(ptep);
	/* Fault any virtual address we were asked to fault */
616
	return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
617 618
}

619 620 621 622 623
static int hmm_vma_walk_pmd(pmd_t *pmdp,
			    unsigned long start,
			    unsigned long end,
			    struct mm_walk *walk)
{
624 625
	struct hmm_vma_walk *hmm_vma_walk = walk->private;
	struct hmm_range *range = hmm_vma_walk->range;
626
	struct vm_area_struct *vma = walk->vma;
627
	uint64_t *pfns = range->pfns;
628 629
	unsigned long addr = start, i;
	pte_t *ptep;
630
	pmd_t pmd;
631 632 633


again:
634 635
	pmd = READ_ONCE(*pmdp);
	if (pmd_none(pmd))
636 637
		return hmm_vma_walk_hole(start, end, walk);

638
	if (pmd_huge(pmd) && (range->vma->vm_flags & VM_HUGETLB))
639 640
		return hmm_pfns_bad(start, end, walk);

641 642 643 644 645 646 647 648 649 650 651 652 653 654
	if (thp_migration_supported() && is_pmd_migration_entry(pmd)) {
		bool fault, write_fault;
		unsigned long npages;
		uint64_t *pfns;

		i = (addr - range->start) >> PAGE_SHIFT;
		npages = (end - addr) >> PAGE_SHIFT;
		pfns = &range->pfns[i];

		hmm_range_need_fault(hmm_vma_walk, pfns, npages,
				     0, &fault, &write_fault);
		if (fault || write_fault) {
			hmm_vma_walk->last = addr;
			pmd_migration_entry_wait(vma->vm_mm, pmdp);
655
			return -EBUSY;
656 657 658 659
		}
		return 0;
	} else if (!pmd_present(pmd))
		return hmm_pfns_bad(start, end, walk);
660

661
	if (pmd_devmap(pmd) || pmd_trans_huge(pmd)) {
662 663 664 665 666 667 668 669 670 671 672 673 674
		/*
		 * No need to take pmd_lock here, even if some other threads
		 * is splitting the huge pmd we will get that event through
		 * mmu_notifier callback.
		 *
		 * So just read pmd value and check again its a transparent
		 * huge or device mapping one and compute corresponding pfn
		 * values.
		 */
		pmd = pmd_read_atomic(pmdp);
		barrier();
		if (!pmd_devmap(pmd) && !pmd_trans_huge(pmd))
			goto again;
675

676
		i = (addr - range->start) >> PAGE_SHIFT;
677
		return hmm_vma_handle_pmd(walk, addr, end, &pfns[i], pmd);
678 679
	}

680 681 682 683 684 685 686
	/*
	 * We have handled all the valid case above ie either none, migration,
	 * huge or transparent huge. At this point either it is a valid pmd
	 * entry pointing to pte directory or it is a bad pmd that will not
	 * recover.
	 */
	if (pmd_bad(pmd))
687 688 689
		return hmm_pfns_bad(start, end, walk);

	ptep = pte_offset_map(pmdp, addr);
690
	i = (addr - range->start) >> PAGE_SHIFT;
691
	for (; addr < end; addr += PAGE_SIZE, ptep++, i++) {
692
		int r;
693

694 695 696 697 698
		r = hmm_vma_handle_pte(walk, addr, end, pmdp, ptep, &pfns[i]);
		if (r) {
			/* hmm_vma_handle_pte() did unmap pte directory */
			hmm_vma_walk->last = addr;
			return r;
699
		}
700
	}
701 702 703 704 705 706 707 708 709 710
	if (hmm_vma_walk->pgmap) {
		/*
		 * We do put_dev_pagemap() here and not in hmm_vma_handle_pte()
		 * so that we can leverage get_dev_pagemap() optimization which
		 * will not re-take a reference on a pgmap if we already have
		 * one.
		 */
		put_dev_pagemap(hmm_vma_walk->pgmap);
		hmm_vma_walk->pgmap = NULL;
	}
711 712
	pte_unmap(ptep - 1);

713
	hmm_vma_walk->last = addr;
714 715 716
	return 0;
}

717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758
static int hmm_vma_walk_pud(pud_t *pudp,
			    unsigned long start,
			    unsigned long end,
			    struct mm_walk *walk)
{
	struct hmm_vma_walk *hmm_vma_walk = walk->private;
	struct hmm_range *range = hmm_vma_walk->range;
	unsigned long addr = start, next;
	pmd_t *pmdp;
	pud_t pud;
	int ret;

again:
	pud = READ_ONCE(*pudp);
	if (pud_none(pud))
		return hmm_vma_walk_hole(start, end, walk);

	if (pud_huge(pud) && pud_devmap(pud)) {
		unsigned long i, npages, pfn;
		uint64_t *pfns, cpu_flags;
		bool fault, write_fault;

		if (!pud_present(pud))
			return hmm_vma_walk_hole(start, end, walk);

		i = (addr - range->start) >> PAGE_SHIFT;
		npages = (end - addr) >> PAGE_SHIFT;
		pfns = &range->pfns[i];

		cpu_flags = pud_to_hmm_pfn_flags(range, pud);
		hmm_range_need_fault(hmm_vma_walk, pfns, npages,
				     cpu_flags, &fault, &write_fault);
		if (fault || write_fault)
			return hmm_vma_walk_hole_(addr, end, fault,
						write_fault, walk);

		pfn = pud_pfn(pud) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
		for (i = 0; i < npages; ++i, ++pfn) {
			hmm_vma_walk->pgmap = get_dev_pagemap(pfn,
					      hmm_vma_walk->pgmap);
			if (unlikely(!hmm_vma_walk->pgmap))
				return -EBUSY;
759 760
			pfns[i] = hmm_device_entry_from_pfn(range, pfn) |
				  cpu_flags;
761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784
		}
		if (hmm_vma_walk->pgmap) {
			put_dev_pagemap(hmm_vma_walk->pgmap);
			hmm_vma_walk->pgmap = NULL;
		}
		hmm_vma_walk->last = end;
		return 0;
	}

	split_huge_pud(walk->vma, pudp, addr);
	if (pud_none(*pudp))
		goto again;

	pmdp = pmd_offset(pudp, addr);
	do {
		next = pmd_addr_end(addr, end);
		ret = hmm_vma_walk_pmd(pmdp, addr, next, walk);
		if (ret)
			return ret;
	} while (pmdp++, addr = next, addr != end);

	return 0;
}

785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832
static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask,
				      unsigned long start, unsigned long end,
				      struct mm_walk *walk)
{
#ifdef CONFIG_HUGETLB_PAGE
	unsigned long addr = start, i, pfn, mask, size, pfn_inc;
	struct hmm_vma_walk *hmm_vma_walk = walk->private;
	struct hmm_range *range = hmm_vma_walk->range;
	struct vm_area_struct *vma = walk->vma;
	struct hstate *h = hstate_vma(vma);
	uint64_t orig_pfn, cpu_flags;
	bool fault, write_fault;
	spinlock_t *ptl;
	pte_t entry;
	int ret = 0;

	size = 1UL << huge_page_shift(h);
	mask = size - 1;
	if (range->page_shift != PAGE_SHIFT) {
		/* Make sure we are looking at full page. */
		if (start & mask)
			return -EINVAL;
		if (end < (start + size))
			return -EINVAL;
		pfn_inc = size >> PAGE_SHIFT;
	} else {
		pfn_inc = 1;
		size = PAGE_SIZE;
	}


	ptl = huge_pte_lock(hstate_vma(walk->vma), walk->mm, pte);
	entry = huge_ptep_get(pte);

	i = (start - range->start) >> range->page_shift;
	orig_pfn = range->pfns[i];
	range->pfns[i] = range->values[HMM_PFN_NONE];
	cpu_flags = pte_to_hmm_pfn_flags(range, entry);
	fault = write_fault = false;
	hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
			   &fault, &write_fault);
	if (fault || write_fault) {
		ret = -ENOENT;
		goto unlock;
	}

	pfn = pte_pfn(entry) + ((start & mask) >> range->page_shift);
	for (; addr < end; addr += size, i++, pfn += pfn_inc)
833 834
		range->pfns[i] = hmm_device_entry_from_pfn(range, pfn) |
				 cpu_flags;
835 836 837 838 839 840 841 842 843 844 845 846 847 848
	hmm_vma_walk->last = end;

unlock:
	spin_unlock(ptl);

	if (ret == -ENOENT)
		return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);

	return ret;
#else /* CONFIG_HUGETLB_PAGE */
	return -EINVAL;
#endif
}

849 850
static void hmm_pfns_clear(struct hmm_range *range,
			   uint64_t *pfns,
851 852 853 854
			   unsigned long addr,
			   unsigned long end)
{
	for (; addr < end; addr += PAGE_SIZE, pfns++)
855
		*pfns = range->values[HMM_PFN_NONE];
856 857
}

858
/*
859
 * hmm_range_register() - start tracking change to CPU page table over a range
860
 * @range: range
861 862 863
 * @mm: the mm struct for the range of virtual address
 * @start: start virtual address (inclusive)
 * @end: end virtual address (exclusive)
864
 * @page_shift: expect page shift for the range
865
 * Returns 0 on success, -EFAULT if the address space is no longer valid
866
 *
867
 * Track updates to the CPU page table see include/linux/hmm.h
868
 */
869
int hmm_range_register(struct hmm_range *range,
870
		       struct hmm_mirror *mirror,
871
		       unsigned long start,
872 873
		       unsigned long end,
		       unsigned page_shift)
874
{
875
	unsigned long mask = ((1UL << page_shift) - 1UL);
876
	struct hmm *hmm = mirror->hmm;
877
	unsigned long flags;
878

879
	range->valid = false;
880 881
	range->hmm = NULL;

882 883 884
	if ((start & mask) || (end & mask))
		return -EINVAL;
	if (start >= end)
885 886
		return -EINVAL;

887
	range->page_shift = page_shift;
888 889 890
	range->start = start;
	range->end = end;

891 892
	/* Prevent hmm_release() from running while the range is valid */
	if (!mmget_not_zero(hmm->mm))
893
		return -EFAULT;
894

895
	/* Initialize range to track CPU page table updates. */
896
	spin_lock_irqsave(&hmm->ranges_lock, flags);
897

898
	range->hmm = hmm;
899
	kref_get(&hmm->kref);
900
	list_add(&range->list, &hmm->ranges);
901

902
	/*
903 904
	 * If there are any concurrent notifiers we have to wait for them for
	 * the range to be valid (see hmm_range_wait_until_valid()).
905
	 */
906
	if (!hmm->notifiers)
907
		range->valid = true;
908
	spin_unlock_irqrestore(&hmm->ranges_lock, flags);
909 910

	return 0;
911
}
912
EXPORT_SYMBOL(hmm_range_register);
913 914

/*
915 916
 * hmm_range_unregister() - stop tracking change to CPU page table over a range
 * @range: range
917 918
 *
 * Range struct is used to track updates to the CPU page table after a call to
919
 * hmm_range_register(). See include/linux/hmm.h for how to use it.
920
 */
921
void hmm_range_unregister(struct hmm_range *range)
922
{
923
	struct hmm *hmm = range->hmm;
924
	unsigned long flags;
925

926
	spin_lock_irqsave(&hmm->ranges_lock, flags);
927
	list_del_init(&range->list);
928
	spin_unlock_irqrestore(&hmm->ranges_lock, flags);
929

930
	/* Drop reference taken by hmm_range_register() */
931
	mmput(hmm->mm);
932
	hmm_put(hmm);
933 934 935 936 937 938

	/*
	 * The range is now invalid and the ref on the hmm is dropped, so
	 * poison the pointer.  Leave other fields in place, for the caller's
	 * use.
	 */
939
	range->valid = false;
940
	memset(&range->hmm, POISON_INUSE, sizeof(range->hmm));
941
}
942 943 944 945 946
EXPORT_SYMBOL(hmm_range_unregister);

/*
 * hmm_range_snapshot() - snapshot CPU page table for a range
 * @range: range
947
 * Return: -EINVAL if invalid argument, -ENOMEM out of memory, -EPERM invalid
948
 *          permission (for instance asking for write and range is read only),
949
 *          -EBUSY if you need to retry, -EFAULT invalid (ie either no valid
950 951 952 953 954 955 956 957 958
 *          vma or it is illegal to access that range), number of valid pages
 *          in range->pfns[] (from range start address).
 *
 * This snapshots the CPU page table for a range of virtual addresses. Snapshot
 * validity is tracked by range struct. See in include/linux/hmm.h for example
 * on how to use.
 */
long hmm_range_snapshot(struct hmm_range *range)
{
959
	const unsigned long device_vma = VM_IO | VM_PFNMAP | VM_MIXEDMAP;
960 961 962 963 964 965
	unsigned long start = range->start, end;
	struct hmm_vma_walk hmm_vma_walk;
	struct hmm *hmm = range->hmm;
	struct vm_area_struct *vma;
	struct mm_walk mm_walk;

966
	lockdep_assert_held(&hmm->mm->mmap_sem);
967 968 969
	do {
		/* If range is no longer valid force retry. */
		if (!range->valid)
970
			return -EBUSY;
971 972

		vma = find_vma(hmm->mm, start);
973
		if (vma == NULL || (vma->vm_flags & device_vma))
974 975
			return -EFAULT;

976
		if (is_vm_hugetlb_page(vma)) {
977 978
			if (huge_page_shift(hstate_vma(vma)) !=
				    range->page_shift &&
979 980 981 982 983 984 985
			    range->page_shift != PAGE_SHIFT)
				return -EINVAL;
		} else {
			if (range->page_shift != PAGE_SHIFT)
				return -EINVAL;
		}

986 987 988 989 990 991 992 993 994 995 996 997
		if (!(vma->vm_flags & VM_READ)) {
			/*
			 * If vma do not allow read access, then assume that it
			 * does not allow write access, either. HMM does not
			 * support architecture that allow write without read.
			 */
			hmm_pfns_clear(range, range->pfns,
				range->start, range->end);
			return -EPERM;
		}

		range->vma = vma;
998
		hmm_vma_walk.pgmap = NULL;
999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009
		hmm_vma_walk.last = start;
		hmm_vma_walk.fault = false;
		hmm_vma_walk.range = range;
		mm_walk.private = &hmm_vma_walk;
		end = min(range->end, vma->vm_end);

		mm_walk.vma = vma;
		mm_walk.mm = vma->vm_mm;
		mm_walk.pte_entry = NULL;
		mm_walk.test_walk = NULL;
		mm_walk.hugetlb_entry = NULL;
1010
		mm_walk.pud_entry = hmm_vma_walk_pud;
1011 1012
		mm_walk.pmd_entry = hmm_vma_walk_pmd;
		mm_walk.pte_hole = hmm_vma_walk_hole;
1013
		mm_walk.hugetlb_entry = hmm_vma_walk_hugetlb_entry;
1014 1015 1016 1017 1018 1019 1020 1021

		walk_page_range(start, end, &mm_walk);
		start = end;
	} while (start < range->end);

	return (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
}
EXPORT_SYMBOL(hmm_range_snapshot);
1022 1023

/*
1024
 * hmm_range_fault() - try to fault some address in a virtual address range
1025
 * @range: range being faulted
1026
 * @block: allow blocking on fault (if true it sleeps and do not drop mmap_sem)
1027
 * Return: number of valid pages in range->pfns[] (from range start
1028 1029 1030 1031
 *          address). This may be zero. If the return value is negative,
 *          then one of the following values may be returned:
 *
 *           -EINVAL  invalid arguments or mm or virtual address are in an
1032
 *                    invalid vma (for instance device file vma).
1033 1034 1035 1036 1037 1038 1039 1040 1041 1042
 *           -ENOMEM: Out of memory.
 *           -EPERM:  Invalid permission (for instance asking for write and
 *                    range is read only).
 *           -EAGAIN: If you need to retry and mmap_sem was drop. This can only
 *                    happens if block argument is false.
 *           -EBUSY:  If the the range is being invalidated and you should wait
 *                    for invalidation to finish.
 *           -EFAULT: Invalid (ie either no valid vma or it is illegal to access
 *                    that range), number of valid pages in range->pfns[] (from
 *                    range start address).
1043 1044
 *
 * This is similar to a regular CPU page fault except that it will not trigger
1045 1046
 * any memory migration if the memory being faulted is not accessible by CPUs
 * and caller does not ask for migration.
1047
 *
1048 1049
 * On error, for one virtual address in the range, the function will mark the
 * corresponding HMM pfn entry with an error flag.
1050
 */
1051
long hmm_range_fault(struct hmm_range *range, bool block)
1052
{
1053
	const unsigned long device_vma = VM_IO | VM_PFNMAP | VM_MIXEDMAP;
1054
	unsigned long start = range->start, end;
1055
	struct hmm_vma_walk hmm_vma_walk;
1056 1057
	struct hmm *hmm = range->hmm;
	struct vm_area_struct *vma;
1058 1059 1060
	struct mm_walk mm_walk;
	int ret;

1061
	lockdep_assert_held(&hmm->mm->mmap_sem);
1062

1063 1064
	do {
		/* If range is no longer valid force retry. */
1065 1066
		if (!range->valid)
			return -EBUSY;
1067

1068
		vma = find_vma(hmm->mm, start);
1069
		if (vma == NULL || (vma->vm_flags & device_vma))
1070
			return -EFAULT;
1071

1072 1073 1074 1075 1076 1077 1078 1079 1080 1081
		if (is_vm_hugetlb_page(vma)) {
			if (huge_page_shift(hstate_vma(vma)) !=
			    range->page_shift &&
			    range->page_shift != PAGE_SHIFT)
				return -EINVAL;
		} else {
			if (range->page_shift != PAGE_SHIFT)
				return -EINVAL;
		}

1082 1083 1084 1085 1086 1087 1088 1089 1090 1091
		if (!(vma->vm_flags & VM_READ)) {
			/*
			 * If vma do not allow read access, then assume that it
			 * does not allow write access, either. HMM does not
			 * support architecture that allow write without read.
			 */
			hmm_pfns_clear(range, range->pfns,
				range->start, range->end);
			return -EPERM;
		}
1092

1093
		range->vma = vma;
1094
		hmm_vma_walk.pgmap = NULL;
1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106
		hmm_vma_walk.last = start;
		hmm_vma_walk.fault = true;
		hmm_vma_walk.block = block;
		hmm_vma_walk.range = range;
		mm_walk.private = &hmm_vma_walk;
		end = min(range->end, vma->vm_end);

		mm_walk.vma = vma;
		mm_walk.mm = vma->vm_mm;
		mm_walk.pte_entry = NULL;
		mm_walk