diff options
Diffstat (limited to 'tests/topotests/lib/fe_client.py')
-rwxr-xr-x | tests/topotests/lib/fe_client.py | 420 |
1 files changed, 420 insertions, 0 deletions
diff --git a/tests/topotests/lib/fe_client.py b/tests/topotests/lib/fe_client.py new file mode 100755 index 0000000..07059cc --- /dev/null +++ b/tests/topotests/lib/fe_client.py @@ -0,0 +1,420 @@ +#!/usr/bin/env python +# -*- coding: utf-8 eval: (blacken-mode 1) -*- +# SPDX-License-Identifier: GPL-2.0-or-later +# +# November 27 2023, Christian Hopps <chopps@labn.net> +# +# Copyright (c) 2023, LabN Consulting, L.L.C. +# +# noqa: E501 +# +import argparse +import json +import logging +import os +import socket +import struct +import sys +import time +from pathlib import Path + +CWD = os.path.dirname(os.path.realpath(__file__)) + +# This is painful but works if you have installed protobuf would be better if we +# actually built and installed these but ... python packaging. +try: + sys.path.append(os.path.dirname(CWD)) + from munet.base import commander + + commander.cmd_raises(f"protoc --python_out={CWD} -I {CWD}/../../../lib mgmt.proto") +except Exception as error: + logging.error("can't create protobuf definition modules %s", error) + raise + +try: + sys.path[0:0] = "." + import mgmt_pb2 +except Exception as error: + logging.error("can't import proto definition modules %s", error) + raise + +CANDIDATE_DS = mgmt_pb2.DatastoreId.CANDIDATE_DS +OPERATIONAL_DS = mgmt_pb2.DatastoreId.OPERATIONAL_DS +RUNNING_DS = mgmt_pb2.DatastoreId.RUNNING_DS +STARTUP_DS = mgmt_pb2.DatastoreId.STARTUP_DS + +# ===================== +# Native message values +# ===================== + +MGMT_MSG_MARKER_PROTOBUF = b"\000###" +MGMT_MSG_MARKER_NATIVE = b"\001###" + +# +# Native message formats +# +MSG_HDR_FMT = "=H2xIQQ" +HDR_FIELD_CODE = 0 +HDR_FIELD_VSPLIT = 1 +HDR_FIELD_SESS_ID = 2 +HDR_FIELD_REQ_ID = 3 + +MSG_ERROR_FMT = "=h6x" +ERROR_FIELD_ERROR = 0 + +# MSG_GET_TREE_FMT = "=B7x" +# GET_TREE_FIELD_RESULT_TYPE = 0 + +MSG_TREE_DATA_FMT = "=bBB5x" +TREE_DATA_FIELD_PARTIAL_ERROR = 0 +TREE_DATA_FIELD_RESULT_TYPE = 1 +TREE_DATA_FIELD_MORE = 2 + +MSG_GET_DATA_FMT = "=BB6x" +GET_DATA_FIELD_RESULT_TYPE = 0 +GET_DATA_FIELD_FLAGS = 1 +GET_DATA_FLAG_STATE = 0x1 +GET_DATA_FLAG_CONFIG = 0x2 +GET_DATA_FLAG_EXACT = 0x4 + +MSG_NOTIFY_FMT = "=B7x" +NOTIFY_FIELD_RESULT_TYPE = 0 + +# +# Native message codes +# +MSG_CODE_ERROR = 0 +# MSG_CODE_GET_TREE = 1 +MSG_CODE_TREE_DATA = 2 +MSG_CODE_GET_DATA = 3 +MSG_CODE_NOTIFY = 4 + +msg_native_formats = { + MSG_CODE_ERROR: MSG_ERROR_FMT, + # MSG_CODE_GET_TREE: MSG_GET_TREE_FMT, + MSG_CODE_TREE_DATA: MSG_TREE_DATA_FMT, + MSG_CODE_GET_DATA: MSG_GET_DATA_FMT, + MSG_CODE_NOTIFY: MSG_NOTIFY_FMT, +} + + +# Result formats +MSG_FORMAT_XML = 1 +MSG_FORMAT_JSON = 2 +MSG_FORMAT_LYB = 3 + + +def cstr(mdata): + assert mdata[-1] == 0 + return mdata[:-1] + + +class FEClientError(Exception): + pass + + +class PBMessageError(FEClientError): + def __init__(self, msg, errstr): + self.msg = msg + # self.sess_id = mhdr[HDR_FIELD_SESS_ID] + # self.req_id = mhdr[HDR_FIELD_REQ_ID] + self.error = -1 + self.errstr = errstr + super().__init__(f"PBMessageError: {self.errstr}: {msg}") + + +class NativeMessageError(FEClientError): + def __init__(self, mhdr, mfixed, mdata): + self.mhdr = mhdr + self.sess_id = mhdr[HDR_FIELD_SESS_ID] + self.req_id = mhdr[HDR_FIELD_REQ_ID] + self.error = mfixed[0] + self.errstr = cstr(mdata) + super().__init__( + "NativeMessageError: " + f"session {self.sess_id} reqid {self.req_id} " + f"error {self.error}: {self.errstr}" + ) + + +# +# Low-level socket functions +# + + +def recv_wait(sock, size): + """Receive a fixed number of bytes from a stream socket.""" + data = b"" + while len(data) < size: + newdata = sock.recv(size - len(data)) + if not newdata: + raise Exception("Socket closed") + data += newdata + return data + + +def recv_msg(sock): + marker = recv_wait(sock, 4) + assert marker in (MGMT_MSG_MARKER_PROTOBUF, MGMT_MSG_MARKER_NATIVE) + + msize = int.from_bytes(recv_wait(sock, 4), byteorder=sys.byteorder) + assert msize >= 8 + mdata = recv_wait(sock, msize - 8) if msize > 8 else b"" + + return mdata, marker == MGMT_MSG_MARKER_NATIVE + + +def send_msg(sock, marker, mdata): + """Send a mgmtd native message to a stream socket.""" + msize = int.to_bytes(len(mdata) + 8, byteorder=sys.byteorder, length=4) + sock.send(marker) + sock.send(msize) + sock.send(mdata) + + +class Session: + """A session to the mgmtd server.""" + + client_id = 1 + + def __init__(self, sock): + self.sock = sock + self.next_req_id = 1 + + req = mgmt_pb2.FeMessage() + req.register_req.client_name = "test-client" + self.send_pb_msg(req) + logging.debug("Sent FeRegisterReq: %s", req) + + req = mgmt_pb2.FeMessage() + req.session_req.create = 1 + req.session_req.client_conn_id = Session.client_id + Session.client_id += 1 + self.send_pb_msg(req) + logging.debug("Sent FeSessionReq: %s", req) + + reply = self.recv_pb_msg(mgmt_pb2.FeMessage()) + logging.debug("Received FeSessionReply: %s", repr(reply)) + + assert reply.session_reply.success + self.sess_id = reply.session_reply.session_id + + def close(self, clean=True): + if clean: + req = mgmt_pb2.FeMessage() + req.session_req.create = 0 + req.session_req.sess_id = self.sess_id + self.send_pb_msg(req) + self.sock.close() + self.sock = None + + def get_next_req_id(self): + req_id = self.next_req_id + self.next_req_id += 1 + return req_id + + # -------------------------- + # Protobuf message functions + # -------------------------- + + def recv_pb_msg(self, msg): + """Receive a protobuf message.""" + mdata, native = recv_msg(self.sock) + assert not native + + msg.ParseFromString(mdata) + + req = getattr(msg, msg.WhichOneof("message")) + if req.HasField("success"): + if not req.success: + raise PBMessageError(msg, req.error_if_any) + + return msg + + def send_pb_msg(self, msg): + """Send a protobuf message.""" + mdata = msg.SerializeToString() + return send_msg(self.sock, MGMT_MSG_MARKER_PROTOBUF, mdata) + + # ------------------------ + # Native message functions + # ------------------------ + + def recv_native_msg(self): + """Send a native message.""" + mdata, native = recv_msg(self.sock) + assert native + + hlen = struct.calcsize(MSG_HDR_FMT) + hdata = mdata[:hlen] + mhdr = struct.unpack(MSG_HDR_FMT, hdata) + code = mhdr[0] + + if code not in msg_native_formats: + raise Exception(f"Unknown native msg code {code} rcvd") + + mfmt = msg_native_formats[code] + flen = struct.calcsize(mfmt) + fdata = mdata[hlen : hlen + flen] + mfixed = struct.unpack(mfmt, fdata) + mdata = mdata[hlen + flen :] + + if code == MSG_ERROR_FMT: + raise NativeMessageError(mhdr, mfixed, mdata) + + return mhdr, mfixed, mdata + + def send_native_msg(self, mdata): + """Send a native message.""" + return send_msg(self.sock, MGMT_MSG_MARKER_NATIVE, mdata) + + def get_native_msg_header(self, msg_code): + req_id = self.get_next_req_id() + hdata = struct.pack(MSG_HDR_FMT, msg_code, 0, self.sess_id, req_id) + return hdata, req_id + + # ----------------------- + # Front-end API Fountains + # ----------------------- + + def lock(self, lock=True, ds_id=mgmt_pb2.CANDIDATE_DS): + req = mgmt_pb2.FeMessage() + req.lockds_req.session_id = self.sess_id + req.lockds_req.req_id = self.get_next_req_id() + req.lockds_req.ds_id = ds_id + req.lockds_req.lock = lock + self.send_pb_msg(req) + logging.debug("Sent LockDsReq: %s", req) + + reply = self.recv_pb_msg(mgmt_pb2.FeMessage()) + logging.debug("Received Reply: %s", repr(reply)) + assert reply.lockds_reply.success + + def get_data(self, query, data=True, config=False): + # Create the message + mdata, req_id = self.get_native_msg_header(MSG_CODE_GET_DATA) + flags = GET_DATA_FLAG_STATE if data else 0 + flags |= GET_DATA_FLAG_CONFIG if config else 0 + mdata += struct.pack(MSG_GET_DATA_FMT, MSG_FORMAT_JSON, flags) + mdata += query.encode("utf-8") + b"\x00" + + self.send_native_msg(mdata) + logging.debug("Sent GET-TREE") + + mhdr, mfixed, mdata = self.recv_native_msg() + assert mdata[-1] == 0 + result = mdata[:-1].decode("utf-8") + + logging.debug("Received GET: %s: %s", mfixed, mdata) + return result + + # def subscribe(self, notif_xpath): + # # Create the message + # mdata, req_id = self.get_native_msg_header(MSG_CODE_SUBSCRIBE) + # mdata += struct.pack(MSG_SUBSCRIBE_FMT, MSG_FORMAT_JSON) + # mdata += notif_xpath.encode("utf-8") + b"\x00" + + # self.send_native_msg(mdata) + # logging.debug("Sent SUBSCRIBE") + + def recv_notify(self, xpaths=None): + while True: + logging.debug("Waiting for Notify Message") + mhdr, mfixed, mdata = self.recv_native_msg() + if mhdr[HDR_FIELD_CODE] == MSG_CODE_NOTIFY: + logging.debug("Received Notify Message: %s: %s", mfixed, mdata) + else: + raise Exception(f"Received NON-NOTIFY Message: {mfixed}: {mdata}") + + vsplit = mhdr[HDR_FIELD_VSPLIT] + assert mdata[vsplit - 1] == 0 + xpath = mdata[: vsplit - 1].decode("utf-8") + + assert mdata[-1] == 0 + result = mdata[vsplit:-1].decode("utf-8") + + if not xpaths: + return result + js = json.loads(result) + key = [x for x in js.keys()][0] + for xpath in xpaths: + if key.startswith(xpath): + return result + logging.debug("'%s' didn't match xpath filters", key) + + +def __parse_args(): + MPATH = "/var/run/frr/mgmtd_fe.sock" + parser = argparse.ArgumentParser() + parser.add_argument( + "-l", "--listen", nargs="*", metavar="XPATH", help="xpath[s] to listen for" + ) + parser.add_argument( + "--notify-count", + type=int, + default=1, + help="Number of notifications to listen for 0 for infinite", + ) + parser.add_argument( + "-b", "--both", action="store_true", help="return both config and data" + ) + parser.add_argument( + "-c", "--config-only", action="store_true", help="return config only" + ) + parser.add_argument( + "-q", "--query", nargs="+", metavar="XPATH", help="xpath[s] to query" + ) + parser.add_argument("-s", "--server", default=MPATH, help="path to server socket") + parser.add_argument("-v", "--verbose", action="store_true", help="Be verbose") + args = parser.parse_args() + + level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig(level=level, format="%(asctime)s %(levelname)s: %(message)s") + + return args + + +def __server_connect(spath): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + logging.debug("Connecting to server on %s", spath) + while ec := sock.connect_ex(str(spath)): + logging.warn("retry server connection in .5s (%s)", os.strerror(ec)) + time.sleep(0.5) + logging.info("Connected to server on %s", spath) + return sock + + +def __main(): + args = __parse_args() + sock = __server_connect(Path(args.server)) + sess = Session(sock) + + if args.query: + # Performa an xpath query + # query = "/frr-interface:lib/interface/state/mtu" + for query in args.query: + logging.info("Sending query: %s", query) + result = sess.get_data( + query, data=not args.config_only, config=(args.both or args.config_only) + ) + print(result) + + if args.listen is not None: + i = args.notify_count + while i > 0 or args.notify_count == 0: + notif = sess.recv_notify(args.listen) + print(notif) + i -= 1 + + +def main(): + try: + __main() + except KeyboardInterrupt: + logging.info("Exiting") + except Exception as error: + logging.error("Unexpected error exiting: %s", error, exc_info=True) + + +if __name__ == "__main__": + main() |