diff options
Diffstat (limited to 'tests/topotests/bgp_rpki_topo1/r1/rtrd.py')
-rwxr-xr-x | tests/topotests/bgp_rpki_topo1/r1/rtrd.py | 330 |
1 files changed, 330 insertions, 0 deletions
diff --git a/tests/topotests/bgp_rpki_topo1/r1/rtrd.py b/tests/topotests/bgp_rpki_topo1/r1/rtrd.py new file mode 100755 index 0000000..bca58a6 --- /dev/null +++ b/tests/topotests/bgp_rpki_topo1/r1/rtrd.py @@ -0,0 +1,330 @@ +#!/usr/bin/python3 +# SPDX-License-Identifier: GPL-2.0-or-later + +# Copyright (C) 2023 Tomas Hlavacek (tmshlvck@gmail.com) + +from typing import List, Tuple, Callable, Type +import socket +import threading +import socketserver +import struct +import ipaddress +import csv +import os +import sys + +LISTEN_HOST, LISTEN_PORT = "0.0.0.0", 15432 +VRPS_FILE = os.path.join(sys.path[0], "vrps.csv") + + +def dbg(m: str): + print(m) + sys.stdout.flush() + + +class RTRDatabase(object): + def __init__(self, vrps_file: str) -> None: + self.last_serial = 0 + self.ann4 = [] + self.ann6 = [] + self.withdraw4 = [] + self.withdraw6 = [] + + with open(vrps_file, "r") as fh: + for rasn, rnet, rmaxlen, _ in csv.reader(fh): + try: + net = ipaddress.ip_network(rnet) + asn = int(rasn[2:]) + maxlen = int(rmaxlen) + if net.version == 6: + self.ann6.append((asn, str(net), maxlen)) + elif net.version == 4: + self.ann4.append((asn, str(net), maxlen)) + else: + raise ValueError(f"Unknown AFI: {net.version}") + except Exception as e: + dbg( + f"VRPS load: ignoring {str((rasn, rnet,rmaxlen))} because {str(e)}" + ) + + def get_serial(self) -> int: + return self.last_serial + + def set_serial(self, serial: int) -> None: + self.last_serial = serial + + def get_announcements4(self, serial: int = 0) -> List[Tuple[int, str, int]]: + if serial > self.last_serial: + return self.ann4 + else: + return [] + + def get_withdrawals4(self, serial: int = 0) -> List[Tuple[int, str, int]]: + if serial > self.last_serial: + return self.withdraw4 + else: + return [] + + def get_announcements6(self, serial: int = 0) -> List[Tuple[int, str, int]]: + if serial > self.last_serial: + return self.ann6 + else: + return [] + + def get_withdrawals6(self, serial: int = 0) -> List[Tuple[int, str, int]]: + if serial > self.last_serial: + return self.withdraw6 + else: + return [] + + +class RTRConnHandler(socketserver.BaseRequestHandler): + PROTO_VERSION = 0 + + def setup(self) -> None: + self.session_id = 2345 + self.serial = 1024 + + dbg(f"New connection from: {str(self.client_address)} ") + # TODO: register for notifies + + def finish(self) -> None: + pass + # TODO: de-register + + HEADER_LEN = 8 + + def decode_header(self, buf: bytes) -> Tuple[int, int, int, int]: + # common header in all received packets + return struct.unpack("!BBHI", buf) + # reutnrs (proto_ver, pdu_type, sess_id, length) + + SERNOTIFY_TYPE = 0 + SERNOTIFY_LEN = 12 + + def send_sernotify(self, serial: int) -> None: + # serial notify PDU + dbg(f"<Serial Notify session_id={self.session_id} serial={serial}") + self.request.send( + struct.pack( + "!BBHII", + self.PROTO_VERSION, + self.SERNOTIFY_TYPE, + self.session_id, + self.SERNOTIFY_LEN, + serial, + ) + ) + + CACHERESPONSE_TYPE = 3 + CACHERESPONSE_LEN = 8 + + def send_cacheresponse(self) -> None: + # cache response PDU + dbg(f"<Cache response session_id={self.session_id}") + self.request.send( + struct.pack( + "!BBHI", + self.PROTO_VERSION, + self.CACHERESPONSE_TYPE, + self.session_id, + self.CACHERESPONSE_LEN, + ) + ) + + FLAGS_ANNOUNCE = 1 + FLAGS_WITHDRAW = 0 + + IPV4_TYPE = 4 + IPV4_LEN = 20 + + def send_ipv4(self, ipnet: str, asn: int, maxlen: int, flags: int): + # IPv4 PDU + dbg(f"<IPv4 net={ipnet} asn={asn} maxlen={maxlen} flags={flags}") + ip = ipaddress.IPv4Network(ipnet) + self.request.send( + struct.pack( + "!BBHIBBBB4sI", + self.PROTO_VERSION, + self.IPV4_TYPE, + 0, + self.IPV4_LEN, + flags, + ip.prefixlen, + maxlen, + 0, + ip.network_address.packed, + asn, + ) + ) + + def announce_ipv4(self, ipnet, asn, maxlen): + self.send_ipv4(ipnet, asn, maxlen, self.FLAGS_ANNOUNCE) + + def withdraw_ipv4(self, ipnet, asn, maxlen): + self.send_ipv4(ipnet, asn, maxlen, self.FLAGS_WITHDRAW) + + IPV6_TYPE = 6 + IPV6_LEN = 32 + + def send_ipv6(self, ipnet: str, asn: int, maxlen: int, flags: int): + # IPv6 PDU + dbg(f"<IPv6 net={ipnet} asn={asn} maxlen={maxlen} flags={flags}") + ip = ipaddress.IPv6Network(ipnet) + self.request.send( + struct.pack( + "!BBHIBBBB16sI", + self.PROTO_VERSION, + self.IPV6_TYPE, + 0, + self.IPV6_LEN, + flags, + ip.prefixlen, + maxlen, + 0, + ip.network_address.packed, + asn, + ) + ) + + def announce_ipv6(self, ipnet: str, asn: int, maxlen: int): + self.send_ipv6(ipnet, asn, maxlen, self.FLAGS_ANNOUNCE) + + def withdraw_ipv6(self, ipnet: str, asn: int, maxlen: int): + self.send_ipv6(ipnet, asn, maxlen, self.FLAGS_WITHDRAW) + + EOD_TYPE = 7 + EOD_LEN = 12 + + def send_endofdata(self, serial: int): + # end of data PDU + dbg(f"<End of Data session_id={self.session_id} serial={serial}") + self.server.db.set_serial(serial) + self.request.send( + struct.pack( + "!BBHII", + self.PROTO_VERSION, + self.EOD_TYPE, + self.session_id, + self.EOD_LEN, + serial, + ) + ) + + CACHERESET_TYPE = 8 + CACHERESET_LEN = 8 + + def send_cachereset(self): + # cache reset PDU + dbg("<Cache Reset") + self.request.send( + struct.pack( + "!BBHI", + self.PROTO_VERSION, + self.CACHERESET_TYPE, + 0, + self.CACHERESET_LEN, + ) + ) + + SERIAL_QUERY_TYPE = 1 + SERIAL_QUERY_LEN = 12 + + def handle_serial_query(self, buf: bytes, sess_id: int): + serial = struct.unpack("!I", buf)[0] + dbg(f">Serial query: {serial}") + if sess_id: + self.server.db.set_serial(serial) + else: + self.server.db.set_serial(0) + self.send_cacheresponse() + + for asn, ipnet, maxlen in self.server.db.get_announcements4(serial): + self.announce_ipv4(ipnet, asn, maxlen) + + for asn, ipnet, maxlen in self.server.db.get_withdrawals4(serial): + self.withdraw_ipv4(ipnet, asn, maxlen) + + for asn, ipnet, maxlen in self.server.db.get_announcements6(serial): + self.announce_ipv6(ipnet, asn, maxlen) + + for asn, ipnet, maxlen in self.server.db.get_withdrawals6(serial): + self.withdraw_ipv6(ipnet, asn, maxlen) + + self.send_endofdata(self.serial) + + RESET_TYPE = 2 + + def handle_reset(self): + dbg(">Reset") + self.session_id += 1 + self.server.db.set_serial(0) + self.send_cacheresponse() + + for asn, ipnet, maxlen in self.server.db.get_announcements4(self.serial): + self.announce_ipv4(ipnet, asn, maxlen) + + for asn, ipnet, maxlen in self.server.db.get_announcements6(self.serial): + self.announce_ipv6(ipnet, asn, maxlen) + + self.send_endofdata(self.serial) + + ERROR_TYPE = 10 + + def handle_error(self, buf: bytes): + dbg(f">Error: {str(buf)}") + self.server.shutdown() + self.server.stopped = True + raise ConnectionError("Received an RPKI error packet from FRR. Exiting") + + def handle(self): + while True: + b = self.request.recv(self.HEADER_LEN, socket.MSG_WAITALL) + if len(b) == 0: + break + proto_ver, pdu_type, sess_id, length = self.decode_header(b) + dbg( + f">Header proto_ver={proto_ver} pdu_type={pdu_type} sess_id={sess_id} length={length}" + ) + + if sess_id: + self.session_id = sess_id + + if pdu_type == self.SERIAL_QUERY_TYPE: + b = self.request.recv( + self.SERIAL_QUERY_LEN - self.HEADER_LEN, socket.MSG_WAITALL + ) + self.handle_serial_query(b, sess_id) + + elif pdu_type == self.RESET_TYPE: + self.handle_reset() + + elif pdu_type == self.ERROR_TYPE: + b = self.request.recv(length - self.HEADER_LEN, socket.MSG_WAITALL) + self.handle_error(b) + + +class ThreadedTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer): + def __init__( + self, bind: Tuple[str, int], handler: Type[RTRConnHandler], db: RTRDatabase + ) -> None: + super().__init__(bind, handler) + self.db = db + + +def main(): + db = RTRDatabase(VRPS_FILE) + server = ThreadedTCPServer((LISTEN_HOST, LISTEN_PORT), RTRConnHandler, db) + dbg(f"Server listening on {LISTEN_HOST} port {LISTEN_PORT}") + server.serve_forever() + + +if __name__ == "__main__": + if len(sys.argv) > 1: + f = open(sys.argv[1], "w") + sys.__stdout__ = f + sys.stdout = f + sys.__stderr__ = f + sys.stderr = f + + main() |