summaryrefslogtreecommitdiffstats
path: root/tests/topotests/ospfapi/ctester.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/topotests/ospfapi/ctester.py')
-rwxr-xr-xtests/topotests/ospfapi/ctester.py154
1 files changed, 154 insertions, 0 deletions
diff --git a/tests/topotests/ospfapi/ctester.py b/tests/topotests/ospfapi/ctester.py
new file mode 100755
index 0000000..ab23744
--- /dev/null
+++ b/tests/topotests/ospfapi/ctester.py
@@ -0,0 +1,154 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 eval: (blacken-mode 1) -*-
+# SPDX-License-Identifier: MIT
+#
+# January 17 2022, Christian Hopps <chopps@labn.net>
+#
+# Copyright 2022, LabN Consulting, L.L.C.
+import argparse
+import asyncio
+import logging
+import os
+import sys
+
+CWD = os.path.dirname(os.path.realpath(__file__))
+
+CLIENTDIR = os.path.abspath(os.path.join(CWD, "../../../ospfclient"))
+if not os.path.exists(CLIENTDIR):
+ CLIENTDIR = os.path.join(CWD, "/usr/lib/frr")
+assert os.path.exists(
+ os.path.join(CLIENTDIR, "ospfclient.py")
+), "can't locate ospfclient.py"
+
+sys.path[0:0] = [CLIENTDIR]
+
+import ospfclient as api # pylint: disable=E0401 # noqa: E402
+
+
+async def do_monitor(c, args):
+ cv = asyncio.Condition()
+
+ async def cb(new_router_id, _):
+ assert new_router_id == c.router_id
+ logging.info("NEW ROUTER ID: %s", new_router_id)
+ sys.stdout.flush()
+ async with cv:
+ cv.notify_all()
+
+ logging.debug("API using monitor router ID callback")
+ await c.monitor_router_id(callback=cb)
+
+ for check in args.monitor:
+ logging.info("Waiting for %s", check)
+
+ while True:
+ async with cv:
+ got = c.router_id
+ if str(check) == str(got):
+ break
+ logging.debug("expected '%s' != '%s'\nwaiting on notify", check, got)
+ await cv.wait()
+
+ logging.info("SUCCESS: %s", check)
+ print("SUCCESS: {}".format(check))
+ sys.stdout.flush()
+
+
+async def do_wait(c, args):
+ cv = asyncio.Condition()
+
+ async def cb(added, removed):
+ logging.debug("callback: added: %s removed: %s", added, removed)
+ sys.stdout.flush()
+ async with cv:
+ cv.notify_all()
+
+ logging.debug("API using monitor reachable callback")
+ await c.monitor_reachable(callback=cb)
+
+ for w in args.wait:
+ check = ",".join(sorted(list(w.split(","))))
+ logging.info("Waiting for %s", check)
+
+ while True:
+ async with cv:
+ got = ",".join(sorted([str(x) for x in c.reachable_routers]))
+ if check == got:
+ break
+ logging.debug("expected '%s' != '%s'\nwaiting on notify", check, got)
+ await cv.wait()
+
+ logging.info("SUCCESS: %s", check)
+ print("SUCCESS: {}".format(check))
+ sys.stdout.flush()
+
+
+async def async_main(args):
+ c = api.OspfOpaqueClient(args.server)
+ await c.connect()
+ if sys.version_info[1] > 6:
+ asyncio.create_task(c._handle_msg_loop()) # pylint: disable=W0212
+ else:
+ asyncio.get_event_loop().create_task(
+ c._handle_msg_loop() # pylint: disable=W0212
+ )
+
+ if args.monitor:
+ await do_monitor(c, args)
+ if args.wait:
+ await do_wait(c, args)
+ return 0
+
+
+def main(*args):
+ ap = argparse.ArgumentParser(args)
+ ap.add_argument(
+ "--monitor", action="append", help="monitor and wait for this router ID"
+ )
+ ap.add_argument("--server", default="localhost", help="OSPF API server")
+ ap.add_argument(
+ "--wait", action="append", help="wait for comma-sep set of reachable routers"
+ )
+ ap.add_argument("-v", "--verbose", action="store_true", help="be verbose")
+ args = ap.parse_args()
+
+ level = logging.DEBUG if args.verbose else logging.INFO
+ logging.basicConfig(
+ level=level, format="%(asctime)s %(levelname)s: TESTER: %(name)s: %(message)s"
+ )
+
+ # We need to flush this output to stdout right away
+ h = logging.StreamHandler(sys.stdout)
+ h.flush = sys.stdout.flush
+ f = logging.Formatter("%(asctime)s %(name)s: %(levelname)s: %(message)s")
+ h.setFormatter(f)
+ logger = logging.getLogger("ospfclient")
+ logger.addHandler(h)
+ logger.propagate = False
+
+ logging.info("ctester: starting")
+ sys.stdout.flush()
+
+ status = 3
+ try:
+ if sys.version_info[1] > 6:
+ status = asyncio.run(async_main(args))
+ else:
+ loop = asyncio.get_event_loop()
+ try:
+ status = loop.run_until_complete(async_main(args))
+ finally:
+ loop.close()
+ except KeyboardInterrupt:
+ logging.info("Exiting, received KeyboardInterrupt in main")
+ except Exception as error:
+ logging.info("Exiting, unexpected exception %s", error, exc_info=True)
+ else:
+ logging.info("api: clean exit")
+
+ return status
+
+
+if __name__ == "__main__":
+ exit_status = main()
+ sys.exit(exit_status)