diff options
Diffstat (limited to 'src/libknot/xdp/protocols.h')
-rw-r--r-- | src/libknot/xdp/protocols.h | 457 |
1 files changed, 457 insertions, 0 deletions
diff --git a/src/libknot/xdp/protocols.h b/src/libknot/xdp/protocols.h new file mode 100644 index 0000000..1a18601 --- /dev/null +++ b/src/libknot/xdp/protocols.h @@ -0,0 +1,457 @@ +/* Copyright (C) 2023 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz> + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +#pragma once + +#include <assert.h> +#include <linux/ip.h> +#include <linux/ipv6.h> +#include <linux/tcp.h> +#include <linux/udp.h> +#include <string.h> + +#include "libknot/endian.h" +#include "libknot/xdp/bpf-consts.h" +#include "libknot/xdp/msg.h" + +/* Don't fragment flag. */ +#define IP_DF 0x4000 + +#define HDR_8021Q_LEN 4; + +/* + * Following prot_read_*() functions do not check sanity of parsed packet. + * Broken packets have to be dropped by BPF filter prior getting here. + */ + +inline static void *prot_read_udp(void *data, uint16_t *src_port, uint16_t *dst_port) +{ + const struct udphdr *udp = data; + + *src_port = udp->source; + *dst_port = udp->dest; + + return data + sizeof(*udp); +} + +enum { + PROT_TCP_OPT_ENDOP = 0, + PROT_TCP_OPT_NOOP = 1, + PROT_TCP_OPT_MSS = 2, + PROT_TCP_OPT_WSC = 3, // window scale + + PROT_TCP_OPT_LEN_MSS = 4, + PROT_TCP_OPT_LEN_WSC = 3, +}; + +inline static void *prot_read_tcp(void *data, knot_xdp_msg_t *msg, uint16_t *src_port, uint16_t *dst_port) +{ + const struct tcphdr *tcp = data; + + msg->flags |= KNOT_XDP_MSG_TCP; + + if (tcp->syn) { + msg->flags |= KNOT_XDP_MSG_SYN; + } + if (tcp->ack) { + msg->flags |= KNOT_XDP_MSG_ACK; + } + if (tcp->fin) { + msg->flags |= KNOT_XDP_MSG_FIN; + } + if (tcp->rst) { + msg->flags |= KNOT_XDP_MSG_RST; + } + + msg->seqno = be32toh(tcp->seq); + msg->ackno = be32toh(tcp->ack_seq); + msg->win = be16toh(tcp->window); + + *src_port = tcp->source; + *dst_port = tcp->dest; + + uint8_t *opts = data + sizeof(*tcp), *hdr_end = data + tcp->doff * 4; + while (opts < hdr_end) { + if (opts[0] == PROT_TCP_OPT_ENDOP || opts[0] == PROT_TCP_OPT_NOOP) { + opts++; + continue; + } + + if (opts + 1 > hdr_end || opts + opts[1] > hdr_end) { + // Malformed option. + break; + } + + if (opts[0] == PROT_TCP_OPT_MSS && opts[1] == PROT_TCP_OPT_LEN_MSS) { + msg->flags |= KNOT_XDP_MSG_MSS; + memcpy(&msg->mss, &opts[2], sizeof(msg->mss)); + msg->mss = be16toh(msg->mss); + } + + if (opts[0] == PROT_TCP_OPT_WSC && opts[1] == PROT_TCP_OPT_LEN_WSC) { + msg->flags |= KNOT_XDP_MSG_WSC; + msg->win_scale = opts[2]; + } + + opts += opts[1]; + } + + return hdr_end; +} + +inline static void *prot_read_ipv4(void *data, knot_xdp_msg_t *msg, void **data_end) +{ + const struct iphdr *ip4 = data; + + // Conditions ensured by the BPF filter. + assert(ip4->version == 4); + assert(ip4->frag_off == 0 || ip4->frag_off == __constant_htons(IP_DF)); + // IPv4 header checksum is not verified! + + struct sockaddr_in *src = (struct sockaddr_in *)&msg->ip_from; + struct sockaddr_in *dst = (struct sockaddr_in *)&msg->ip_to; + memcpy(&src->sin_addr, &ip4->saddr, sizeof(src->sin_addr)); + memcpy(&dst->sin_addr, &ip4->daddr, sizeof(dst->sin_addr)); + src->sin_family = AF_INET; + dst->sin_family = AF_INET; + + msg->ecn = (ip4->tos & 0x3); + + *data_end = data + be16toh(ip4->tot_len); + data += ip4->ihl * 4; + + if (ip4->protocol == IPPROTO_TCP) { + return prot_read_tcp(data, msg, &src->sin_port, &dst->sin_port); + } else { + assert(ip4->protocol == IPPROTO_UDP); + return prot_read_udp(data, &src->sin_port, &dst->sin_port); + } +} + +inline static void *prot_read_ipv6(void *data, knot_xdp_msg_t *msg, void **data_end) +{ + const struct ipv6hdr *ip6 = data; + + msg->flags |= KNOT_XDP_MSG_IPV6; + + // Conditions ensured by the BPF filter. + assert(ip6->version == 6); + + struct sockaddr_in6 *src = (struct sockaddr_in6 *)&msg->ip_from; + struct sockaddr_in6 *dst = (struct sockaddr_in6 *)&msg->ip_to; + memcpy(&src->sin6_addr, &ip6->saddr, sizeof(src->sin6_addr)); + memcpy(&dst->sin6_addr, &ip6->daddr, sizeof(dst->sin6_addr)); + src->sin6_family = AF_INET6; + dst->sin6_family = AF_INET6; + src->sin6_flowinfo = 0; + dst->sin6_flowinfo = 0; + // Scope ID is ignored. + + msg->ecn = ((ip6->flow_lbl[0] & 0x30) >> 4); + + data += sizeof(*ip6); + *data_end = data + be16toh(ip6->payload_len); + + if (ip6->nexthdr == IPPROTO_TCP) { + return prot_read_tcp(data, msg, &src->sin6_port, &dst->sin6_port); + } else { + assert(ip6->nexthdr == IPPROTO_UDP); + return prot_read_udp(data, &src->sin6_port, &dst->sin6_port); + } +} + +inline static void *prot_read_eth(void *data, knot_xdp_msg_t *msg, void **data_end, + const uint16_t *vlan_map, unsigned vlan_map_max) +{ + const struct ethhdr *eth = data; + knot_xdp_info_t *info = data - KNOT_XDP_PKT_ALIGNMENT - sizeof(*info); + + memcpy(msg->eth_from, eth->h_source, ETH_ALEN); + memcpy(msg->eth_to, eth->h_dest, ETH_ALEN); + msg->flags = 0; + + if (eth->h_proto == __constant_htons(ETH_P_8021Q)) { + if (info->out_if_index > 0 && info->out_if_index <= vlan_map_max) { + assert(vlan_map); + msg->vlan_tci = vlan_map[info->out_if_index]; + } else { + memcpy(&msg->vlan_tci, data + sizeof(*eth), sizeof(msg->vlan_tci)); + } + data += HDR_8021Q_LEN; + eth = data; + } + + data += sizeof(*eth); + + if (eth->h_proto == __constant_htons(ETH_P_IPV6)) { + return prot_read_ipv6(data, msg, data_end); + } else { + assert(eth->h_proto == __constant_htons(ETH_P_IP)); + return prot_read_ipv4(data, msg, data_end); + } +} + +inline static size_t prot_write_hdrs_len(const knot_xdp_msg_t *msg) +{ + size_t res = sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr); + + if (msg->vlan_tci != 0 || msg->flags & KNOT_XDP_MSG_VLAN) { + res += HDR_8021Q_LEN; + } + + if (msg->flags & KNOT_XDP_MSG_IPV6) { + res += sizeof(struct ipv6hdr) - sizeof(struct iphdr); + } + + if (msg->flags & KNOT_XDP_MSG_TCP) { + res += sizeof(struct tcphdr) - sizeof(struct udphdr); + + if (msg->flags & KNOT_XDP_MSG_MSS) { + res += PROT_TCP_OPT_LEN_MSS; + } + if (msg->flags & KNOT_XDP_MSG_WSC) { + res += PROT_TCP_OPT_LEN_WSC + 1; // 1 == align + } + } + + return res; +} + +/* Checksum endianness implementation notes for ipv4_checksum() and checksum(). + * + * The basis for checksum is addition on big-endian 16-bit words, with bit 16 carrying + * over to bit 0. That can be viewed as first byte carrying to the second and the + * second one carrying back to the first one, i.e. a symmetrical situation. + * Therefore the result is the same even when arithmetics is done on little-endian (!) + */ + +inline static void checksum(uint32_t *result, const void *_data, uint32_t _data_len) +{ + assert(!(_data_len & 1)); + const uint16_t *data = _data; + uint32_t len = _data_len / 2; + while (len-- > 0) { + *result += *data++; + } +} + +inline static void checksum_uint16(uint32_t *result, uint16_t x) +{ + checksum(result, &x, sizeof(x)); +} + +inline static void checksum_payload(uint32_t *result, void *payload, size_t pay_len) +{ + if (pay_len & 1) { + ((uint8_t *)payload)[pay_len++] = 0; + } + checksum(result, payload, pay_len); +} + +inline static uint16_t checksum_finish(uint32_t result, bool nonzero) +{ + while (result > 0xffff) { + result = (result & 0xffff) + (result >> 16); + } + if (!nonzero || result != 0xffff) { + result = ~result; + } + return result; +} + +inline static void prot_write_udp(void *data, const knot_xdp_msg_t *msg, void *data_end, + uint16_t src_port, uint16_t dst_port, uint32_t chksum) +{ + struct udphdr *udp = data; + + udp->len = htobe16(data_end - data); + udp->source = src_port; + udp->dest = dst_port; + + if (msg->flags & KNOT_XDP_MSG_IPV6) { + udp->check = 0; + checksum(&chksum, &udp->len, sizeof(udp->len)); + checksum_uint16(&chksum, htobe16(IPPROTO_UDP)); + checksum_payload(&chksum, data, data_end - data); + udp->check = checksum_finish(chksum, true); + } else { + udp->check = 0; // UDP over IPv4 doesn't require checksum. + } + + assert(data + sizeof(*udp) == msg->payload.iov_base); +} + +inline static void prot_write_tcp(void *data, const knot_xdp_msg_t *msg, void *data_end, + uint16_t src_port, uint16_t dst_port, uint32_t chksum, + uint16_t mss) +{ + struct tcphdr *tcp = data; + + tcp->source = src_port; + tcp->dest = dst_port; + tcp->seq = htobe32(msg->seqno); + tcp->ack_seq = htobe32(msg->ackno); + tcp->window = htobe16(msg->win); + tcp->check = 0; // Temporarily initialize before checksum calculation. + + tcp->res1 = 0; + tcp->urg = 0; + tcp->ece = 0; + tcp->cwr = 0; + tcp->urg_ptr = 0; + + tcp->syn = ((msg->flags & KNOT_XDP_MSG_SYN) ? 1 : 0); + tcp->ack = ((msg->flags & KNOT_XDP_MSG_ACK) ? 1 : 0); + tcp->fin = ((msg->flags & KNOT_XDP_MSG_FIN) ? 1 : 0); + tcp->rst = ((msg->flags & KNOT_XDP_MSG_RST) ? 1 : 0); + + uint8_t *hdr_end = data + sizeof(*tcp); + if (msg->flags & KNOT_XDP_MSG_WSC) { + hdr_end[0] = PROT_TCP_OPT_WSC; + hdr_end[1] = PROT_TCP_OPT_LEN_WSC; + hdr_end[2] = msg->win_scale; + hdr_end += PROT_TCP_OPT_LEN_WSC; + *hdr_end++ = PROT_TCP_OPT_NOOP; // align + } + if (msg->flags & KNOT_XDP_MSG_MSS) { + mss = htobe16(mss); + hdr_end[0] = PROT_TCP_OPT_MSS; + hdr_end[1] = PROT_TCP_OPT_LEN_MSS; + memcpy(&hdr_end[2], &mss, sizeof(mss)); + hdr_end += PROT_TCP_OPT_LEN_MSS; + } + + tcp->psh = ((data_end - (void *)hdr_end > 0) ? 1 : 0); + tcp->doff = (hdr_end - (uint8_t *)tcp) / 4; + assert((hdr_end - (uint8_t *)tcp) % 4 == 0); + + checksum_uint16(&chksum, htobe16(IPPROTO_TCP)); + checksum_uint16(&chksum, htobe16(data_end - data)); + checksum_payload(&chksum, data, data_end - data); + tcp->check = checksum_finish(chksum, false); + + assert(hdr_end == msg->payload.iov_base); +} + +inline static uint16_t from32to16(uint32_t sum) +{ + sum = (sum & 0xffff) + (sum >> 16); + sum = (sum & 0xffff) + (sum >> 16); + return sum; +} + +inline static uint16_t ipv4_checksum(const uint16_t *ipv4_hdr) +{ + uint32_t sum32 = 0; + for (int i = 0; i < 10; ++i) { + if (i != 5) { + sum32 += ipv4_hdr[i]; + } + } + return ~from32to16(sum32); +} + +inline static void prot_write_ipv4(void *data, const knot_xdp_msg_t *msg, + void *data_end, uint16_t tcp_mss) +{ + struct iphdr *ip4 = data; + + ip4->version = 4; + ip4->ihl = sizeof(*ip4) / 4; + ip4->tos = msg->ecn; + ip4->tot_len = htobe16(data_end - data); + ip4->id = 0; + ip4->frag_off = 0; + ip4->ttl = IPDEFTTL; + ip4->protocol = ((msg->flags & KNOT_XDP_MSG_TCP) ? IPPROTO_TCP : IPPROTO_UDP); + + const struct sockaddr_in *src = (const struct sockaddr_in *)&msg->ip_from; + const struct sockaddr_in *dst = (const struct sockaddr_in *)&msg->ip_to; + memcpy(&ip4->saddr, &src->sin_addr, sizeof(src->sin_addr)); + memcpy(&ip4->daddr, &dst->sin_addr, sizeof(dst->sin_addr)); + + ip4->check = ipv4_checksum(data); + + data += sizeof(*ip4); + + if (msg->flags & KNOT_XDP_MSG_TCP) { + uint32_t chk = 0; + checksum(&chk, &src->sin_addr, sizeof(src->sin_addr)); + checksum(&chk, &dst->sin_addr, sizeof(dst->sin_addr)); + + prot_write_tcp(data, msg, data_end, src->sin_port, dst->sin_port, chk, tcp_mss); + } else { + prot_write_udp(data, msg, data_end, src->sin_port, dst->sin_port, 0); // IPv4/UDP requires no checksum + } +} + +inline static void prot_write_ipv6(void *data, const knot_xdp_msg_t *msg, + void *data_end, uint16_t tcp_mss) +{ + struct ipv6hdr *ip6 = data; + + ip6->version = 6; + ip6->priority = 0; + ip6->flow_lbl[0] = (msg->ecn << 4); + ip6->payload_len = htobe16(data_end - data - sizeof(*ip6)); + ip6->nexthdr = ((msg->flags & KNOT_XDP_MSG_TCP) ? IPPROTO_TCP : IPPROTO_UDP); + ip6->hop_limit = IPDEFTTL; + + memset(ip6->flow_lbl, 0, sizeof(ip6->flow_lbl)); + + const struct sockaddr_in6 *src = (const struct sockaddr_in6 *)&msg->ip_from; + const struct sockaddr_in6 *dst = (const struct sockaddr_in6 *)&msg->ip_to; + memcpy(&ip6->saddr, &src->sin6_addr, sizeof(src->sin6_addr)); + memcpy(&ip6->daddr, &dst->sin6_addr, sizeof(dst->sin6_addr)); + + data += sizeof(*ip6); + + uint32_t chk = 0; + checksum(&chk, &src->sin6_addr, sizeof(src->sin6_addr)); + checksum(&chk, &dst->sin6_addr, sizeof(dst->sin6_addr)); + + if (msg->flags & KNOT_XDP_MSG_TCP) { + prot_write_tcp(data, msg, data_end, src->sin6_port, dst->sin6_port, chk, tcp_mss); + } else { + prot_write_udp(data, msg, data_end, src->sin6_port, dst->sin6_port, chk); + } +} + +inline static void prot_write_eth(void *data, const knot_xdp_msg_t *msg, + void *data_end, uint16_t tcp_mss) +{ + struct ethhdr *eth = data; + + memcpy(eth->h_source, msg->eth_from, ETH_ALEN); + memcpy(eth->h_dest, msg->eth_to, ETH_ALEN); + + if (msg->vlan_tci != 0) { + eth->h_proto = __constant_htons(ETH_P_8021Q); + memcpy(data + sizeof(*eth), &msg->vlan_tci, sizeof(msg->vlan_tci)); + data += HDR_8021Q_LEN; + eth = data; + } + + data += sizeof(*eth); + + if (msg->flags & KNOT_XDP_MSG_IPV6) { + eth->h_proto = __constant_htons(ETH_P_IPV6); + prot_write_ipv6(data, msg, data_end, tcp_mss); + } else { + eth->h_proto = __constant_htons(ETH_P_IP); + prot_write_ipv4(data, msg, data_end, tcp_mss); + } +} |