diff options
Diffstat (limited to 'tests/topotests/ospfapi/ctester.py')
-rwxr-xr-x | tests/topotests/ospfapi/ctester.py | 154 |
1 files changed, 154 insertions, 0 deletions
diff --git a/tests/topotests/ospfapi/ctester.py b/tests/topotests/ospfapi/ctester.py new file mode 100755 index 0000000..ab23744 --- /dev/null +++ b/tests/topotests/ospfapi/ctester.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 eval: (blacken-mode 1) -*- +# SPDX-License-Identifier: MIT +# +# January 17 2022, Christian Hopps <chopps@labn.net> +# +# Copyright 2022, LabN Consulting, L.L.C. +import argparse +import asyncio +import logging +import os +import sys + +CWD = os.path.dirname(os.path.realpath(__file__)) + +CLIENTDIR = os.path.abspath(os.path.join(CWD, "../../../ospfclient")) +if not os.path.exists(CLIENTDIR): + CLIENTDIR = os.path.join(CWD, "/usr/lib/frr") +assert os.path.exists( + os.path.join(CLIENTDIR, "ospfclient.py") +), "can't locate ospfclient.py" + +sys.path[0:0] = [CLIENTDIR] + +import ospfclient as api # pylint: disable=E0401 # noqa: E402 + + +async def do_monitor(c, args): + cv = asyncio.Condition() + + async def cb(new_router_id, _): + assert new_router_id == c.router_id + logging.info("NEW ROUTER ID: %s", new_router_id) + sys.stdout.flush() + async with cv: + cv.notify_all() + + logging.debug("API using monitor router ID callback") + await c.monitor_router_id(callback=cb) + + for check in args.monitor: + logging.info("Waiting for %s", check) + + while True: + async with cv: + got = c.router_id + if str(check) == str(got): + break + logging.debug("expected '%s' != '%s'\nwaiting on notify", check, got) + await cv.wait() + + logging.info("SUCCESS: %s", check) + print("SUCCESS: {}".format(check)) + sys.stdout.flush() + + +async def do_wait(c, args): + cv = asyncio.Condition() + + async def cb(added, removed): + logging.debug("callback: added: %s removed: %s", added, removed) + sys.stdout.flush() + async with cv: + cv.notify_all() + + logging.debug("API using monitor reachable callback") + await c.monitor_reachable(callback=cb) + + for w in args.wait: + check = ",".join(sorted(list(w.split(",")))) + logging.info("Waiting for %s", check) + + while True: + async with cv: + got = ",".join(sorted([str(x) for x in c.reachable_routers])) + if check == got: + break + logging.debug("expected '%s' != '%s'\nwaiting on notify", check, got) + await cv.wait() + + logging.info("SUCCESS: %s", check) + print("SUCCESS: {}".format(check)) + sys.stdout.flush() + + +async def async_main(args): + c = api.OspfOpaqueClient(args.server) + await c.connect() + if sys.version_info[1] > 6: + asyncio.create_task(c._handle_msg_loop()) # pylint: disable=W0212 + else: + asyncio.get_event_loop().create_task( + c._handle_msg_loop() # pylint: disable=W0212 + ) + + if args.monitor: + await do_monitor(c, args) + if args.wait: + await do_wait(c, args) + return 0 + + +def main(*args): + ap = argparse.ArgumentParser(args) + ap.add_argument( + "--monitor", action="append", help="monitor and wait for this router ID" + ) + ap.add_argument("--server", default="localhost", help="OSPF API server") + ap.add_argument( + "--wait", action="append", help="wait for comma-sep set of reachable routers" + ) + ap.add_argument("-v", "--verbose", action="store_true", help="be verbose") + args = ap.parse_args() + + level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig( + level=level, format="%(asctime)s %(levelname)s: TESTER: %(name)s: %(message)s" + ) + + # We need to flush this output to stdout right away + h = logging.StreamHandler(sys.stdout) + h.flush = sys.stdout.flush + f = logging.Formatter("%(asctime)s %(name)s: %(levelname)s: %(message)s") + h.setFormatter(f) + logger = logging.getLogger("ospfclient") + logger.addHandler(h) + logger.propagate = False + + logging.info("ctester: starting") + sys.stdout.flush() + + status = 3 + try: + if sys.version_info[1] > 6: + status = asyncio.run(async_main(args)) + else: + loop = asyncio.get_event_loop() + try: + status = loop.run_until_complete(async_main(args)) + finally: + loop.close() + except KeyboardInterrupt: + logging.info("Exiting, received KeyboardInterrupt in main") + except Exception as error: + logging.info("Exiting, unexpected exception %s", error, exc_info=True) + else: + logging.info("api: clean exit") + + return status + + +if __name__ == "__main__": + exit_status = main() + sys.exit(exit_status) |