#!/usr/bin/env python import asyncio import os import platform import shutil import sys from pathlib import Path from typing import Set from tests.util import ( get_directories, inputs_path, output_path_aristaproto, output_path_aristaproto_pydantic, output_path_reference, protoc, ) # Force pure-python implementation instead of C++, otherwise imports # break things because we can't properly reset the symbol database. os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" def clear_directory(dir_path: Path): for file_or_directory in dir_path.glob("*"): if file_or_directory.is_dir(): shutil.rmtree(file_or_directory) else: file_or_directory.unlink() async def generate(whitelist: Set[str], verbose: bool): test_case_names = set(get_directories(inputs_path)) - {"__pycache__"} path_whitelist = set() name_whitelist = set() for item in whitelist: if item in test_case_names: name_whitelist.add(item) continue path_whitelist.add(item) generation_tasks = [] for test_case_name in sorted(test_case_names): test_case_input_path = inputs_path.joinpath(test_case_name).resolve() if ( whitelist and str(test_case_input_path) not in path_whitelist and test_case_name not in name_whitelist ): continue generation_tasks.append( generate_test_case_output(test_case_input_path, test_case_name, verbose) ) failed_test_cases = [] # Wait for all subprocs and match any failures to names to report for test_case_name, result in zip( sorted(test_case_names), await asyncio.gather(*generation_tasks) ): if result != 0: failed_test_cases.append(test_case_name) if len(failed_test_cases) > 0: sys.stderr.write( "\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n" ) for failed_test_case in failed_test_cases: sys.stderr.write(f"- {failed_test_case}\n") sys.exit(1) async def generate_test_case_output( test_case_input_path: Path, test_case_name: str, verbose: bool ) -> int: """ Returns the max of the subprocess return values """ test_case_output_path_reference = output_path_reference.joinpath(test_case_name) test_case_output_path_aristaproto = output_path_aristaproto test_case_output_path_aristaproto_pyd = output_path_aristaproto_pydantic os.makedirs(test_case_output_path_reference, exist_ok=True) os.makedirs(test_case_output_path_aristaproto, exist_ok=True) os.makedirs(test_case_output_path_aristaproto_pyd, exist_ok=True) clear_directory(test_case_output_path_reference) clear_directory(test_case_output_path_aristaproto) ( (ref_out, ref_err, ref_code), (plg_out, plg_err, plg_code), (plg_out_pyd, plg_err_pyd, plg_code_pyd), ) = await asyncio.gather( protoc(test_case_input_path, test_case_output_path_reference, True), protoc(test_case_input_path, test_case_output_path_aristaproto, False), protoc( test_case_input_path, test_case_output_path_aristaproto_pyd, False, True ), ) if ref_code == 0: print(f"\033[31;1;4mGenerated reference output for {test_case_name!r}\033[0m") else: print( f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m" ) if verbose: if ref_out: print("Reference stdout:") sys.stdout.buffer.write(ref_out) sys.stdout.buffer.flush() if ref_err: print("Reference stderr:") sys.stderr.buffer.write(ref_err) sys.stderr.buffer.flush() if plg_code == 0: print(f"\033[31;1;4mGenerated plugin output for {test_case_name!r}\033[0m") else: print( f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m" ) if verbose: if plg_out: print("Plugin stdout:") sys.stdout.buffer.write(plg_out) sys.stdout.buffer.flush() if plg_err: print("Plugin stderr:") sys.stderr.buffer.write(plg_err) sys.stderr.buffer.flush() if plg_code_pyd == 0: print( f"\033[31;1;4mGenerated plugin (pydantic compatible) output for {test_case_name!r}\033[0m" ) else: print( f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m" ) if verbose: if plg_out_pyd: print("Plugin stdout:") sys.stdout.buffer.write(plg_out_pyd) sys.stdout.buffer.flush() if plg_err_pyd: print("Plugin stderr:") sys.stderr.buffer.write(plg_err_pyd) sys.stderr.buffer.flush() return max(ref_code, plg_code, plg_code_pyd) HELP = "\n".join( ( "Usage: python generate.py [-h] [-v] [DIRECTORIES or NAMES]", "Generate python classes for standard tests.", "", "DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.", " python generate.py inputs/bool inputs/double inputs/enum", "", "NAMES One or more test-case names to generate classes for.", " python generate.py bool double enums", ) ) def main(): if set(sys.argv).intersection({"-h", "--help"}): print(HELP) return if sys.argv[1:2] == ["-v"]: verbose = True whitelist = set(sys.argv[2:]) else: verbose = False whitelist = set(sys.argv[1:]) if platform.system() == "Windows": asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) asyncio.run(generate(whitelist, verbose)) if __name__ == "__main__": main()