summaryrefslogtreecommitdiffstats
path: root/src/core/bpf/socket_bind/socket-bind.bpf.c
blob: b7972a887a4c8948bc1e62512e952839bbd6a23a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
/* SPDX-License-Identifier: LGPL-2.1-or-later */

/* The SPDX header above is actually correct in claiming this was
 * LGPL-2.1-or-later, because it is. Since the kernel doesn't consider that
 * compatible with GPL we will claim this to be GPL however, which should be
 * fine given that LGPL-2.1-or-later downgrades to GPL if needed.
 */

#include "socket-bind-api.bpf.h"
/* <linux/types.h> must precede <bpf/bpf_helpers.h> due to
 * <bpf/bpf_helpers.h> does not depend from type header by design.
 */
#include <linux/types.h>
#include <bpf/bpf_endian.h>
#include <bpf/bpf_helpers.h>
#include <linux/bpf.h>
#include <netinet/in.h>
#include <stdbool.h>

/*
 * max_entries is set from user space with bpf_map__set_max_entries helper.
 */
struct socket_bind_map_t {
        __uint(type, BPF_MAP_TYPE_ARRAY);
        __type(key, __u32);
        __type(value, struct socket_bind_rule);
};

enum socket_bind_action {
        SOCKET_BIND_DENY = 0,
        SOCKET_BIND_ALLOW = 1,
};

struct socket_bind_map_t sd_bind_allow SEC(".maps");
struct socket_bind_map_t sd_bind_deny SEC(".maps");

static __always_inline bool match_af(
                __u8 address_family, const struct socket_bind_rule *r) {
        return r->address_family == AF_UNSPEC || address_family == r->address_family;
}

static __always_inline bool match_protocol(
                __u32 protocol, const struct socket_bind_rule *r) {
        return r->protocol == 0 || r->protocol == protocol;
}

static __always_inline bool match_user_port(
                __u16 port, const struct socket_bind_rule *r) {
        return r->nr_ports == 0 ||
                (port >= r->port_min && port < r->port_min + (__u32) r->nr_ports);
}

static __always_inline bool match(
                __u8 address_family,
                __u32 protocol,
                __u16 port,
                const struct socket_bind_rule *r) {
        return match_af(address_family, r) &&
                match_protocol(protocol, r) &&
                match_user_port(port, r);
}

static __always_inline bool match_rules(
                struct bpf_sock_addr *ctx,
                struct socket_bind_map_t *rules) {
        volatile __u32 user_port = ctx->user_port;
        __u16 port = (__u16)bpf_ntohs(user_port);

        for (__u32 i = 0; i < SOCKET_BIND_MAX_RULES; ++i) {
                const __u32 key = i;
                const struct socket_bind_rule *rule = bpf_map_lookup_elem(rules, &key);

                /* Lookup returns NULL if iterator is advanced past the last
                 * element put in the map. */
                if (!rule)
                        break;

                if (match(ctx->user_family, ctx->protocol, port, rule))
                        return true;
        }

        return false;
}

static __always_inline int bind_socket(struct bpf_sock_addr *ctx) {
        if (match_rules(ctx, &sd_bind_allow))
                return SOCKET_BIND_ALLOW;

        if (match_rules(ctx, &sd_bind_deny))
                return SOCKET_BIND_DENY;

        return SOCKET_BIND_ALLOW;
}

SEC("cgroup/bind4")
int sd_bind4(struct bpf_sock_addr *ctx) {
        if (ctx->user_family != AF_INET || ctx->family != AF_INET)
                return SOCKET_BIND_ALLOW;

        return bind_socket(ctx);
}

SEC("cgroup/bind6")
int sd_bind6(struct bpf_sock_addr *ctx) {
        if (ctx->user_family != AF_INET6 || ctx->family != AF_INET6)
                return SOCKET_BIND_ALLOW;

        return bind_socket(ctx);
}

char _license[] SEC("license") = "GPL";