oom_kill: change oom_kill.c to use for_each_thread()
authorOleg Nesterov <oleg@redhat.com>
Tue, 21 Jan 2014 23:49:58 +0000 (15:49 -0800)
committerLinus Torvalds <torvalds@linux-foundation.org>
Wed, 22 Jan 2014 00:19:46 +0000 (16:19 -0800)
Change oom_kill.c to use for_each_thread() rather than the racy
while_each_thread() which can loop forever if we race with exit.

Note also that most users were buggy even if while_each_thread() was
fine, the task can exit even _before_ rcu_read_lock().

Fortunately the new for_each_thread() only requires the stable
task_struct, so this change fixes both problems.

Signed-off-by: Oleg Nesterov <oleg@redhat.com>
Reviewed-by: Sergey Dyasly <dserrg@gmail.com>
Tested-by: Sergey Dyasly <dserrg@gmail.com>
Reviewed-by: Sameer Nanda <snanda@chromium.org>
Cc: "Eric W. Biederman" <ebiederm@xmission.com>
Cc: Frederic Weisbecker <fweisbec@gmail.com>
Cc: Mandeep Singh Baines <msb@chromium.org>
Cc: "Ma, Xindong" <xindong.ma@intel.com>
Reviewed-by: Michal Hocko <mhocko@suse.cz>
Cc: "Tu, Xiaobing" <xiaobing.tu@intel.com>
Acked-by: David Rientjes <rientjes@google.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
mm/oom_kill.c

index 1e4a600a6163645897a42defaf21f437fdf431a6..96d7945f75a6ab67292eaf7e83e28e936bd0e307 100644 (file)
@@ -59,7 +59,7 @@ static bool has_intersects_mems_allowed(struct task_struct *tsk,
 {
        struct task_struct *start = tsk;
 
-       do {
+       for_each_thread(start, tsk) {
                if (mask) {
                        /*
                         * If this is a mempolicy constrained oom, tsk's
@@ -77,7 +77,7 @@ static bool has_intersects_mems_allowed(struct task_struct *tsk,
                        if (cpuset_mems_allowed_intersects(current, tsk))
                                return true;
                }
-       } while_each_thread(start, tsk);
+       }
 
        return false;
 }
@@ -97,14 +97,14 @@ static bool has_intersects_mems_allowed(struct task_struct *tsk,
  */
 struct task_struct *find_lock_task_mm(struct task_struct *p)
 {
-       struct task_struct *t = p;
+       struct task_struct *t;
 
-       do {
+       for_each_thread(p, t) {
                task_lock(t);
                if (likely(t->mm))
                        return t;
                task_unlock(t);
-       } while_each_thread(p, t);
+       }
 
        return NULL;
 }
@@ -301,7 +301,7 @@ static struct task_struct *select_bad_process(unsigned int *ppoints,
        unsigned long chosen_points = 0;
 
        rcu_read_lock();
-       do_each_thread(g, p) {
+       for_each_process_thread(g, p) {
                unsigned int points;
 
                switch (oom_scan_process_thread(p, totalpages, nodemask,
@@ -323,7 +323,7 @@ static struct task_struct *select_bad_process(unsigned int *ppoints,
                        chosen = p;
                        chosen_points = points;
                }
-       } while_each_thread(g, p);
+       }
        if (chosen)
                get_task_struct(chosen);
        rcu_read_unlock();
@@ -406,7 +406,7 @@ void oom_kill_process(struct task_struct *p, gfp_t gfp_mask, int order,
 {
        struct task_struct *victim = p;
        struct task_struct *child;
-       struct task_struct *t = p;
+       struct task_struct *t;
        struct mm_struct *mm;
        unsigned int victim_points = 0;
        static DEFINE_RATELIMIT_STATE(oom_rs, DEFAULT_RATELIMIT_INTERVAL,
@@ -437,7 +437,7 @@ void oom_kill_process(struct task_struct *p, gfp_t gfp_mask, int order,
         * still freeing memory.
         */
        read_lock(&tasklist_lock);
-       do {
+       for_each_thread(p, t) {
                list_for_each_entry(child, &t->children, sibling) {
                        unsigned int child_points;
 
@@ -455,7 +455,7 @@ void oom_kill_process(struct task_struct *p, gfp_t gfp_mask, int order,
                                get_task_struct(victim);
                        }
                }
-       } while_each_thread(p, t);
+       }
        read_unlock(&tasklist_lock);
 
        rcu_read_lock();