8977a4e8a7f8c9db2f60952a9d55d6d3245eaff9
[firefly-linux-kernel-4.4.55.git] / fs / userfaultfd.c
1 /*
2  *  fs/userfaultfd.c
3  *
4  *  Copyright (C) 2007  Davide Libenzi <davidel@xmailserver.org>
5  *  Copyright (C) 2008-2009 Red Hat, Inc.
6  *  Copyright (C) 2015  Red Hat, Inc.
7  *
8  *  This work is licensed under the terms of the GNU GPL, version 2. See
9  *  the COPYING file in the top-level directory.
10  *
11  *  Some part derived from fs/eventfd.c (anon inode setup) and
12  *  mm/ksm.c (mm hashing).
13  */
14
15 #include <linux/hashtable.h>
16 #include <linux/sched.h>
17 #include <linux/mm.h>
18 #include <linux/poll.h>
19 #include <linux/slab.h>
20 #include <linux/seq_file.h>
21 #include <linux/file.h>
22 #include <linux/bug.h>
23 #include <linux/anon_inodes.h>
24 #include <linux/syscalls.h>
25 #include <linux/userfaultfd_k.h>
26 #include <linux/mempolicy.h>
27 #include <linux/ioctl.h>
28 #include <linux/security.h>
29
30 static struct kmem_cache *userfaultfd_ctx_cachep __read_mostly;
31
32 enum userfaultfd_state {
33         UFFD_STATE_WAIT_API,
34         UFFD_STATE_RUNNING,
35 };
36
37 /*
38  * Start with fault_pending_wqh and fault_wqh so they're more likely
39  * to be in the same cacheline.
40  */
41 struct userfaultfd_ctx {
42         /* waitqueue head for the pending (i.e. not read) userfaults */
43         wait_queue_head_t fault_pending_wqh;
44         /* waitqueue head for the userfaults */
45         wait_queue_head_t fault_wqh;
46         /* waitqueue head for the pseudo fd to wakeup poll/read */
47         wait_queue_head_t fd_wqh;
48         /* pseudo fd refcounting */
49         atomic_t refcount;
50         /* userfaultfd syscall flags */
51         unsigned int flags;
52         /* state machine */
53         enum userfaultfd_state state;
54         /* released */
55         bool released;
56         /* mm with one ore more vmas attached to this userfaultfd_ctx */
57         struct mm_struct *mm;
58 };
59
60 struct userfaultfd_wait_queue {
61         struct uffd_msg msg;
62         wait_queue_t wq;
63         struct userfaultfd_ctx *ctx;
64 };
65
66 struct userfaultfd_wake_range {
67         unsigned long start;
68         unsigned long len;
69 };
70
71 static int userfaultfd_wake_function(wait_queue_t *wq, unsigned mode,
72                                      int wake_flags, void *key)
73 {
74         struct userfaultfd_wake_range *range = key;
75         int ret;
76         struct userfaultfd_wait_queue *uwq;
77         unsigned long start, len;
78
79         uwq = container_of(wq, struct userfaultfd_wait_queue, wq);
80         ret = 0;
81         /* len == 0 means wake all */
82         start = range->start;
83         len = range->len;
84         if (len && (start > uwq->msg.arg.pagefault.address ||
85                     start + len <= uwq->msg.arg.pagefault.address))
86                 goto out;
87         ret = wake_up_state(wq->private, mode);
88         if (ret)
89                 /*
90                  * Wake only once, autoremove behavior.
91                  *
92                  * After the effect of list_del_init is visible to the
93                  * other CPUs, the waitqueue may disappear from under
94                  * us, see the !list_empty_careful() in
95                  * handle_userfault(). try_to_wake_up() has an
96                  * implicit smp_mb__before_spinlock, and the
97                  * wq->private is read before calling the extern
98                  * function "wake_up_state" (which in turns calls
99                  * try_to_wake_up). While the spin_lock;spin_unlock;
100                  * wouldn't be enough, the smp_mb__before_spinlock is
101                  * enough to avoid an explicit smp_mb() here.
102                  */
103                 list_del_init(&wq->task_list);
104 out:
105         return ret;
106 }
107
108 /**
109  * userfaultfd_ctx_get - Acquires a reference to the internal userfaultfd
110  * context.
111  * @ctx: [in] Pointer to the userfaultfd context.
112  *
113  * Returns: In case of success, returns not zero.
114  */
115 static void userfaultfd_ctx_get(struct userfaultfd_ctx *ctx)
116 {
117         if (!atomic_inc_not_zero(&ctx->refcount))
118                 BUG();
119 }
120
121 /**
122  * userfaultfd_ctx_put - Releases a reference to the internal userfaultfd
123  * context.
124  * @ctx: [in] Pointer to userfaultfd context.
125  *
126  * The userfaultfd context reference must have been previously acquired either
127  * with userfaultfd_ctx_get() or userfaultfd_ctx_fdget().
128  */
129 static void userfaultfd_ctx_put(struct userfaultfd_ctx *ctx)
130 {
131         if (atomic_dec_and_test(&ctx->refcount)) {
132                 VM_BUG_ON(spin_is_locked(&ctx->fault_pending_wqh.lock));
133                 VM_BUG_ON(waitqueue_active(&ctx->fault_pending_wqh));
134                 VM_BUG_ON(spin_is_locked(&ctx->fault_wqh.lock));
135                 VM_BUG_ON(waitqueue_active(&ctx->fault_wqh));
136                 VM_BUG_ON(spin_is_locked(&ctx->fd_wqh.lock));
137                 VM_BUG_ON(waitqueue_active(&ctx->fd_wqh));
138                 mmput(ctx->mm);
139                 kmem_cache_free(userfaultfd_ctx_cachep, ctx);
140         }
141 }
142
143 static inline void msg_init(struct uffd_msg *msg)
144 {
145         BUILD_BUG_ON(sizeof(struct uffd_msg) != 32);
146         /*
147          * Must use memset to zero out the paddings or kernel data is
148          * leaked to userland.
149          */
150         memset(msg, 0, sizeof(struct uffd_msg));
151 }
152
153 static inline struct uffd_msg userfault_msg(unsigned long address,
154                                             unsigned int flags,
155                                             unsigned long reason)
156 {
157         struct uffd_msg msg;
158         msg_init(&msg);
159         msg.event = UFFD_EVENT_PAGEFAULT;
160         msg.arg.pagefault.address = address;
161         if (flags & FAULT_FLAG_WRITE)
162                 /*
163                  * If UFFD_FEATURE_PAGEFAULT_FLAG_WRITE was set in the
164                  * uffdio_api.features and UFFD_PAGEFAULT_FLAG_WRITE
165                  * was not set in a UFFD_EVENT_PAGEFAULT, it means it
166                  * was a read fault, otherwise if set it means it's
167                  * a write fault.
168                  */
169                 msg.arg.pagefault.flags |= UFFD_PAGEFAULT_FLAG_WRITE;
170         if (reason & VM_UFFD_WP)
171                 /*
172                  * If UFFD_FEATURE_PAGEFAULT_FLAG_WP was set in the
173                  * uffdio_api.features and UFFD_PAGEFAULT_FLAG_WP was
174                  * not set in a UFFD_EVENT_PAGEFAULT, it means it was
175                  * a missing fault, otherwise if set it means it's a
176                  * write protect fault.
177                  */
178                 msg.arg.pagefault.flags |= UFFD_PAGEFAULT_FLAG_WP;
179         return msg;
180 }
181
182 /*
183  * The locking rules involved in returning VM_FAULT_RETRY depending on
184  * FAULT_FLAG_ALLOW_RETRY, FAULT_FLAG_RETRY_NOWAIT and
185  * FAULT_FLAG_KILLABLE are not straightforward. The "Caution"
186  * recommendation in __lock_page_or_retry is not an understatement.
187  *
188  * If FAULT_FLAG_ALLOW_RETRY is set, the mmap_sem must be released
189  * before returning VM_FAULT_RETRY only if FAULT_FLAG_RETRY_NOWAIT is
190  * not set.
191  *
192  * If FAULT_FLAG_ALLOW_RETRY is set but FAULT_FLAG_KILLABLE is not
193  * set, VM_FAULT_RETRY can still be returned if and only if there are
194  * fatal_signal_pending()s, and the mmap_sem must be released before
195  * returning it.
196  */
197 int handle_userfault(struct vm_area_struct *vma, unsigned long address,
198                      unsigned int flags, unsigned long reason)
199 {
200         struct mm_struct *mm = vma->vm_mm;
201         struct userfaultfd_ctx *ctx;
202         struct userfaultfd_wait_queue uwq;
203         int ret;
204
205         BUG_ON(!rwsem_is_locked(&mm->mmap_sem));
206
207         ret = VM_FAULT_SIGBUS;
208         ctx = vma->vm_userfaultfd_ctx.ctx;
209         if (!ctx)
210                 goto out;
211
212         BUG_ON(ctx->mm != mm);
213
214         VM_BUG_ON(reason & ~(VM_UFFD_MISSING|VM_UFFD_WP));
215         VM_BUG_ON(!(reason & VM_UFFD_MISSING) ^ !!(reason & VM_UFFD_WP));
216
217         /*
218          * If it's already released don't get it. This avoids to loop
219          * in __get_user_pages if userfaultfd_release waits on the
220          * caller of handle_userfault to release the mmap_sem.
221          */
222         if (unlikely(ACCESS_ONCE(ctx->released)))
223                 goto out;
224
225         /*
226          * Check that we can return VM_FAULT_RETRY.
227          *
228          * NOTE: it should become possible to return VM_FAULT_RETRY
229          * even if FAULT_FLAG_TRIED is set without leading to gup()
230          * -EBUSY failures, if the userfaultfd is to be extended for
231          * VM_UFFD_WP tracking and we intend to arm the userfault
232          * without first stopping userland access to the memory. For
233          * VM_UFFD_MISSING userfaults this is enough for now.
234          */
235         if (unlikely(!(flags & FAULT_FLAG_ALLOW_RETRY))) {
236                 /*
237                  * Validate the invariant that nowait must allow retry
238                  * to be sure not to return SIGBUS erroneously on
239                  * nowait invocations.
240                  */
241                 BUG_ON(flags & FAULT_FLAG_RETRY_NOWAIT);
242 #ifdef CONFIG_DEBUG_VM
243                 if (printk_ratelimit()) {
244                         printk(KERN_WARNING
245                                "FAULT_FLAG_ALLOW_RETRY missing %x\n", flags);
246                         dump_stack();
247                 }
248 #endif
249                 goto out;
250         }
251
252         /*
253          * Handle nowait, not much to do other than tell it to retry
254          * and wait.
255          */
256         ret = VM_FAULT_RETRY;
257         if (flags & FAULT_FLAG_RETRY_NOWAIT)
258                 goto out;
259
260         /* take the reference before dropping the mmap_sem */
261         userfaultfd_ctx_get(ctx);
262
263         /* be gentle and immediately relinquish the mmap_sem */
264         up_read(&mm->mmap_sem);
265
266         init_waitqueue_func_entry(&uwq.wq, userfaultfd_wake_function);
267         uwq.wq.private = current;
268         uwq.msg = userfault_msg(address, flags, reason);
269         uwq.ctx = ctx;
270
271         spin_lock(&ctx->fault_pending_wqh.lock);
272         /*
273          * After the __add_wait_queue the uwq is visible to userland
274          * through poll/read().
275          */
276         __add_wait_queue(&ctx->fault_pending_wqh, &uwq.wq);
277         /*
278          * The smp_mb() after __set_current_state prevents the reads
279          * following the spin_unlock to happen before the list_add in
280          * __add_wait_queue.
281          */
282         set_current_state(TASK_KILLABLE);
283         spin_unlock(&ctx->fault_pending_wqh.lock);
284
285         if (likely(!ACCESS_ONCE(ctx->released) &&
286                    !fatal_signal_pending(current))) {
287                 wake_up_poll(&ctx->fd_wqh, POLLIN);
288                 schedule();
289                 ret |= VM_FAULT_MAJOR;
290         }
291
292         __set_current_state(TASK_RUNNING);
293
294         /*
295          * Here we race with the list_del; list_add in
296          * userfaultfd_ctx_read(), however because we don't ever run
297          * list_del_init() to refile across the two lists, the prev
298          * and next pointers will never point to self. list_add also
299          * would never let any of the two pointers to point to
300          * self. So list_empty_careful won't risk to see both pointers
301          * pointing to self at any time during the list refile. The
302          * only case where list_del_init() is called is the full
303          * removal in the wake function and there we don't re-list_add
304          * and it's fine not to block on the spinlock. The uwq on this
305          * kernel stack can be released after the list_del_init.
306          */
307         if (!list_empty_careful(&uwq.wq.task_list)) {
308                 spin_lock(&ctx->fault_pending_wqh.lock);
309                 /*
310                  * No need of list_del_init(), the uwq on the stack
311                  * will be freed shortly anyway.
312                  */
313                 list_del(&uwq.wq.task_list);
314                 spin_unlock(&ctx->fault_pending_wqh.lock);
315         }
316
317         /*
318          * ctx may go away after this if the userfault pseudo fd is
319          * already released.
320          */
321         userfaultfd_ctx_put(ctx);
322
323 out:
324         return ret;
325 }
326
327 static int userfaultfd_release(struct inode *inode, struct file *file)
328 {
329         struct userfaultfd_ctx *ctx = file->private_data;
330         struct mm_struct *mm = ctx->mm;
331         struct vm_area_struct *vma, *prev;
332         /* len == 0 means wake all */
333         struct userfaultfd_wake_range range = { .len = 0, };
334         unsigned long new_flags;
335
336         ACCESS_ONCE(ctx->released) = true;
337
338         /*
339          * Flush page faults out of all CPUs. NOTE: all page faults
340          * must be retried without returning VM_FAULT_SIGBUS if
341          * userfaultfd_ctx_get() succeeds but vma->vma_userfault_ctx
342          * changes while handle_userfault released the mmap_sem. So
343          * it's critical that released is set to true (above), before
344          * taking the mmap_sem for writing.
345          */
346         down_write(&mm->mmap_sem);
347         prev = NULL;
348         for (vma = mm->mmap; vma; vma = vma->vm_next) {
349                 cond_resched();
350                 BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^
351                        !!(vma->vm_flags & (VM_UFFD_MISSING | VM_UFFD_WP)));
352                 if (vma->vm_userfaultfd_ctx.ctx != ctx) {
353                         prev = vma;
354                         continue;
355                 }
356                 new_flags = vma->vm_flags & ~(VM_UFFD_MISSING | VM_UFFD_WP);
357                 prev = vma_merge(mm, prev, vma->vm_start, vma->vm_end,
358                                  new_flags, vma->anon_vma,
359                                  vma->vm_file, vma->vm_pgoff,
360                                  vma_policy(vma),
361                                  NULL_VM_UFFD_CTX);
362                 if (prev)
363                         vma = prev;
364                 else
365                         prev = vma;
366                 vma->vm_flags = new_flags;
367                 vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
368         }
369         up_write(&mm->mmap_sem);
370
371         /*
372          * After no new page faults can wait on this fault_*wqh, flush
373          * the last page faults that may have been already waiting on
374          * the fault_*wqh.
375          */
376         spin_lock(&ctx->fault_pending_wqh.lock);
377         __wake_up_locked_key(&ctx->fault_pending_wqh, TASK_NORMAL, 0, &range);
378         __wake_up_locked_key(&ctx->fault_wqh, TASK_NORMAL, 0, &range);
379         spin_unlock(&ctx->fault_pending_wqh.lock);
380
381         wake_up_poll(&ctx->fd_wqh, POLLHUP);
382         userfaultfd_ctx_put(ctx);
383         return 0;
384 }
385
386 /* fault_pending_wqh.lock must be hold by the caller */
387 static inline struct userfaultfd_wait_queue *find_userfault(
388         struct userfaultfd_ctx *ctx)
389 {
390         wait_queue_t *wq;
391         struct userfaultfd_wait_queue *uwq;
392
393         VM_BUG_ON(!spin_is_locked(&ctx->fault_pending_wqh.lock));
394
395         uwq = NULL;
396         if (!waitqueue_active(&ctx->fault_pending_wqh))
397                 goto out;
398         /* walk in reverse to provide FIFO behavior to read userfaults */
399         wq = list_last_entry(&ctx->fault_pending_wqh.task_list,
400                              typeof(*wq), task_list);
401         uwq = container_of(wq, struct userfaultfd_wait_queue, wq);
402 out:
403         return uwq;
404 }
405
406 static unsigned int userfaultfd_poll(struct file *file, poll_table *wait)
407 {
408         struct userfaultfd_ctx *ctx = file->private_data;
409         unsigned int ret;
410
411         poll_wait(file, &ctx->fd_wqh, wait);
412
413         switch (ctx->state) {
414         case UFFD_STATE_WAIT_API:
415                 return POLLERR;
416         case UFFD_STATE_RUNNING:
417                 /*
418                  * poll() never guarantees that read won't block.
419                  * userfaults can be waken before they're read().
420                  */
421                 if (unlikely(!(file->f_flags & O_NONBLOCK)))
422                         return POLLERR;
423                 /*
424                  * lockless access to see if there are pending faults
425                  * __pollwait last action is the add_wait_queue but
426                  * the spin_unlock would allow the waitqueue_active to
427                  * pass above the actual list_add inside
428                  * add_wait_queue critical section. So use a full
429                  * memory barrier to serialize the list_add write of
430                  * add_wait_queue() with the waitqueue_active read
431                  * below.
432                  */
433                 ret = 0;
434                 smp_mb();
435                 if (waitqueue_active(&ctx->fault_pending_wqh))
436                         ret = POLLIN;
437                 return ret;
438         default:
439                 BUG();
440         }
441 }
442
443 static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait,
444                                     struct uffd_msg *msg)
445 {
446         ssize_t ret;
447         DECLARE_WAITQUEUE(wait, current);
448         struct userfaultfd_wait_queue *uwq;
449
450         /* always take the fd_wqh lock before the fault_pending_wqh lock */
451         spin_lock(&ctx->fd_wqh.lock);
452         __add_wait_queue(&ctx->fd_wqh, &wait);
453         for (;;) {
454                 set_current_state(TASK_INTERRUPTIBLE);
455                 spin_lock(&ctx->fault_pending_wqh.lock);
456                 uwq = find_userfault(ctx);
457                 if (uwq) {
458                         /*
459                          * The fault_pending_wqh.lock prevents the uwq
460                          * to disappear from under us.
461                          *
462                          * Refile this userfault from
463                          * fault_pending_wqh to fault_wqh, it's not
464                          * pending anymore after we read it.
465                          *
466                          * Use list_del() by hand (as
467                          * userfaultfd_wake_function also uses
468                          * list_del_init() by hand) to be sure nobody
469                          * changes __remove_wait_queue() to use
470                          * list_del_init() in turn breaking the
471                          * !list_empty_careful() check in
472                          * handle_userfault(). The uwq->wq.task_list
473                          * must never be empty at any time during the
474                          * refile, or the waitqueue could disappear
475                          * from under us. The "wait_queue_head_t"
476                          * parameter of __remove_wait_queue() is unused
477                          * anyway.
478                          */
479                         list_del(&uwq->wq.task_list);
480                         __add_wait_queue(&ctx->fault_wqh, &uwq->wq);
481
482                         /* careful to always initialize msg if ret == 0 */
483                         *msg = uwq->msg;
484                         spin_unlock(&ctx->fault_pending_wqh.lock);
485                         ret = 0;
486                         break;
487                 }
488                 spin_unlock(&ctx->fault_pending_wqh.lock);
489                 if (signal_pending(current)) {
490                         ret = -ERESTARTSYS;
491                         break;
492                 }
493                 if (no_wait) {
494                         ret = -EAGAIN;
495                         break;
496                 }
497                 spin_unlock(&ctx->fd_wqh.lock);
498                 schedule();
499                 spin_lock(&ctx->fd_wqh.lock);
500         }
501         __remove_wait_queue(&ctx->fd_wqh, &wait);
502         __set_current_state(TASK_RUNNING);
503         spin_unlock(&ctx->fd_wqh.lock);
504
505         return ret;
506 }
507
508 static ssize_t userfaultfd_read(struct file *file, char __user *buf,
509                                 size_t count, loff_t *ppos)
510 {
511         struct userfaultfd_ctx *ctx = file->private_data;
512         ssize_t _ret, ret = 0;
513         struct uffd_msg msg;
514         int no_wait = file->f_flags & O_NONBLOCK;
515
516         if (ctx->state == UFFD_STATE_WAIT_API)
517                 return -EINVAL;
518         BUG_ON(ctx->state != UFFD_STATE_RUNNING);
519
520         for (;;) {
521                 if (count < sizeof(msg))
522                         return ret ? ret : -EINVAL;
523                 _ret = userfaultfd_ctx_read(ctx, no_wait, &msg);
524                 if (_ret < 0)
525                         return ret ? ret : _ret;
526                 if (copy_to_user((__u64 __user *) buf, &msg, sizeof(msg)))
527                         return ret ? ret : -EFAULT;
528                 ret += sizeof(msg);
529                 buf += sizeof(msg);
530                 count -= sizeof(msg);
531                 /*
532                  * Allow to read more than one fault at time but only
533                  * block if waiting for the very first one.
534                  */
535                 no_wait = O_NONBLOCK;
536         }
537 }
538
539 static void __wake_userfault(struct userfaultfd_ctx *ctx,
540                              struct userfaultfd_wake_range *range)
541 {
542         unsigned long start, end;
543
544         start = range->start;
545         end = range->start + range->len;
546
547         spin_lock(&ctx->fault_pending_wqh.lock);
548         /* wake all in the range and autoremove */
549         if (waitqueue_active(&ctx->fault_pending_wqh))
550                 __wake_up_locked_key(&ctx->fault_pending_wqh, TASK_NORMAL, 0,
551                                      range);
552         if (waitqueue_active(&ctx->fault_wqh))
553                 __wake_up_locked_key(&ctx->fault_wqh, TASK_NORMAL, 0, range);
554         spin_unlock(&ctx->fault_pending_wqh.lock);
555 }
556
557 static __always_inline void wake_userfault(struct userfaultfd_ctx *ctx,
558                                            struct userfaultfd_wake_range *range)
559 {
560         /*
561          * To be sure waitqueue_active() is not reordered by the CPU
562          * before the pagetable update, use an explicit SMP memory
563          * barrier here. PT lock release or up_read(mmap_sem) still
564          * have release semantics that can allow the
565          * waitqueue_active() to be reordered before the pte update.
566          */
567         smp_mb();
568
569         /*
570          * Use waitqueue_active because it's very frequent to
571          * change the address space atomically even if there are no
572          * userfaults yet. So we take the spinlock only when we're
573          * sure we've userfaults to wake.
574          */
575         if (waitqueue_active(&ctx->fault_pending_wqh) ||
576             waitqueue_active(&ctx->fault_wqh))
577                 __wake_userfault(ctx, range);
578 }
579
580 static __always_inline int validate_range(struct mm_struct *mm,
581                                           __u64 start, __u64 len)
582 {
583         __u64 task_size = mm->task_size;
584
585         if (start & ~PAGE_MASK)
586                 return -EINVAL;
587         if (len & ~PAGE_MASK)
588                 return -EINVAL;
589         if (!len)
590                 return -EINVAL;
591         if (start < mmap_min_addr)
592                 return -EINVAL;
593         if (start >= task_size)
594                 return -EINVAL;
595         if (len > task_size - start)
596                 return -EINVAL;
597         return 0;
598 }
599
600 static int userfaultfd_register(struct userfaultfd_ctx *ctx,
601                                 unsigned long arg)
602 {
603         struct mm_struct *mm = ctx->mm;
604         struct vm_area_struct *vma, *prev, *cur;
605         int ret;
606         struct uffdio_register uffdio_register;
607         struct uffdio_register __user *user_uffdio_register;
608         unsigned long vm_flags, new_flags;
609         bool found;
610         unsigned long start, end, vma_end;
611
612         user_uffdio_register = (struct uffdio_register __user *) arg;
613
614         ret = -EFAULT;
615         if (copy_from_user(&uffdio_register, user_uffdio_register,
616                            sizeof(uffdio_register)-sizeof(__u64)))
617                 goto out;
618
619         ret = -EINVAL;
620         if (!uffdio_register.mode)
621                 goto out;
622         if (uffdio_register.mode & ~(UFFDIO_REGISTER_MODE_MISSING|
623                                      UFFDIO_REGISTER_MODE_WP))
624                 goto out;
625         vm_flags = 0;
626         if (uffdio_register.mode & UFFDIO_REGISTER_MODE_MISSING)
627                 vm_flags |= VM_UFFD_MISSING;
628         if (uffdio_register.mode & UFFDIO_REGISTER_MODE_WP) {
629                 vm_flags |= VM_UFFD_WP;
630                 /*
631                  * FIXME: remove the below error constraint by
632                  * implementing the wprotect tracking mode.
633                  */
634                 ret = -EINVAL;
635                 goto out;
636         }
637
638         ret = validate_range(mm, uffdio_register.range.start,
639                              uffdio_register.range.len);
640         if (ret)
641                 goto out;
642
643         start = uffdio_register.range.start;
644         end = start + uffdio_register.range.len;
645
646         down_write(&mm->mmap_sem);
647         vma = find_vma_prev(mm, start, &prev);
648
649         ret = -ENOMEM;
650         if (!vma)
651                 goto out_unlock;
652
653         /* check that there's at least one vma in the range */
654         ret = -EINVAL;
655         if (vma->vm_start >= end)
656                 goto out_unlock;
657
658         /*
659          * Search for not compatible vmas.
660          *
661          * FIXME: this shall be relaxed later so that it doesn't fail
662          * on tmpfs backed vmas (in addition to the current allowance
663          * on anonymous vmas).
664          */
665         found = false;
666         for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) {
667                 cond_resched();
668
669                 BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
670                        !!(cur->vm_flags & (VM_UFFD_MISSING | VM_UFFD_WP)));
671
672                 /* check not compatible vmas */
673                 ret = -EINVAL;
674                 if (cur->vm_ops)
675                         goto out_unlock;
676
677                 /*
678                  * Check that this vma isn't already owned by a
679                  * different userfaultfd. We can't allow more than one
680                  * userfaultfd to own a single vma simultaneously or we
681                  * wouldn't know which one to deliver the userfaults to.
682                  */
683                 ret = -EBUSY;
684                 if (cur->vm_userfaultfd_ctx.ctx &&
685                     cur->vm_userfaultfd_ctx.ctx != ctx)
686                         goto out_unlock;
687
688                 found = true;
689         }
690         BUG_ON(!found);
691
692         if (vma->vm_start < start)
693                 prev = vma;
694
695         ret = 0;
696         do {
697                 cond_resched();
698
699                 BUG_ON(vma->vm_ops);
700                 BUG_ON(vma->vm_userfaultfd_ctx.ctx &&
701                        vma->vm_userfaultfd_ctx.ctx != ctx);
702
703                 /*
704                  * Nothing to do: this vma is already registered into this
705                  * userfaultfd and with the right tracking mode too.
706                  */
707                 if (vma->vm_userfaultfd_ctx.ctx == ctx &&
708                     (vma->vm_flags & vm_flags) == vm_flags)
709                         goto skip;
710
711                 if (vma->vm_start > start)
712                         start = vma->vm_start;
713                 vma_end = min(end, vma->vm_end);
714
715                 new_flags = (vma->vm_flags & ~vm_flags) | vm_flags;
716                 prev = vma_merge(mm, prev, start, vma_end, new_flags,
717                                  vma->anon_vma, vma->vm_file, vma->vm_pgoff,
718                                  vma_policy(vma),
719                                  ((struct vm_userfaultfd_ctx){ ctx }));
720                 if (prev) {
721                         vma = prev;
722                         goto next;
723                 }
724                 if (vma->vm_start < start) {
725                         ret = split_vma(mm, vma, start, 1);
726                         if (ret)
727                                 break;
728                 }
729                 if (vma->vm_end > end) {
730                         ret = split_vma(mm, vma, end, 0);
731                         if (ret)
732                                 break;
733                 }
734         next:
735                 /*
736                  * In the vma_merge() successful mprotect-like case 8:
737                  * the next vma was merged into the current one and
738                  * the current one has not been updated yet.
739                  */
740                 vma->vm_flags = new_flags;
741                 vma->vm_userfaultfd_ctx.ctx = ctx;
742
743         skip:
744                 prev = vma;
745                 start = vma->vm_end;
746                 vma = vma->vm_next;
747         } while (vma && vma->vm_start < end);
748 out_unlock:
749         up_write(&mm->mmap_sem);
750         if (!ret) {
751                 /*
752                  * Now that we scanned all vmas we can already tell
753                  * userland which ioctls methods are guaranteed to
754                  * succeed on this range.
755                  */
756                 if (put_user(UFFD_API_RANGE_IOCTLS,
757                              &user_uffdio_register->ioctls))
758                         ret = -EFAULT;
759         }
760 out:
761         return ret;
762 }
763
764 static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
765                                   unsigned long arg)
766 {
767         struct mm_struct *mm = ctx->mm;
768         struct vm_area_struct *vma, *prev, *cur;
769         int ret;
770         struct uffdio_range uffdio_unregister;
771         unsigned long new_flags;
772         bool found;
773         unsigned long start, end, vma_end;
774         const void __user *buf = (void __user *)arg;
775
776         ret = -EFAULT;
777         if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister)))
778                 goto out;
779
780         ret = validate_range(mm, uffdio_unregister.start,
781                              uffdio_unregister.len);
782         if (ret)
783                 goto out;
784
785         start = uffdio_unregister.start;
786         end = start + uffdio_unregister.len;
787
788         down_write(&mm->mmap_sem);
789         vma = find_vma_prev(mm, start, &prev);
790
791         ret = -ENOMEM;
792         if (!vma)
793                 goto out_unlock;
794
795         /* check that there's at least one vma in the range */
796         ret = -EINVAL;
797         if (vma->vm_start >= end)
798                 goto out_unlock;
799
800         /*
801          * Search for not compatible vmas.
802          *
803          * FIXME: this shall be relaxed later so that it doesn't fail
804          * on tmpfs backed vmas (in addition to the current allowance
805          * on anonymous vmas).
806          */
807         found = false;
808         ret = -EINVAL;
809         for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) {
810                 cond_resched();
811
812                 BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
813                        !!(cur->vm_flags & (VM_UFFD_MISSING | VM_UFFD_WP)));
814
815                 /*
816                  * Check not compatible vmas, not strictly required
817                  * here as not compatible vmas cannot have an
818                  * userfaultfd_ctx registered on them, but this
819                  * provides for more strict behavior to notice
820                  * unregistration errors.
821                  */
822                 if (cur->vm_ops)
823                         goto out_unlock;
824
825                 found = true;
826         }
827         BUG_ON(!found);
828
829         if (vma->vm_start < start)
830                 prev = vma;
831
832         ret = 0;
833         do {
834                 cond_resched();
835
836                 BUG_ON(vma->vm_ops);
837
838                 /*
839                  * Nothing to do: this vma is already registered into this
840                  * userfaultfd and with the right tracking mode too.
841                  */
842                 if (!vma->vm_userfaultfd_ctx.ctx)
843                         goto skip;
844
845                 if (vma->vm_start > start)
846                         start = vma->vm_start;
847                 vma_end = min(end, vma->vm_end);
848
849                 new_flags = vma->vm_flags & ~(VM_UFFD_MISSING | VM_UFFD_WP);
850                 prev = vma_merge(mm, prev, start, vma_end, new_flags,
851                                  vma->anon_vma, vma->vm_file, vma->vm_pgoff,
852                                  vma_policy(vma),
853                                  NULL_VM_UFFD_CTX);
854                 if (prev) {
855                         vma = prev;
856                         goto next;
857                 }
858                 if (vma->vm_start < start) {
859                         ret = split_vma(mm, vma, start, 1);
860                         if (ret)
861                                 break;
862                 }
863                 if (vma->vm_end > end) {
864                         ret = split_vma(mm, vma, end, 0);
865                         if (ret)
866                                 break;
867                 }
868         next:
869                 /*
870                  * In the vma_merge() successful mprotect-like case 8:
871                  * the next vma was merged into the current one and
872                  * the current one has not been updated yet.
873                  */
874                 vma->vm_flags = new_flags;
875                 vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
876
877         skip:
878                 prev = vma;
879                 start = vma->vm_end;
880                 vma = vma->vm_next;
881         } while (vma && vma->vm_start < end);
882 out_unlock:
883         up_write(&mm->mmap_sem);
884 out:
885         return ret;
886 }
887
888 /*
889  * userfaultfd_wake is needed in case an userfault is in flight by the
890  * time a UFFDIO_COPY (or other ioctl variants) completes. The page
891  * may be well get mapped and the page fault if repeated wouldn't lead
892  * to a userfault anymore, but before scheduling in TASK_KILLABLE mode
893  * handle_userfault() doesn't recheck the pagetables and it doesn't
894  * serialize against UFFDO_COPY (or other ioctl variants). Ultimately
895  * the knowledge of which pages are mapped is left to userland who is
896  * responsible for handling the race between read() userfaults and
897  * background UFFDIO_COPY (or other ioctl variants), if done by
898  * separate concurrent threads.
899  *
900  * userfaultfd_wake may be used in combination with the
901  * UFFDIO_*_MODE_DONTWAKE to wakeup userfaults in batches.
902  */
903 static int userfaultfd_wake(struct userfaultfd_ctx *ctx,
904                             unsigned long arg)
905 {
906         int ret;
907         struct uffdio_range uffdio_wake;
908         struct userfaultfd_wake_range range;
909         const void __user *buf = (void __user *)arg;
910
911         ret = -EFAULT;
912         if (copy_from_user(&uffdio_wake, buf, sizeof(uffdio_wake)))
913                 goto out;
914
915         ret = validate_range(ctx->mm, uffdio_wake.start, uffdio_wake.len);
916         if (ret)
917                 goto out;
918
919         range.start = uffdio_wake.start;
920         range.len = uffdio_wake.len;
921
922         /*
923          * len == 0 means wake all and we don't want to wake all here,
924          * so check it again to be sure.
925          */
926         VM_BUG_ON(!range.len);
927
928         wake_userfault(ctx, &range);
929         ret = 0;
930
931 out:
932         return ret;
933 }
934
935 /*
936  * userland asks for a certain API version and we return which bits
937  * and ioctl commands are implemented in this kernel for such API
938  * version or -EINVAL if unknown.
939  */
940 static int userfaultfd_api(struct userfaultfd_ctx *ctx,
941                            unsigned long arg)
942 {
943         struct uffdio_api uffdio_api;
944         void __user *buf = (void __user *)arg;
945         int ret;
946
947         ret = -EINVAL;
948         if (ctx->state != UFFD_STATE_WAIT_API)
949                 goto out;
950         ret = -EFAULT;
951         if (copy_from_user(&uffdio_api, buf, sizeof(uffdio_api)))
952                 goto out;
953         if (uffdio_api.api != UFFD_API || uffdio_api.features) {
954                 memset(&uffdio_api, 0, sizeof(uffdio_api));
955                 if (copy_to_user(buf, &uffdio_api, sizeof(uffdio_api)))
956                         goto out;
957                 ret = -EINVAL;
958                 goto out;
959         }
960         uffdio_api.features = UFFD_API_FEATURES;
961         uffdio_api.ioctls = UFFD_API_IOCTLS;
962         ret = -EFAULT;
963         if (copy_to_user(buf, &uffdio_api, sizeof(uffdio_api)))
964                 goto out;
965         ctx->state = UFFD_STATE_RUNNING;
966         ret = 0;
967 out:
968         return ret;
969 }
970
971 static long userfaultfd_ioctl(struct file *file, unsigned cmd,
972                               unsigned long arg)
973 {
974         int ret = -EINVAL;
975         struct userfaultfd_ctx *ctx = file->private_data;
976
977         switch(cmd) {
978         case UFFDIO_API:
979                 ret = userfaultfd_api(ctx, arg);
980                 break;
981         case UFFDIO_REGISTER:
982                 ret = userfaultfd_register(ctx, arg);
983                 break;
984         case UFFDIO_UNREGISTER:
985                 ret = userfaultfd_unregister(ctx, arg);
986                 break;
987         case UFFDIO_WAKE:
988                 ret = userfaultfd_wake(ctx, arg);
989                 break;
990         }
991         return ret;
992 }
993
994 #ifdef CONFIG_PROC_FS
995 static void userfaultfd_show_fdinfo(struct seq_file *m, struct file *f)
996 {
997         struct userfaultfd_ctx *ctx = f->private_data;
998         wait_queue_t *wq;
999         struct userfaultfd_wait_queue *uwq;
1000         unsigned long pending = 0, total = 0;
1001
1002         spin_lock(&ctx->fault_pending_wqh.lock);
1003         list_for_each_entry(wq, &ctx->fault_pending_wqh.task_list, task_list) {
1004                 uwq = container_of(wq, struct userfaultfd_wait_queue, wq);
1005                 pending++;
1006                 total++;
1007         }
1008         list_for_each_entry(wq, &ctx->fault_wqh.task_list, task_list) {
1009                 uwq = container_of(wq, struct userfaultfd_wait_queue, wq);
1010                 total++;
1011         }
1012         spin_unlock(&ctx->fault_pending_wqh.lock);
1013
1014         /*
1015          * If more protocols will be added, there will be all shown
1016          * separated by a space. Like this:
1017          *      protocols: aa:... bb:...
1018          */
1019         seq_printf(m, "pending:\t%lu\ntotal:\t%lu\nAPI:\t%Lx:%x:%Lx\n",
1020                    pending, total, UFFD_API, UFFD_API_FEATURES,
1021                    UFFD_API_IOCTLS|UFFD_API_RANGE_IOCTLS);
1022 }
1023 #endif
1024
1025 static const struct file_operations userfaultfd_fops = {
1026 #ifdef CONFIG_PROC_FS
1027         .show_fdinfo    = userfaultfd_show_fdinfo,
1028 #endif
1029         .release        = userfaultfd_release,
1030         .poll           = userfaultfd_poll,
1031         .read           = userfaultfd_read,
1032         .unlocked_ioctl = userfaultfd_ioctl,
1033         .compat_ioctl   = userfaultfd_ioctl,
1034         .llseek         = noop_llseek,
1035 };
1036
1037 static void init_once_userfaultfd_ctx(void *mem)
1038 {
1039         struct userfaultfd_ctx *ctx = (struct userfaultfd_ctx *) mem;
1040
1041         init_waitqueue_head(&ctx->fault_pending_wqh);
1042         init_waitqueue_head(&ctx->fault_wqh);
1043         init_waitqueue_head(&ctx->fd_wqh);
1044 }
1045
1046 /**
1047  * userfaultfd_file_create - Creates an userfaultfd file pointer.
1048  * @flags: Flags for the userfaultfd file.
1049  *
1050  * This function creates an userfaultfd file pointer, w/out installing
1051  * it into the fd table. This is useful when the userfaultfd file is
1052  * used during the initialization of data structures that require
1053  * extra setup after the userfaultfd creation. So the userfaultfd
1054  * creation is split into the file pointer creation phase, and the
1055  * file descriptor installation phase.  In this way races with
1056  * userspace closing the newly installed file descriptor can be
1057  * avoided.  Returns an userfaultfd file pointer, or a proper error
1058  * pointer.
1059  */
1060 static struct file *userfaultfd_file_create(int flags)
1061 {
1062         struct file *file;
1063         struct userfaultfd_ctx *ctx;
1064
1065         BUG_ON(!current->mm);
1066
1067         /* Check the UFFD_* constants for consistency.  */
1068         BUILD_BUG_ON(UFFD_CLOEXEC != O_CLOEXEC);
1069         BUILD_BUG_ON(UFFD_NONBLOCK != O_NONBLOCK);
1070
1071         file = ERR_PTR(-EINVAL);
1072         if (flags & ~UFFD_SHARED_FCNTL_FLAGS)
1073                 goto out;
1074
1075         file = ERR_PTR(-ENOMEM);
1076         ctx = kmem_cache_alloc(userfaultfd_ctx_cachep, GFP_KERNEL);
1077         if (!ctx)
1078                 goto out;
1079
1080         atomic_set(&ctx->refcount, 1);
1081         ctx->flags = flags;
1082         ctx->state = UFFD_STATE_WAIT_API;
1083         ctx->released = false;
1084         ctx->mm = current->mm;
1085         /* prevent the mm struct to be freed */
1086         atomic_inc(&ctx->mm->mm_users);
1087
1088         file = anon_inode_getfile("[userfaultfd]", &userfaultfd_fops, ctx,
1089                                   O_RDWR | (flags & UFFD_SHARED_FCNTL_FLAGS));
1090         if (IS_ERR(file))
1091                 kmem_cache_free(userfaultfd_ctx_cachep, ctx);
1092 out:
1093         return file;
1094 }
1095
1096 SYSCALL_DEFINE1(userfaultfd, int, flags)
1097 {
1098         int fd, error;
1099         struct file *file;
1100
1101         error = get_unused_fd_flags(flags & UFFD_SHARED_FCNTL_FLAGS);
1102         if (error < 0)
1103                 return error;
1104         fd = error;
1105
1106         file = userfaultfd_file_create(flags);
1107         if (IS_ERR(file)) {
1108                 error = PTR_ERR(file);
1109                 goto err_put_unused_fd;
1110         }
1111         fd_install(fd, file);
1112
1113         return fd;
1114
1115 err_put_unused_fd:
1116         put_unused_fd(fd);
1117
1118         return error;
1119 }
1120
1121 static int __init userfaultfd_init(void)
1122 {
1123         userfaultfd_ctx_cachep = kmem_cache_create("userfaultfd_ctx_cache",
1124                                                 sizeof(struct userfaultfd_ctx),
1125                                                 0,
1126                                                 SLAB_HWCACHE_ALIGN|SLAB_PANIC,
1127                                                 init_once_userfaultfd_ctx);
1128         return 0;
1129 }
1130 __initcall(userfaultfd_init);