diff --git a/tools/testing/selftests/bpf/prog_tests/probe_user.c b/tools/testing/selftests/bpf/prog_tests/probe_user.c index 8721671321de..b3e17a329cb4 100644 --- a/tools/testing/selftests/bpf/prog_tests/probe_user.c +++ b/tools/testing/selftests/bpf/prog_tests/probe_user.c @@ -20,6 +20,11 @@ void serial_test_probe_user(void) struct bpf_program *kprobe_progs[prog_count]; struct bpf_object *obj; static const int zero = 0; + struct test_pro_bss { + struct sockaddr_in old; + __u32 test_pid; + }; + struct test_pro_bss results = {}; size_t i; obj = bpf_object__open_file(obj_file, &opts); @@ -34,6 +39,23 @@ void serial_test_probe_user(void) goto cleanup; } + { + struct bpf_map *bss_map; + struct test_pro_bss bss_init = {}; + + bss_init.test_pid = getpid(); + bss_map = bpf_object__find_map_by_name(obj, "test_pro.bss"); + if (!ASSERT_OK_PTR(bss_map, "find_bss_map")) + goto cleanup; + if (!ASSERT_EQ(bpf_map__value_size(bss_map), sizeof(bss_init), + "bss_size")) + goto cleanup; + err = bpf_map__set_initial_value(bss_map, &bss_init, + sizeof(bss_init)); + if (!ASSERT_OK(err, "set_bss_init")) + goto cleanup; + } + err = bpf_object__load(obj); if (CHECK(err, "obj_load", "err %d\n", err)) goto cleanup; @@ -62,11 +84,13 @@ void serial_test_probe_user(void) connect(sock_fd, &curr, sizeof(curr)); close(sock_fd); - err = bpf_map_lookup_elem(results_map_fd, &zero, &tmp); + err = bpf_map_lookup_elem(results_map_fd, &zero, &results); if (CHECK(err, "get_kprobe_res", "failed to get kprobe res: %d\n", err)) goto cleanup; + memcpy(&tmp, &results.old, sizeof(tmp)); + in = (struct sockaddr_in *)&tmp; if (CHECK(memcmp(&tmp, &orig, sizeof(orig)), "check_kprobe_res", "wrong kprobe res from probe read: %s:%u\n", diff --git a/tools/testing/selftests/bpf/progs/test_probe_user.c b/tools/testing/selftests/bpf/progs/test_probe_user.c index a8e501af9604..4bc86c7654b1 100644 --- a/tools/testing/selftests/bpf/progs/test_probe_user.c +++ b/tools/testing/selftests/bpf/progs/test_probe_user.c @@ -5,13 +5,22 @@ #include #include "bpf_misc.h" -static struct sockaddr_in old; +struct test_pro_bss { + struct sockaddr_in old; + __u32 test_pid; +}; + +struct test_pro_bss bss; static int handle_sys_connect_common(struct sockaddr_in *uservaddr) { struct sockaddr_in new; + __u32 cur = bpf_get_current_pid_tgid() >> 32; - bpf_probe_read_user(&old, sizeof(old), uservaddr); + if (bss.test_pid && cur != bss.test_pid) + return 0; + + bpf_probe_read_user(&bss.old, sizeof(bss.old), uservaddr); __builtin_memset(&new, 0xab, sizeof(new)); bpf_probe_write_user(uservaddr, &new, sizeof(new));