diff --git a/include/linux/bpf-cgroup.h b/include/linux/bpf-cgroup.h
index 8b77d08d4b47f333d8d9be9c14706dffd0267fca..a74cd1c3bd8713f13c9e4c016ad0711a4c88b12e 100644
--- a/include/linux/bpf-cgroup.h
+++ b/include/linux/bpf-cgroup.h
@@ -27,19 +27,6 @@ struct task_struct;
 extern struct static_key_false cgroup_bpf_enabled_key[MAX_BPF_ATTACH_TYPE];
 #define cgroup_bpf_enabled(type) static_branch_unlikely(&cgroup_bpf_enabled_key[type])
 
-#define BPF_CGROUP_STORAGE_NEST_MAX	8
-
-struct bpf_cgroup_storage_info {
-	struct task_struct *task;
-	struct bpf_cgroup_storage *storage[MAX_BPF_CGROUP_STORAGE_TYPE];
-};
-
-/* For each cpu, permit maximum BPF_CGROUP_STORAGE_NEST_MAX number of tasks
- * to use bpf cgroup storage simultaneously.
- */
-DECLARE_PER_CPU(struct bpf_cgroup_storage_info,
-		bpf_cgroup_storage_info[BPF_CGROUP_STORAGE_NEST_MAX]);
-
 #define for_each_cgroup_storage_type(stype) \
 	for (stype = 0; stype < MAX_BPF_CGROUP_STORAGE_TYPE; stype++)
 
@@ -172,44 +159,6 @@ static inline enum bpf_cgroup_storage_type cgroup_storage_type(
 	return BPF_CGROUP_STORAGE_SHARED;
 }
 
-static inline int bpf_cgroup_storage_set(struct bpf_cgroup_storage
-					 *storage[MAX_BPF_CGROUP_STORAGE_TYPE])
-{
-	enum bpf_cgroup_storage_type stype;
-	int i, err = 0;
-
-	preempt_disable();
-	for (i = 0; i < BPF_CGROUP_STORAGE_NEST_MAX; i++) {
-		if (unlikely(this_cpu_read(bpf_cgroup_storage_info[i].task) != NULL))
-			continue;
-
-		this_cpu_write(bpf_cgroup_storage_info[i].task, current);
-		for_each_cgroup_storage_type(stype)
-			this_cpu_write(bpf_cgroup_storage_info[i].storage[stype],
-				       storage[stype]);
-		goto out;
-	}
-	err = -EBUSY;
-	WARN_ON_ONCE(1);
-
-out:
-	preempt_enable();
-	return err;
-}
-
-static inline void bpf_cgroup_storage_unset(void)
-{
-	int i;
-
-	for (i = 0; i < BPF_CGROUP_STORAGE_NEST_MAX; i++) {
-		if (unlikely(this_cpu_read(bpf_cgroup_storage_info[i].task) != current))
-			continue;
-
-		this_cpu_write(bpf_cgroup_storage_info[i].task, NULL);
-		return;
-	}
-}
-
 struct bpf_cgroup_storage *
 cgroup_storage_lookup(struct bpf_cgroup_storage_map *map,
 		      void *key, bool locked);
@@ -487,9 +436,6 @@ static inline int cgroup_bpf_prog_query(const union bpf_attr *attr,
 	return -EINVAL;
 }
 
-static inline int bpf_cgroup_storage_set(
-	struct bpf_cgroup_storage *storage[MAX_BPF_CGROUP_STORAGE_TYPE]) { return 0; }
-static inline void bpf_cgroup_storage_unset(void) {}
 static inline int bpf_cgroup_storage_assign(struct bpf_prog_aux *aux,
 					    struct bpf_map *map) { return 0; }
 static inline struct bpf_cgroup_storage *bpf_cgroup_storage_alloc(
diff --git a/include/linux/bpf.h b/include/linux/bpf.h
index 0edff8f5177ebbb4ede7013c63bf9339795b608c..978ebd16ae60cca8d4b4557a519f451ecdcd4613 100644
--- a/include/linux/bpf.h
+++ b/include/linux/bpf.h
@@ -1142,38 +1142,40 @@ int bpf_prog_array_copy(struct bpf_prog_array *old_array,
 			struct bpf_prog *include_prog,
 			struct bpf_prog_array **new_array);
 
+struct bpf_run_ctx {};
+
+struct bpf_cg_run_ctx {
+	struct bpf_run_ctx run_ctx;
+	struct bpf_prog_array_item *prog_item;
+};
+
 /* BPF program asks to bypass CAP_NET_BIND_SERVICE in bind. */
 #define BPF_RET_BIND_NO_CAP_NET_BIND_SERVICE			(1 << 0)
 /* BPF program asks to set CN on the packet. */
 #define BPF_RET_SET_CN						(1 << 0)
 
-/* For BPF_PROG_RUN_ARRAY_FLAGS and __BPF_PROG_RUN_ARRAY,
- * if bpf_cgroup_storage_set() failed, the rest of programs
- * will not execute. This should be a really rare scenario
- * as it requires BPF_CGROUP_STORAGE_NEST_MAX number of
- * preemptions all between bpf_cgroup_storage_set() and
- * bpf_cgroup_storage_unset() on the same cpu.
- */
 #define BPF_PROG_RUN_ARRAY_FLAGS(array, ctx, func, ret_flags)		\
 	({								\
 		struct bpf_prog_array_item *_item;			\
 		struct bpf_prog *_prog;					\
 		struct bpf_prog_array *_array;				\
+		struct bpf_run_ctx *old_run_ctx;			\
+		struct bpf_cg_run_ctx run_ctx;				\
 		u32 _ret = 1;						\
 		u32 func_ret;						\
 		migrate_disable();					\
 		rcu_read_lock();					\
 		_array = rcu_dereference(array);			\
 		_item = &_array->items[0];				\
+		old_run_ctx = bpf_set_run_ctx(&run_ctx.run_ctx);	\
 		while ((_prog = READ_ONCE(_item->prog))) {		\
-			if (unlikely(bpf_cgroup_storage_set(_item->cgroup_storage)))	\
-				break;					\
+			run_ctx.prog_item = _item;			\
 			func_ret = func(_prog, ctx);			\
 			_ret &= (func_ret & 1);				\
-			*(ret_flags) |= (func_ret >> 1);			\
-			bpf_cgroup_storage_unset();			\
+			*(ret_flags) |= (func_ret >> 1);		\
 			_item++;					\
 		}							\
+		bpf_reset_run_ctx(old_run_ctx);				\
 		rcu_read_unlock();					\
 		migrate_enable();					\
 		_ret;							\
@@ -1184,6 +1186,8 @@ int bpf_prog_array_copy(struct bpf_prog_array *old_array,
 		struct bpf_prog_array_item *_item;	\
 		struct bpf_prog *_prog;			\
 		struct bpf_prog_array *_array;		\
+		struct bpf_run_ctx *old_run_ctx;	\
+		struct bpf_cg_run_ctx run_ctx;		\
 		u32 _ret = 1;				\
 		migrate_disable();			\
 		rcu_read_lock();			\
@@ -1191,17 +1195,13 @@ int bpf_prog_array_copy(struct bpf_prog_array *old_array,
 		if (unlikely(check_non_null && !_array))\
 			goto _out;			\
 		_item = &_array->items[0];		\
-		while ((_prog = READ_ONCE(_item->prog))) {		\
-			if (!set_cg_storage) {			\
-				_ret &= func(_prog, ctx);	\
-			} else {				\
-				if (unlikely(bpf_cgroup_storage_set(_item->cgroup_storage)))	\
-					break;			\
-				_ret &= func(_prog, ctx);	\
-				bpf_cgroup_storage_unset();	\
-			}				\
+		old_run_ctx = bpf_set_run_ctx(&run_ctx.run_ctx);\
+		while ((_prog = READ_ONCE(_item->prog))) {	\
+			run_ctx.prog_item = _item;	\
+			_ret &= func(_prog, ctx);	\
 			_item++;			\
 		}					\
+		bpf_reset_run_ctx(old_run_ctx);		\
 _out:							\
 		rcu_read_unlock();			\
 		migrate_enable();			\
@@ -1284,6 +1284,20 @@ static inline void bpf_enable_instrumentation(void)
 	migrate_enable();
 }
 
+static inline struct bpf_run_ctx *bpf_set_run_ctx(struct bpf_run_ctx *new_ctx)
+{
+	struct bpf_run_ctx *old_ctx;
+
+	old_ctx = current->bpf_ctx;
+	current->bpf_ctx = new_ctx;
+	return old_ctx;
+}
+
+static inline void bpf_reset_run_ctx(struct bpf_run_ctx *old_ctx)
+{
+	current->bpf_ctx = old_ctx;
+}
+
 extern const struct file_operations bpf_map_fops;
 extern const struct file_operations bpf_prog_fops;
 extern const struct file_operations bpf_iter_fops;
diff --git a/include/linux/sched.h b/include/linux/sched.h
index ec8d07d88641cd9d0d974377b6b74d920a12c53b..c64119aa2e60fe98dfebeafd182a5a7d0f760225 100644
--- a/include/linux/sched.h
+++ b/include/linux/sched.h
@@ -42,6 +42,7 @@ struct backing_dev_info;
 struct bio_list;
 struct blk_plug;
 struct bpf_local_storage;
+struct bpf_run_ctx;
 struct capture_control;
 struct cfs_rq;
 struct fs_struct;
@@ -1379,6 +1380,8 @@ struct task_struct {
 #ifdef CONFIG_BPF_SYSCALL
 	/* Used by BPF task local storage */
 	struct bpf_local_storage __rcu	*bpf_storage;
+	/* Used for BPF run context */
+	struct bpf_run_ctx		*bpf_ctx;
 #endif
 
 #ifdef CONFIG_GCC_PLUGIN_STACKLEAK
diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c
index 9fe846ec6bd1cde4dad55d1e3c3747bf5a62af77..15746f779fe13f5626e88677c306eb9195abf09c 100644
--- a/kernel/bpf/helpers.c
+++ b/kernel/bpf/helpers.c
@@ -393,8 +393,6 @@ const struct bpf_func_proto bpf_get_current_ancestor_cgroup_id_proto = {
 };
 
 #ifdef CONFIG_CGROUP_BPF
-DECLARE_PER_CPU(struct bpf_cgroup_storage_info,
-		bpf_cgroup_storage_info[BPF_CGROUP_STORAGE_NEST_MAX]);
 
 BPF_CALL_2(bpf_get_local_storage, struct bpf_map *, map, u64, flags)
 {
@@ -403,17 +401,13 @@ BPF_CALL_2(bpf_get_local_storage, struct bpf_map *, map, u64, flags)
 	 * verifier checks that its value is correct.
 	 */
 	enum bpf_cgroup_storage_type stype = cgroup_storage_type(map);
-	struct bpf_cgroup_storage *storage = NULL;
+	struct bpf_cgroup_storage *storage;
+	struct bpf_cg_run_ctx *ctx;
 	void *ptr;
-	int i;
 
-	for (i = 0; i < BPF_CGROUP_STORAGE_NEST_MAX; i++) {
-		if (unlikely(this_cpu_read(bpf_cgroup_storage_info[i].task) != current))
-			continue;
-
-		storage = this_cpu_read(bpf_cgroup_storage_info[i].storage[stype]);
-		break;
-	}
+	/* get current cgroup storage from BPF run context */
+	ctx = container_of(current->bpf_ctx, struct bpf_cg_run_ctx, run_ctx);
+	storage = ctx->prog_item->cgroup_storage[stype];
 
 	if (stype == BPF_CGROUP_STORAGE_SHARED)
 		ptr = &READ_ONCE(storage->buf)->data[0];
diff --git a/kernel/bpf/local_storage.c b/kernel/bpf/local_storage.c
index 95d70a08325dfb727d1bc70091c884e832448546..362e814815942c443c2c7f86b3f7990e11bf90c8 100644
--- a/kernel/bpf/local_storage.c
+++ b/kernel/bpf/local_storage.c
@@ -11,9 +11,6 @@
 
 #ifdef CONFIG_CGROUP_BPF
 
-DEFINE_PER_CPU(struct bpf_cgroup_storage_info,
-	       bpf_cgroup_storage_info[BPF_CGROUP_STORAGE_NEST_MAX]);
-
 #include "../cgroup/cgroup-internal.h"
 
 #define LOCAL_STORAGE_CREATE_FLAG_MASK					\
diff --git a/kernel/fork.c b/kernel/fork.c
index bc94b2cc59956e923cb7bbe11d32a919ef7a053c..e8b41e212110f120beca355a4c5579f62e21d381 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -2083,6 +2083,7 @@ static __latent_entropy struct task_struct *copy_process(
 #endif
 #ifdef CONFIG_BPF_SYSCALL
 	RCU_INIT_POINTER(p->bpf_storage, NULL);
+	p->bpf_ctx = NULL;
 #endif
 
 	/* Perform scheduler related setup. Assign this task to a CPU. */
diff --git a/net/bpf/test_run.c b/net/bpf/test_run.c
index cda8375bbbaf82923b8d89c98d54dbfe4a204159..8d46e2962786ac3ba261ce6d8e3f760947f65a92 100644
--- a/net/bpf/test_run.c
+++ b/net/bpf/test_run.c
@@ -88,17 +88,19 @@ static bool bpf_test_timer_continue(struct bpf_test_timer *t, u32 repeat, int *e
 static int bpf_test_run(struct bpf_prog *prog, void *ctx, u32 repeat,
 			u32 *retval, u32 *time, bool xdp)
 {
-	struct bpf_cgroup_storage *storage[MAX_BPF_CGROUP_STORAGE_TYPE] = { NULL };
+	struct bpf_prog_array_item item = {.prog = prog};
+	struct bpf_run_ctx *old_ctx;
+	struct bpf_cg_run_ctx run_ctx;
 	struct bpf_test_timer t = { NO_MIGRATE };
 	enum bpf_cgroup_storage_type stype;
 	int ret;
 
 	for_each_cgroup_storage_type(stype) {
-		storage[stype] = bpf_cgroup_storage_alloc(prog, stype);
-		if (IS_ERR(storage[stype])) {
-			storage[stype] = NULL;
+		item.cgroup_storage[stype] = bpf_cgroup_storage_alloc(prog, stype);
+		if (IS_ERR(item.cgroup_storage[stype])) {
+			item.cgroup_storage[stype] = NULL;
 			for_each_cgroup_storage_type(stype)
-				bpf_cgroup_storage_free(storage[stype]);
+				bpf_cgroup_storage_free(item.cgroup_storage[stype]);
 			return -ENOMEM;
 		}
 	}
@@ -107,22 +109,19 @@ static int bpf_test_run(struct bpf_prog *prog, void *ctx, u32 repeat,
 		repeat = 1;
 
 	bpf_test_timer_enter(&t);
+	old_ctx = bpf_set_run_ctx(&run_ctx.run_ctx);
 	do {
-		ret = bpf_cgroup_storage_set(storage);
-		if (ret)
-			break;
-
+		run_ctx.prog_item = &item;
 		if (xdp)
 			*retval = bpf_prog_run_xdp(prog, ctx);
 		else
 			*retval = BPF_PROG_RUN(prog, ctx);
-
-		bpf_cgroup_storage_unset();
 	} while (bpf_test_timer_continue(&t, repeat, &ret, time));
+	bpf_reset_run_ctx(old_ctx);
 	bpf_test_timer_leave(&t);
 
 	for_each_cgroup_storage_type(stype)
-		bpf_cgroup_storage_free(storage[stype]);
+		bpf_cgroup_storage_free(item.cgroup_storage[stype]);
 
 	return ret;
 }