diff --git a/mm/memcontrol.c b/mm/memcontrol.c index c3d98ab41f1f..c03d4787d466 100644 --- a/mm/memcontrol.c +++ b/mm/memcontrol.c @@ -805,12 +805,17 @@ static long memcg_state_val_in_pages(int idx, long val) * Used in mod_memcg_state() and mod_memcg_lruvec_state() to avoid race with * reparenting of non-hierarchical state_locals. */ -static inline struct mem_cgroup *get_non_dying_memcg_start(struct mem_cgroup *memcg) +static inline struct mem_cgroup *get_non_dying_memcg_start(struct mem_cgroup *memcg, + bool *rcu_locked) { - if (cgroup_subsys_on_dfl(memory_cgrp_subsys)) + /* Rebinding can cause this value to be changed at runtime */ + if (cgroup_subsys_on_dfl(memory_cgrp_subsys)) { + *rcu_locked = false; return memcg; + } rcu_read_lock(); + *rcu_locked = true; while (memcg_is_dying(memcg)) memcg = parent_mem_cgroup(memcg); @@ -818,20 +823,21 @@ static inline struct mem_cgroup *get_non_dying_memcg_start(struct mem_cgroup *me return memcg; } -static inline void get_non_dying_memcg_end(void) +static inline void get_non_dying_memcg_end(bool rcu_locked) { - if (cgroup_subsys_on_dfl(memory_cgrp_subsys)) + if (!rcu_locked) return; rcu_read_unlock(); } #else -static inline struct mem_cgroup *get_non_dying_memcg_start(struct mem_cgroup *memcg) +static inline struct mem_cgroup *get_non_dying_memcg_start(struct mem_cgroup *memcg, + bool *rcu_locked) { return memcg; } -static inline void get_non_dying_memcg_end(void) +static inline void get_non_dying_memcg_end(bool rcu_locked) { } #endif @@ -865,12 +871,14 @@ static void __mod_memcg_state(struct mem_cgroup *memcg, void mod_memcg_state(struct mem_cgroup *memcg, enum memcg_stat_item idx, int val) { + bool rcu_locked = false; + if (mem_cgroup_disabled()) return; - memcg = get_non_dying_memcg_start(memcg); + memcg = get_non_dying_memcg_start(memcg, &rcu_locked); __mod_memcg_state(memcg, idx, val); - get_non_dying_memcg_end(); + get_non_dying_memcg_end(rcu_locked); } #ifdef CONFIG_MEMCG_V1 @@ -933,14 +941,15 @@ static void mod_memcg_lruvec_state(struct lruvec *lruvec, struct pglist_data *pgdat = lruvec_pgdat(lruvec); struct mem_cgroup_per_node *pn; struct mem_cgroup *memcg; + bool rcu_locked = false; pn = container_of(lruvec, struct mem_cgroup_per_node, lruvec); - memcg = get_non_dying_memcg_start(pn->memcg); + memcg = get_non_dying_memcg_start(pn->memcg, &rcu_locked); pn = memcg->nodeinfo[pgdat->node_id]; __mod_memcg_lruvec_state(pn, idx, val); - get_non_dying_memcg_end(); + get_non_dying_memcg_end(rcu_locked); } /**