diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-07-29 09:40:12 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-07-29 09:40:12 +0000 |
commit | 14b40ec77a4bf8605789cc3aff0eb87625510a41 (patch) | |
tree | 4064d27144d6deaabfcd96df01bd996baa8b51a0 /src/aristaproto/plugin | |
parent | Initial commit. (diff) | |
download | python-aristaproto-14b40ec77a4bf8605789cc3aff0eb87625510a41.tar.xz python-aristaproto-14b40ec77a4bf8605789cc3aff0eb87625510a41.zip |
Adding upstream version 1.2+20240521.upstream/1.2+20240521upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/aristaproto/plugin')
-rw-r--r-- | src/aristaproto/plugin/__init__.py | 1 | ||||
-rw-r--r-- | src/aristaproto/plugin/__main__.py | 4 | ||||
-rw-r--r-- | src/aristaproto/plugin/compiler.py | 50 | ||||
-rwxr-xr-x | src/aristaproto/plugin/main.py | 52 | ||||
-rw-r--r-- | src/aristaproto/plugin/models.py | 851 | ||||
-rw-r--r-- | src/aristaproto/plugin/parser.py | 221 | ||||
-rw-r--r-- | src/aristaproto/plugin/plugin.bat | 2 |
7 files changed, 1181 insertions, 0 deletions
diff --git a/src/aristaproto/plugin/__init__.py b/src/aristaproto/plugin/__init__.py new file mode 100644 index 0000000..c28a133 --- /dev/null +++ b/src/aristaproto/plugin/__init__.py @@ -0,0 +1 @@ +from .main import main diff --git a/src/aristaproto/plugin/__main__.py b/src/aristaproto/plugin/__main__.py new file mode 100644 index 0000000..bd95dae --- /dev/null +++ b/src/aristaproto/plugin/__main__.py @@ -0,0 +1,4 @@ +from .main import main + + +main() diff --git a/src/aristaproto/plugin/compiler.py b/src/aristaproto/plugin/compiler.py new file mode 100644 index 0000000..4bbcc48 --- /dev/null +++ b/src/aristaproto/plugin/compiler.py @@ -0,0 +1,50 @@ +import os.path + + +try: + # aristaproto[compiler] specific dependencies + import black + import isort.api + import jinja2 +except ImportError as err: + print( + "\033[31m" + f"Unable to import `{err.name}` from aristaproto plugin! " + "Please ensure that you've installed aristaproto as " + '`pip install "aristaproto[compiler]"` so that compiler dependencies ' + "are included." + "\033[0m" + ) + raise SystemExit(1) + +from .models import OutputTemplate + + +def outputfile_compiler(output_file: OutputTemplate) -> str: + templates_folder = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "templates") + ) + + env = jinja2.Environment( + trim_blocks=True, + lstrip_blocks=True, + loader=jinja2.FileSystemLoader(templates_folder), + ) + template = env.get_template("template.py.j2") + + code = template.render(output_file=output_file) + code = isort.api.sort_code_string( + code=code, + show_diff=False, + py_version=37, + profile="black", + combine_as_imports=True, + lines_after_imports=2, + quiet=True, + force_grid_wrap=2, + known_third_party=["grpclib", "aristaproto"], + ) + return black.format_str( + src_contents=code, + mode=black.Mode(), + ) diff --git a/src/aristaproto/plugin/main.py b/src/aristaproto/plugin/main.py new file mode 100755 index 0000000..aff3614 --- /dev/null +++ b/src/aristaproto/plugin/main.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python + +import os +import sys + +from aristaproto.lib.google.protobuf.compiler import ( + CodeGeneratorRequest, + CodeGeneratorResponse, +) +from aristaproto.plugin.models import monkey_patch_oneof_index +from aristaproto.plugin.parser import generate_code + + +def main() -> None: + """The plugin's main entry point.""" + # Read request message from stdin + data = sys.stdin.buffer.read() + + # Apply Work around for proto2/3 difference in protoc messages + monkey_patch_oneof_index() + + # Parse request + request = CodeGeneratorRequest() + request.parse(data) + + dump_file = os.getenv("ARISTAPROTO_DUMP") + if dump_file: + dump_request(dump_file, request) + + # Generate code + response = generate_code(request) + + # Serialise response message + output = response.SerializeToString() + + # Write to stdout + sys.stdout.buffer.write(output) + + +def dump_request(dump_file: str, request: CodeGeneratorRequest) -> None: + """ + For developers: Supports running plugin.py standalone so its possible to debug it. + Run protoc (or generate.py) with ARISTAPROTO_DUMP="yourfile.bin" to write the request to a file. + Then run plugin.py from your IDE in debugging mode, and redirect stdin to the file. + """ + with open(str(dump_file), "wb") as fh: + sys.stderr.write(f"\033[31mWriting input from protoc to: {dump_file}\033[0m\n") + fh.write(request.SerializeToString()) + + +if __name__ == "__main__": + main() diff --git a/src/aristaproto/plugin/models.py b/src/aristaproto/plugin/models.py new file mode 100644 index 0000000..484b40d --- /dev/null +++ b/src/aristaproto/plugin/models.py @@ -0,0 +1,851 @@ +"""Plugin model dataclasses. + +These classes are meant to be an intermediate representation +of protobuf objects. They are used to organize the data collected during parsing. + +The general intention is to create a doubly-linked tree-like structure +with the following types of references: +- Downwards references: from message -> fields, from output package -> messages +or from service -> service methods +- Upwards references: from field -> message, message -> package. +- Input/output message references: from a service method to it's corresponding +input/output messages, which may even be in another package. + +There are convenience methods to allow climbing up and down this tree, for +example to retrieve the list of all messages that are in the same package as +the current message. + +Most of these classes take as inputs: +- proto_obj: A reference to it's corresponding protobuf object as +presented by the protoc plugin. +- parent: a reference to the parent object in the tree. + +With this information, the class is able to expose attributes, +such as a pythonized name, that will be calculated from proto_obj. + +The instantiation should also attach a reference to the new object +into the corresponding place within it's parent object. For example, +instantiating field `A` with parent message `B` should add a +reference to `A` to `B`'s `fields` attribute. +""" + + +import builtins +import re +import textwrap +from dataclasses import ( + dataclass, + field, +) +from typing import ( + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Type, + Union, +) + +import aristaproto +from aristaproto import which_one_of +from aristaproto.casing import sanitize_name +from aristaproto.compile.importing import ( + get_type_reference, + parse_source_type_name, +) +from aristaproto.compile.naming import ( + pythonize_class_name, + pythonize_field_name, + pythonize_method_name, +) +from aristaproto.lib.google.protobuf import ( + DescriptorProto, + EnumDescriptorProto, + Field, + FieldDescriptorProto, + FieldDescriptorProtoLabel, + FieldDescriptorProtoType, + FileDescriptorProto, + MethodDescriptorProto, +) +from aristaproto.lib.google.protobuf.compiler import CodeGeneratorRequest + +from ..compile.importing import ( + get_type_reference, + parse_source_type_name, +) +from ..compile.naming import ( + pythonize_class_name, + pythonize_enum_member_name, + pythonize_field_name, + pythonize_method_name, +) + + +# Create a unique placeholder to deal with +# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses +PLACEHOLDER = object() + +# Organize proto types into categories +PROTO_FLOAT_TYPES = ( + FieldDescriptorProtoType.TYPE_DOUBLE, # 1 + FieldDescriptorProtoType.TYPE_FLOAT, # 2 +) +PROTO_INT_TYPES = ( + FieldDescriptorProtoType.TYPE_INT64, # 3 + FieldDescriptorProtoType.TYPE_UINT64, # 4 + FieldDescriptorProtoType.TYPE_INT32, # 5 + FieldDescriptorProtoType.TYPE_FIXED64, # 6 + FieldDescriptorProtoType.TYPE_FIXED32, # 7 + FieldDescriptorProtoType.TYPE_UINT32, # 13 + FieldDescriptorProtoType.TYPE_SFIXED32, # 15 + FieldDescriptorProtoType.TYPE_SFIXED64, # 16 + FieldDescriptorProtoType.TYPE_SINT32, # 17 + FieldDescriptorProtoType.TYPE_SINT64, # 18 +) +PROTO_BOOL_TYPES = (FieldDescriptorProtoType.TYPE_BOOL,) # 8 +PROTO_STR_TYPES = (FieldDescriptorProtoType.TYPE_STRING,) # 9 +PROTO_BYTES_TYPES = (FieldDescriptorProtoType.TYPE_BYTES,) # 12 +PROTO_MESSAGE_TYPES = ( + FieldDescriptorProtoType.TYPE_MESSAGE, # 11 + FieldDescriptorProtoType.TYPE_ENUM, # 14 +) +PROTO_MAP_TYPES = (FieldDescriptorProtoType.TYPE_MESSAGE,) # 11 +PROTO_PACKED_TYPES = ( + FieldDescriptorProtoType.TYPE_DOUBLE, # 1 + FieldDescriptorProtoType.TYPE_FLOAT, # 2 + FieldDescriptorProtoType.TYPE_INT64, # 3 + FieldDescriptorProtoType.TYPE_UINT64, # 4 + FieldDescriptorProtoType.TYPE_INT32, # 5 + FieldDescriptorProtoType.TYPE_FIXED64, # 6 + FieldDescriptorProtoType.TYPE_FIXED32, # 7 + FieldDescriptorProtoType.TYPE_BOOL, # 8 + FieldDescriptorProtoType.TYPE_UINT32, # 13 + FieldDescriptorProtoType.TYPE_SFIXED32, # 15 + FieldDescriptorProtoType.TYPE_SFIXED64, # 16 + FieldDescriptorProtoType.TYPE_SINT32, # 17 + FieldDescriptorProtoType.TYPE_SINT64, # 18 +) + + +def monkey_patch_oneof_index(): + """ + The compiler message types are written for proto2, but we read them as proto3. + For this to work in the case of the oneof_index fields, which depend on being able + to tell whether they were set, we have to treat them as oneof fields. This method + monkey patches the generated classes after the fact to force this behaviour. + """ + object.__setattr__( + FieldDescriptorProto.__dataclass_fields__["oneof_index"].metadata[ + "aristaproto" + ], + "group", + "oneof_index", + ) + object.__setattr__( + Field.__dataclass_fields__["oneof_index"].metadata["aristaproto"], + "group", + "oneof_index", + ) + + +def get_comment( + proto_file: "FileDescriptorProto", path: List[int], indent: int = 4 +) -> str: + pad = " " * indent + for sci_loc in proto_file.source_code_info.location: + if list(sci_loc.path) == path and sci_loc.leading_comments: + lines = sci_loc.leading_comments.strip().replace("\t", " ").split("\n") + # This is a field, message, enum, service, or method + if len(lines) == 1 and len(lines[0]) < 79 - indent - 6: + lines[0] = lines[0].strip('"') + # rstrip to remove trailing spaces including whitespaces from empty lines. + return f'{pad}"""{lines[0]}"""' + else: + # rstrip to remove trailing spaces including empty lines. + padded = [f"\n{pad}{line}".rstrip(" ") for line in lines] + joined = "".join(padded) + return f'{pad}"""{joined}\n{pad}"""' + + return "" + + +class ProtoContentBase: + """Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler.""" + + source_file: FileDescriptorProto + path: List[int] + comment_indent: int = 4 + parent: Union["aristaproto.Message", "OutputTemplate"] + + __dataclass_fields__: Dict[str, object] + + def __post_init__(self) -> None: + """Checks that no fake default fields were left as placeholders.""" + for field_name, field_val in self.__dataclass_fields__.items(): + if field_val is PLACEHOLDER: + raise ValueError(f"`{field_name}` is a required field.") + + @property + def output_file(self) -> "OutputTemplate": + current = self + while not isinstance(current, OutputTemplate): + current = current.parent + return current + + @property + def request(self) -> "PluginRequestCompiler": + current = self + while not isinstance(current, OutputTemplate): + current = current.parent + return current.parent_request + + @property + def comment(self) -> str: + """Crawl the proto source code and retrieve comments + for this object. + """ + return get_comment( + proto_file=self.source_file, path=self.path, indent=self.comment_indent + ) + + +@dataclass +class PluginRequestCompiler: + plugin_request_obj: CodeGeneratorRequest + output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict) + + @property + def all_messages(self) -> List["MessageCompiler"]: + """All of the messages in this request. + + Returns + ------- + List[MessageCompiler] + List of all of the messages in this request. + """ + return [ + msg for output in self.output_packages.values() for msg in output.messages + ] + + +@dataclass +class OutputTemplate: + """Representation of an output .py file. + + Each output file corresponds to a .proto input file, + but may need references to other .proto files to be + built. + """ + + parent_request: PluginRequestCompiler + package_proto_obj: FileDescriptorProto + input_files: List[str] = field(default_factory=list) + imports: Set[str] = field(default_factory=set) + datetime_imports: Set[str] = field(default_factory=set) + typing_imports: Set[str] = field(default_factory=set) + pydantic_imports: Set[str] = field(default_factory=set) + builtins_import: bool = False + messages: List["MessageCompiler"] = field(default_factory=list) + enums: List["EnumDefinitionCompiler"] = field(default_factory=list) + services: List["ServiceCompiler"] = field(default_factory=list) + imports_type_checking_only: Set[str] = field(default_factory=set) + pydantic_dataclasses: bool = False + output: bool = True + + @property + def package(self) -> str: + """Name of input package. + + Returns + ------- + str + Name of input package. + """ + return self.package_proto_obj.package + + @property + def input_filenames(self) -> Iterable[str]: + """Names of the input files used to build this output. + + Returns + ------- + Iterable[str] + Names of the input files used to build this output. + """ + return sorted(f.name for f in self.input_files) + + @property + def python_module_imports(self) -> Set[str]: + imports = set() + if any(x for x in self.messages if any(x.deprecated_fields)): + imports.add("warnings") + if self.builtins_import: + imports.add("builtins") + return imports + + +@dataclass +class MessageCompiler(ProtoContentBase): + """Representation of a protobuf message.""" + + source_file: FileDescriptorProto + parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER + proto_obj: DescriptorProto = PLACEHOLDER + path: List[int] = PLACEHOLDER + fields: List[Union["FieldCompiler", "MessageCompiler"]] = field( + default_factory=list + ) + deprecated: bool = field(default=False, init=False) + builtins_types: Set[str] = field(default_factory=set) + + def __post_init__(self) -> None: + # Add message to output file + if isinstance(self.parent, OutputTemplate): + if isinstance(self, EnumDefinitionCompiler): + self.output_file.enums.append(self) + else: + self.output_file.messages.append(self) + self.deprecated = self.proto_obj.options.deprecated + super().__post_init__() + + @property + def proto_name(self) -> str: + return self.proto_obj.name + + @property + def py_name(self) -> str: + return pythonize_class_name(self.proto_name) + + @property + def annotation(self) -> str: + if self.repeated: + return f"List[{self.py_name}]" + return self.py_name + + @property + def deprecated_fields(self) -> Iterator[str]: + for f in self.fields: + if f.deprecated: + yield f.py_name + + @property + def has_deprecated_fields(self) -> bool: + return any(self.deprecated_fields) + + @property + def has_oneof_fields(self) -> bool: + return any(isinstance(field, OneOfFieldCompiler) for field in self.fields) + + @property + def has_message_field(self) -> bool: + return any( + ( + field.proto_obj.type in PROTO_MESSAGE_TYPES + for field in self.fields + if isinstance(field.proto_obj, FieldDescriptorProto) + ) + ) + + +def is_map( + proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto +) -> bool: + """True if proto_field_obj is a map, otherwise False.""" + if proto_field_obj.type == FieldDescriptorProtoType.TYPE_MESSAGE: + if not hasattr(parent_message, "nested_type"): + return False + + # This might be a map... + message_type = proto_field_obj.type_name.split(".").pop().lower() + map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry" + if message_type == map_entry: + for nested in parent_message.nested_type: # parent message + if ( + nested.name.replace("_", "").lower() == map_entry + and nested.options.map_entry + ): + return True + return False + + +def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: + """ + True if proto_field_obj is a OneOf, otherwise False. + + .. warning:: + Becuase the message from protoc is defined in proto2, and aristaproto works with + proto3, and interpreting the FieldDescriptorProto.oneof_index field requires + distinguishing between default and unset values (which proto3 doesn't support), + we have to hack the generated FieldDescriptorProto class for this to work. + The hack consists of setting group="oneof_index" in the field metadata, + essentially making oneof_index the sole member of a one_of group, which allows + us to tell whether it was set, via the which_one_of interface. + """ + + return ( + not proto_field_obj.proto3_optional + and which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index" + ) + + +@dataclass +class FieldCompiler(MessageCompiler): + parent: MessageCompiler = PLACEHOLDER + proto_obj: FieldDescriptorProto = PLACEHOLDER + + def __post_init__(self) -> None: + # Add field to message + self.parent.fields.append(self) + # Check for new imports + self.add_imports_to(self.output_file) + super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__ + + def get_field_string(self, indent: int = 4) -> str: + """Construct string representation of this field as a field.""" + name = f"{self.py_name}" + annotations = f": {self.annotation}" + field_args = ", ".join( + ([""] + self.aristaproto_field_args) if self.aristaproto_field_args else [] + ) + aristaproto_field_type = ( + f"aristaproto.{self.field_type}_field({self.proto_obj.number}{field_args})" + ) + if self.py_name in dir(builtins): + self.parent.builtins_types.add(self.py_name) + return f"{name}{annotations} = {aristaproto_field_type}" + + @property + def aristaproto_field_args(self) -> List[str]: + args = [] + if self.field_wraps: + args.append(f"wraps={self.field_wraps}") + if self.optional: + args.append(f"optional=True") + return args + + @property + def datetime_imports(self) -> Set[str]: + imports = set() + annotation = self.annotation + # FIXME: false positives - e.g. `MyDatetimedelta` + if "timedelta" in annotation: + imports.add("timedelta") + if "datetime" in annotation: + imports.add("datetime") + return imports + + @property + def typing_imports(self) -> Set[str]: + imports = set() + annotation = self.annotation + if "Optional[" in annotation: + imports.add("Optional") + if "List[" in annotation: + imports.add("List") + if "Dict[" in annotation: + imports.add("Dict") + return imports + + @property + def pydantic_imports(self) -> Set[str]: + return set() + + @property + def use_builtins(self) -> bool: + return self.py_type in self.parent.builtins_types or ( + self.py_type == self.py_name and self.py_name in dir(builtins) + ) + + def add_imports_to(self, output_file: OutputTemplate) -> None: + output_file.datetime_imports.update(self.datetime_imports) + output_file.typing_imports.update(self.typing_imports) + output_file.pydantic_imports.update(self.pydantic_imports) + output_file.builtins_import = output_file.builtins_import or self.use_builtins + + @property + def field_wraps(self) -> Optional[str]: + """Returns aristaproto wrapped field type or None.""" + match_wrapper = re.match( + r"\.google\.protobuf\.(.+)Value$", self.proto_obj.type_name + ) + if match_wrapper: + wrapped_type = "TYPE_" + match_wrapper.group(1).upper() + if hasattr(aristaproto, wrapped_type): + return f"aristaproto.{wrapped_type}" + return None + + @property + def repeated(self) -> bool: + return ( + self.proto_obj.label == FieldDescriptorProtoLabel.LABEL_REPEATED + and not is_map(self.proto_obj, self.parent) + ) + + @property + def optional(self) -> bool: + return self.proto_obj.proto3_optional + + @property + def mutable(self) -> bool: + """True if the field is a mutable type, otherwise False.""" + return self.annotation.startswith(("List[", "Dict[")) + + @property + def field_type(self) -> str: + """String representation of proto field type.""" + return ( + FieldDescriptorProtoType(self.proto_obj.type) + .name.lower() + .replace("type_", "") + ) + + @property + def default_value_string(self) -> str: + """Python representation of the default proto value.""" + if self.repeated: + return "[]" + if self.optional: + return "None" + if self.py_type == "int": + return "0" + if self.py_type == "float": + return "0.0" + elif self.py_type == "bool": + return "False" + elif self.py_type == "str": + return '""' + elif self.py_type == "bytes": + return 'b""' + elif self.field_type == "enum": + enum_proto_obj_name = self.proto_obj.type_name.split(".").pop() + enum = next( + e + for e in self.output_file.enums + if e.proto_obj.name == enum_proto_obj_name + ) + return enum.default_value_string + else: + # Message type + return "None" + + @property + def packed(self) -> bool: + """True if the wire representation is a packed format.""" + return self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES + + @property + def py_name(self) -> str: + """Pythonized name.""" + return pythonize_field_name(self.proto_name) + + @property + def proto_name(self) -> str: + """Original protobuf name.""" + return self.proto_obj.name + + @property + def py_type(self) -> str: + """String representation of Python type.""" + if self.proto_obj.type in PROTO_FLOAT_TYPES: + return "float" + elif self.proto_obj.type in PROTO_INT_TYPES: + return "int" + elif self.proto_obj.type in PROTO_BOOL_TYPES: + return "bool" + elif self.proto_obj.type in PROTO_STR_TYPES: + return "str" + elif self.proto_obj.type in PROTO_BYTES_TYPES: + return "bytes" + elif self.proto_obj.type in PROTO_MESSAGE_TYPES: + # Type referencing another defined Message or a named enum + return get_type_reference( + package=self.output_file.package, + imports=self.output_file.imports, + source_type=self.proto_obj.type_name, + pydantic=self.output_file.pydantic_dataclasses, + ) + else: + raise NotImplementedError(f"Unknown type {self.proto_obj.type}") + + @property + def annotation(self) -> str: + py_type = self.py_type + if self.use_builtins: + py_type = f"builtins.{py_type}" + if self.repeated: + return f"List[{py_type}]" + if self.optional: + return f"Optional[{py_type}]" + return py_type + + +@dataclass +class OneOfFieldCompiler(FieldCompiler): + @property + def aristaproto_field_args(self) -> List[str]: + args = super().aristaproto_field_args + group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name + args.append(f'group="{group}"') + return args + + +@dataclass +class PydanticOneOfFieldCompiler(OneOfFieldCompiler): + @property + def optional(self) -> bool: + # Force the optional to be True. This will allow the pydantic dataclass + # to validate the object correctly by allowing the field to be let empty. + # We add a pydantic validator later to ensure exactly one field is defined. + return True + + @property + def pydantic_imports(self) -> Set[str]: + return {"root_validator"} + + +@dataclass +class MapEntryCompiler(FieldCompiler): + py_k_type: Type = PLACEHOLDER + py_v_type: Type = PLACEHOLDER + proto_k_type: str = PLACEHOLDER + proto_v_type: str = PLACEHOLDER + + def __post_init__(self) -> None: + """Explore nested types and set k_type and v_type if unset.""" + map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry" + for nested in self.parent.proto_obj.nested_type: + if ( + nested.name.replace("_", "").lower() == map_entry + and nested.options.map_entry + ): + # Get Python types + self.py_k_type = FieldCompiler( + source_file=self.source_file, + parent=self, + proto_obj=nested.field[0], # key + ).py_type + self.py_v_type = FieldCompiler( + source_file=self.source_file, + parent=self, + proto_obj=nested.field[1], # value + ).py_type + + # Get proto types + self.proto_k_type = FieldDescriptorProtoType(nested.field[0].type).name + self.proto_v_type = FieldDescriptorProtoType(nested.field[1].type).name + super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__ + + @property + def aristaproto_field_args(self) -> List[str]: + return [f"aristaproto.{self.proto_k_type}", f"aristaproto.{self.proto_v_type}"] + + @property + def field_type(self) -> str: + return "map" + + @property + def annotation(self) -> str: + return f"Dict[{self.py_k_type}, {self.py_v_type}]" + + @property + def repeated(self) -> bool: + return False # maps cannot be repeated + + +@dataclass +class EnumDefinitionCompiler(MessageCompiler): + """Representation of a proto Enum definition.""" + + proto_obj: EnumDescriptorProto = PLACEHOLDER + entries: List["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER + + @dataclass(unsafe_hash=True) + class EnumEntry: + """Representation of an Enum entry.""" + + name: str + value: int + comment: str + + def __post_init__(self) -> None: + # Get entries/allowed values for this Enum + self.entries = [ + self.EnumEntry( + name=pythonize_enum_member_name( + entry_proto_value.name, self.proto_obj.name + ), + value=entry_proto_value.number, + comment=get_comment( + proto_file=self.source_file, path=self.path + [2, entry_number] + ), + ) + for entry_number, entry_proto_value in enumerate(self.proto_obj.value) + ] + super().__post_init__() # call MessageCompiler __post_init__ + + @property + def default_value_string(self) -> str: + """Python representation of the default value for Enums. + + As per the spec, this is the first value of the Enum. + """ + return str(self.entries[0].value) # ideally, should ALWAYS be int(0)! + + +@dataclass +class ServiceCompiler(ProtoContentBase): + parent: OutputTemplate = PLACEHOLDER + proto_obj: DescriptorProto = PLACEHOLDER + path: List[int] = PLACEHOLDER + methods: List["ServiceMethodCompiler"] = field(default_factory=list) + + def __post_init__(self) -> None: + # Add service to output file + self.output_file.services.append(self) + self.output_file.typing_imports.add("Dict") + super().__post_init__() # check for unset fields + + @property + def proto_name(self) -> str: + return self.proto_obj.name + + @property + def py_name(self) -> str: + return pythonize_class_name(self.proto_name) + + +@dataclass +class ServiceMethodCompiler(ProtoContentBase): + parent: ServiceCompiler + proto_obj: MethodDescriptorProto + path: List[int] = PLACEHOLDER + comment_indent: int = 8 + + def __post_init__(self) -> None: + # Add method to service + self.parent.methods.append(self) + + # Check for imports + if "Optional" in self.py_output_message_type: + self.output_file.typing_imports.add("Optional") + + # Check for Async imports + if self.client_streaming: + self.output_file.typing_imports.add("AsyncIterable") + self.output_file.typing_imports.add("Iterable") + self.output_file.typing_imports.add("Union") + + # Required by both client and server + if self.client_streaming or self.server_streaming: + self.output_file.typing_imports.add("AsyncIterator") + + # add imports required for request arguments timeout, deadline and metadata + self.output_file.typing_imports.add("Optional") + self.output_file.imports_type_checking_only.add("import grpclib.server") + self.output_file.imports_type_checking_only.add( + "from aristaproto.grpc.grpclib_client import MetadataLike" + ) + self.output_file.imports_type_checking_only.add( + "from grpclib.metadata import Deadline" + ) + + super().__post_init__() # check for unset fields + + @property + def py_name(self) -> str: + """Pythonized method name.""" + return pythonize_method_name(self.proto_obj.name) + + @property + def proto_name(self) -> str: + """Original protobuf name.""" + return self.proto_obj.name + + @property + def route(self) -> str: + package_part = ( + f"{self.output_file.package}." if self.output_file.package else "" + ) + return f"/{package_part}{self.parent.proto_name}/{self.proto_name}" + + @property + def py_input_message(self) -> Optional[MessageCompiler]: + """Find the input message object. + + Returns + ------- + Optional[MessageCompiler] + Method instance representing the input message. + If not input message could be found or there are no + input messages, None is returned. + """ + package, name = parse_source_type_name(self.proto_obj.input_type) + + # Nested types are currently flattened without dots. + # Todo: keep a fully quantified name in types, that is + # comparable with method.input_type + for msg in self.request.all_messages: + if ( + msg.py_name == pythonize_class_name(name.replace(".", "")) + and msg.output_file.package == package + ): + return msg + return None + + @property + def py_input_message_type(self) -> str: + """String representation of the Python type corresponding to the + input message. + + Returns + ------- + str + String representation of the Python type corresponding to the input message. + """ + return get_type_reference( + package=self.output_file.package, + imports=self.output_file.imports, + source_type=self.proto_obj.input_type, + unwrap=False, + pydantic=self.output_file.pydantic_dataclasses, + ).strip('"') + + @property + def py_input_message_param(self) -> str: + """Param name corresponding to py_input_message_type. + + Returns + ------- + str + Param name corresponding to py_input_message_type. + """ + return pythonize_field_name(self.py_input_message_type) + + @property + def py_output_message_type(self) -> str: + """String representation of the Python type corresponding to the + output message. + + Returns + ------- + str + String representation of the Python type corresponding to the output message. + """ + return get_type_reference( + package=self.output_file.package, + imports=self.output_file.imports, + source_type=self.proto_obj.output_type, + unwrap=False, + pydantic=self.output_file.pydantic_dataclasses, + ).strip('"') + + @property + def client_streaming(self) -> bool: + return self.proto_obj.client_streaming + + @property + def server_streaming(self) -> bool: + return self.proto_obj.server_streaming diff --git a/src/aristaproto/plugin/parser.py b/src/aristaproto/plugin/parser.py new file mode 100644 index 0000000..f761af6 --- /dev/null +++ b/src/aristaproto/plugin/parser.py @@ -0,0 +1,221 @@ +import pathlib +import sys +from typing import ( + Generator, + List, + Set, + Tuple, + Union, +) + +from aristaproto.lib.google.protobuf import ( + DescriptorProto, + EnumDescriptorProto, + FieldDescriptorProto, + FileDescriptorProto, + ServiceDescriptorProto, +) +from aristaproto.lib.google.protobuf.compiler import ( + CodeGeneratorRequest, + CodeGeneratorResponse, + CodeGeneratorResponseFeature, + CodeGeneratorResponseFile, +) + +from .compiler import outputfile_compiler +from .models import ( + EnumDefinitionCompiler, + FieldCompiler, + MapEntryCompiler, + MessageCompiler, + OneOfFieldCompiler, + OutputTemplate, + PluginRequestCompiler, + PydanticOneOfFieldCompiler, + ServiceCompiler, + ServiceMethodCompiler, + is_map, + is_oneof, +) + + +def traverse( + proto_file: FileDescriptorProto, +) -> Generator[ + Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None +]: + # Todo: Keep information about nested hierarchy + def _traverse( + path: List[int], + items: Union[List[EnumDescriptorProto], List[DescriptorProto]], + prefix: str = "", + ) -> Generator[ + Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None + ]: + for i, item in enumerate(items): + # Adjust the name since we flatten the hierarchy. + # Todo: don't change the name, but include full name in returned tuple + item.name = next_prefix = f"{prefix}_{item.name}" + yield item, [*path, i] + + if isinstance(item, DescriptorProto): + # Get nested types. + yield from _traverse([*path, i, 4], item.enum_type, next_prefix) + yield from _traverse([*path, i, 3], item.nested_type, next_prefix) + + yield from _traverse([5], proto_file.enum_type) + yield from _traverse([4], proto_file.message_type) + + +def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: + response = CodeGeneratorResponse() + + plugin_options = request.parameter.split(",") if request.parameter else [] + response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL + + request_data = PluginRequestCompiler(plugin_request_obj=request) + # Gather output packages + for proto_file in request.proto_file: + output_package_name = proto_file.package + if output_package_name not in request_data.output_packages: + # Create a new output if there is no output for this package + request_data.output_packages[output_package_name] = OutputTemplate( + parent_request=request_data, package_proto_obj=proto_file + ) + # Add this input file to the output corresponding to this package + request_data.output_packages[output_package_name].input_files.append(proto_file) + + if ( + proto_file.package == "google.protobuf" + and "INCLUDE_GOOGLE" not in plugin_options + ): + # If not INCLUDE_GOOGLE, + # skip outputting Google's well-known types + request_data.output_packages[output_package_name].output = False + + if "pydantic_dataclasses" in plugin_options: + request_data.output_packages[ + output_package_name + ].pydantic_dataclasses = True + + # Read Messages and Enums + # We need to read Messages before Services in so that we can + # get the references to input/output messages for each service + for output_package_name, output_package in request_data.output_packages.items(): + for proto_input_file in output_package.input_files: + for item, path in traverse(proto_input_file): + read_protobuf_type( + source_file=proto_input_file, + item=item, + path=path, + output_package=output_package, + ) + + # Read Services + for output_package_name, output_package in request_data.output_packages.items(): + for proto_input_file in output_package.input_files: + for index, service in enumerate(proto_input_file.service): + read_protobuf_service(service, index, output_package) + + # Generate output files + output_paths: Set[pathlib.Path] = set() + for output_package_name, output_package in request_data.output_packages.items(): + if not output_package.output: + continue + + # Add files to the response object + output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") + output_paths.add(output_path) + + response.file.append( + CodeGeneratorResponseFile( + name=str(output_path), + # Render and then format the output file + content=outputfile_compiler(output_file=output_package), + ) + ) + + # Make each output directory a package with __init__ file + init_files = { + directory.joinpath("__init__.py") + for path in output_paths + for directory in path.parents + if not directory.joinpath("__init__.py").exists() + } - output_paths + + for init_file in init_files: + response.file.append(CodeGeneratorResponseFile(name=str(init_file))) + + for output_package_name in sorted(output_paths.union(init_files)): + print(f"Writing {output_package_name}", file=sys.stderr) + + return response + + +def _make_one_of_field_compiler( + output_package: OutputTemplate, + source_file: "FileDescriptorProto", + parent: MessageCompiler, + proto_obj: "FieldDescriptorProto", + path: List[int], +) -> FieldCompiler: + pydantic = output_package.pydantic_dataclasses + Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler + return Cls( + source_file=source_file, + parent=parent, + proto_obj=proto_obj, + path=path, + ) + + +def read_protobuf_type( + item: DescriptorProto, + path: List[int], + source_file: "FileDescriptorProto", + output_package: OutputTemplate, +) -> None: + if isinstance(item, DescriptorProto): + if item.options.map_entry: + # Skip generated map entry messages since we just use dicts + return + # Process Message + message_data = MessageCompiler( + source_file=source_file, parent=output_package, proto_obj=item, path=path + ) + for index, field in enumerate(item.field): + if is_map(field, item): + MapEntryCompiler( + source_file=source_file, + parent=message_data, + proto_obj=field, + path=path + [2, index], + ) + elif is_oneof(field): + _make_one_of_field_compiler( + output_package, source_file, message_data, field, path + [2, index] + ) + else: + FieldCompiler( + source_file=source_file, + parent=message_data, + proto_obj=field, + path=path + [2, index], + ) + elif isinstance(item, EnumDescriptorProto): + # Enum + EnumDefinitionCompiler( + source_file=source_file, parent=output_package, proto_obj=item, path=path + ) + + +def read_protobuf_service( + service: ServiceDescriptorProto, index: int, output_package: OutputTemplate +) -> None: + service_data = ServiceCompiler( + parent=output_package, proto_obj=service, path=[6, index] + ) + for j, method in enumerate(service.method): + ServiceMethodCompiler( + parent=service_data, proto_obj=method, path=[6, index, 2, j] + ) diff --git a/src/aristaproto/plugin/plugin.bat b/src/aristaproto/plugin/plugin.bat new file mode 100644 index 0000000..2a4444d --- /dev/null +++ b/src/aristaproto/plugin/plugin.bat @@ -0,0 +1,2 @@ +@SET plugin_dir=%~dp0 +@python -m %plugin_dir% %*
\ No newline at end of file |