diff options
Diffstat (limited to 'src/aristaproto/plugin/parser.py')
-rw-r--r-- | src/aristaproto/plugin/parser.py | 221 |
1 files changed, 221 insertions, 0 deletions
diff --git a/src/aristaproto/plugin/parser.py b/src/aristaproto/plugin/parser.py new file mode 100644 index 0000000..f761af6 --- /dev/null +++ b/src/aristaproto/plugin/parser.py @@ -0,0 +1,221 @@ +import pathlib +import sys +from typing import ( + Generator, + List, + Set, + Tuple, + Union, +) + +from aristaproto.lib.google.protobuf import ( + DescriptorProto, + EnumDescriptorProto, + FieldDescriptorProto, + FileDescriptorProto, + ServiceDescriptorProto, +) +from aristaproto.lib.google.protobuf.compiler import ( + CodeGeneratorRequest, + CodeGeneratorResponse, + CodeGeneratorResponseFeature, + CodeGeneratorResponseFile, +) + +from .compiler import outputfile_compiler +from .models import ( + EnumDefinitionCompiler, + FieldCompiler, + MapEntryCompiler, + MessageCompiler, + OneOfFieldCompiler, + OutputTemplate, + PluginRequestCompiler, + PydanticOneOfFieldCompiler, + ServiceCompiler, + ServiceMethodCompiler, + is_map, + is_oneof, +) + + +def traverse( + proto_file: FileDescriptorProto, +) -> Generator[ + Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None +]: + # Todo: Keep information about nested hierarchy + def _traverse( + path: List[int], + items: Union[List[EnumDescriptorProto], List[DescriptorProto]], + prefix: str = "", + ) -> Generator[ + Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None + ]: + for i, item in enumerate(items): + # Adjust the name since we flatten the hierarchy. + # Todo: don't change the name, but include full name in returned tuple + item.name = next_prefix = f"{prefix}_{item.name}" + yield item, [*path, i] + + if isinstance(item, DescriptorProto): + # Get nested types. + yield from _traverse([*path, i, 4], item.enum_type, next_prefix) + yield from _traverse([*path, i, 3], item.nested_type, next_prefix) + + yield from _traverse([5], proto_file.enum_type) + yield from _traverse([4], proto_file.message_type) + + +def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: + response = CodeGeneratorResponse() + + plugin_options = request.parameter.split(",") if request.parameter else [] + response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL + + request_data = PluginRequestCompiler(plugin_request_obj=request) + # Gather output packages + for proto_file in request.proto_file: + output_package_name = proto_file.package + if output_package_name not in request_data.output_packages: + # Create a new output if there is no output for this package + request_data.output_packages[output_package_name] = OutputTemplate( + parent_request=request_data, package_proto_obj=proto_file + ) + # Add this input file to the output corresponding to this package + request_data.output_packages[output_package_name].input_files.append(proto_file) + + if ( + proto_file.package == "google.protobuf" + and "INCLUDE_GOOGLE" not in plugin_options + ): + # If not INCLUDE_GOOGLE, + # skip outputting Google's well-known types + request_data.output_packages[output_package_name].output = False + + if "pydantic_dataclasses" in plugin_options: + request_data.output_packages[ + output_package_name + ].pydantic_dataclasses = True + + # Read Messages and Enums + # We need to read Messages before Services in so that we can + # get the references to input/output messages for each service + for output_package_name, output_package in request_data.output_packages.items(): + for proto_input_file in output_package.input_files: + for item, path in traverse(proto_input_file): + read_protobuf_type( + source_file=proto_input_file, + item=item, + path=path, + output_package=output_package, + ) + + # Read Services + for output_package_name, output_package in request_data.output_packages.items(): + for proto_input_file in output_package.input_files: + for index, service in enumerate(proto_input_file.service): + read_protobuf_service(service, index, output_package) + + # Generate output files + output_paths: Set[pathlib.Path] = set() + for output_package_name, output_package in request_data.output_packages.items(): + if not output_package.output: + continue + + # Add files to the response object + output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") + output_paths.add(output_path) + + response.file.append( + CodeGeneratorResponseFile( + name=str(output_path), + # Render and then format the output file + content=outputfile_compiler(output_file=output_package), + ) + ) + + # Make each output directory a package with __init__ file + init_files = { + directory.joinpath("__init__.py") + for path in output_paths + for directory in path.parents + if not directory.joinpath("__init__.py").exists() + } - output_paths + + for init_file in init_files: + response.file.append(CodeGeneratorResponseFile(name=str(init_file))) + + for output_package_name in sorted(output_paths.union(init_files)): + print(f"Writing {output_package_name}", file=sys.stderr) + + return response + + +def _make_one_of_field_compiler( + output_package: OutputTemplate, + source_file: "FileDescriptorProto", + parent: MessageCompiler, + proto_obj: "FieldDescriptorProto", + path: List[int], +) -> FieldCompiler: + pydantic = output_package.pydantic_dataclasses + Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler + return Cls( + source_file=source_file, + parent=parent, + proto_obj=proto_obj, + path=path, + ) + + +def read_protobuf_type( + item: DescriptorProto, + path: List[int], + source_file: "FileDescriptorProto", + output_package: OutputTemplate, +) -> None: + if isinstance(item, DescriptorProto): + if item.options.map_entry: + # Skip generated map entry messages since we just use dicts + return + # Process Message + message_data = MessageCompiler( + source_file=source_file, parent=output_package, proto_obj=item, path=path + ) + for index, field in enumerate(item.field): + if is_map(field, item): + MapEntryCompiler( + source_file=source_file, + parent=message_data, + proto_obj=field, + path=path + [2, index], + ) + elif is_oneof(field): + _make_one_of_field_compiler( + output_package, source_file, message_data, field, path + [2, index] + ) + else: + FieldCompiler( + source_file=source_file, + parent=message_data, + proto_obj=field, + path=path + [2, index], + ) + elif isinstance(item, EnumDescriptorProto): + # Enum + EnumDefinitionCompiler( + source_file=source_file, parent=output_package, proto_obj=item, path=path + ) + + +def read_protobuf_service( + service: ServiceDescriptorProto, index: int, output_package: OutputTemplate +) -> None: + service_data = ServiceCompiler( + parent=output_package, proto_obj=service, path=[6, index] + ) + for j, method in enumerate(service.method): + ServiceMethodCompiler( + parent=service_data, proto_obj=method, path=[6, index, 2, j] + ) |