diff options
Diffstat (limited to 'src/aristaproto/compile')
-rw-r--r-- | src/aristaproto/compile/__init__.py | 0 | ||||
-rw-r--r-- | src/aristaproto/compile/importing.py | 176 | ||||
-rw-r--r-- | src/aristaproto/compile/naming.py | 21 |
3 files changed, 197 insertions, 0 deletions
diff --git a/src/aristaproto/compile/__init__.py b/src/aristaproto/compile/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/aristaproto/compile/__init__.py diff --git a/src/aristaproto/compile/importing.py b/src/aristaproto/compile/importing.py new file mode 100644 index 0000000..8486ddd --- /dev/null +++ b/src/aristaproto/compile/importing.py @@ -0,0 +1,176 @@ +import os +import re +from typing import ( + Dict, + List, + Set, + Tuple, + Type, +) + +from ..casing import safe_snake_case +from ..lib.google import protobuf as google_protobuf +from .naming import pythonize_class_name + + +WRAPPER_TYPES: Dict[str, Type] = { + ".google.protobuf.DoubleValue": google_protobuf.DoubleValue, + ".google.protobuf.FloatValue": google_protobuf.FloatValue, + ".google.protobuf.Int32Value": google_protobuf.Int32Value, + ".google.protobuf.Int64Value": google_protobuf.Int64Value, + ".google.protobuf.UInt32Value": google_protobuf.UInt32Value, + ".google.protobuf.UInt64Value": google_protobuf.UInt64Value, + ".google.protobuf.BoolValue": google_protobuf.BoolValue, + ".google.protobuf.StringValue": google_protobuf.StringValue, + ".google.protobuf.BytesValue": google_protobuf.BytesValue, +} + + +def parse_source_type_name(field_type_name: str) -> Tuple[str, str]: + """ + Split full source type name into package and type name. + E.g. 'root.package.Message' -> ('root.package', 'Message') + 'root.Message.SomeEnum' -> ('root', 'Message.SomeEnum') + """ + package_match = re.match(r"^\.?([^A-Z]+)\.(.+)", field_type_name) + if package_match: + package = package_match.group(1) + name = package_match.group(2) + else: + package = "" + name = field_type_name.lstrip(".") + return package, name + + +def get_type_reference( + *, + package: str, + imports: set, + source_type: str, + unwrap: bool = True, + pydantic: bool = False, +) -> str: + """ + Return a Python type name for a proto type reference. Adds the import if + necessary. Unwraps well known type if required. + """ + if unwrap: + if source_type in WRAPPER_TYPES: + wrapped_type = type(WRAPPER_TYPES[source_type]().value) + return f"Optional[{wrapped_type.__name__}]" + + if source_type == ".google.protobuf.Duration": + return "timedelta" + + elif source_type == ".google.protobuf.Timestamp": + return "datetime" + + source_package, source_type = parse_source_type_name(source_type) + + current_package: List[str] = package.split(".") if package else [] + py_package: List[str] = source_package.split(".") if source_package else [] + py_type: str = pythonize_class_name(source_type) + + compiling_google_protobuf = current_package == ["google", "protobuf"] + importing_google_protobuf = py_package == ["google", "protobuf"] + if importing_google_protobuf and not compiling_google_protobuf: + py_package = ( + ["aristaproto", "lib"] + (["pydantic"] if pydantic else []) + py_package + ) + + if py_package[:1] == ["aristaproto"]: + return reference_absolute(imports, py_package, py_type) + + if py_package == current_package: + return reference_sibling(py_type) + + if py_package[: len(current_package)] == current_package: + return reference_descendent(current_package, imports, py_package, py_type) + + if current_package[: len(py_package)] == py_package: + return reference_ancestor(current_package, imports, py_package, py_type) + + return reference_cousin(current_package, imports, py_package, py_type) + + +def reference_absolute(imports: Set[str], py_package: List[str], py_type: str) -> str: + """ + Returns a reference to a python type located in the root, i.e. sys.path. + """ + string_import = ".".join(py_package) + string_alias = safe_snake_case(string_import) + imports.add(f"import {string_import} as {string_alias}") + return f'"{string_alias}.{py_type}"' + + +def reference_sibling(py_type: str) -> str: + """ + Returns a reference to a python type within the same package as the current package. + """ + return f'"{py_type}"' + + +def reference_descendent( + current_package: List[str], imports: Set[str], py_package: List[str], py_type: str +) -> str: + """ + Returns a reference to a python type in a package that is a descendent of the + current package, and adds the required import that is aliased to avoid name + conflicts. + """ + importing_descendent = py_package[len(current_package) :] + string_from = ".".join(importing_descendent[:-1]) + string_import = importing_descendent[-1] + if string_from: + string_alias = "_".join(importing_descendent) + imports.add(f"from .{string_from} import {string_import} as {string_alias}") + return f'"{string_alias}.{py_type}"' + else: + imports.add(f"from . import {string_import}") + return f'"{string_import}.{py_type}"' + + +def reference_ancestor( + current_package: List[str], imports: Set[str], py_package: List[str], py_type: str +) -> str: + """ + Returns a reference to a python type in a package which is an ancestor to the + current package, and adds the required import that is aliased (if possible) to avoid + name conflicts. + + Adds trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34). + """ + distance_up = len(current_package) - len(py_package) + if py_package: + string_import = py_package[-1] + string_alias = f"_{'_' * distance_up}{string_import}__" + string_from = f"..{'.' * distance_up}" + imports.add(f"from {string_from} import {string_import} as {string_alias}") + return f'"{string_alias}.{py_type}"' + else: + string_alias = f"{'_' * distance_up}{py_type}__" + imports.add(f"from .{'.' * distance_up} import {py_type} as {string_alias}") + return f'"{string_alias}"' + + +def reference_cousin( + current_package: List[str], imports: Set[str], py_package: List[str], py_type: str +) -> str: + """ + Returns a reference to a python type in a package that is not descendent, ancestor + or sibling, and adds the required import that is aliased to avoid name conflicts. + """ + shared_ancestry = os.path.commonprefix([current_package, py_package]) # type: ignore + distance_up = len(current_package) - len(shared_ancestry) + string_from = f".{'.' * distance_up}" + ".".join( + py_package[len(shared_ancestry) : -1] + ) + string_import = py_package[-1] + # Add trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34) + string_alias = ( + f"{'_' * distance_up}" + + safe_snake_case(".".join(py_package[len(shared_ancestry) :])) + + "__" + ) + imports.add(f"from {string_from} import {string_import} as {string_alias}") + return f'"{string_alias}.{py_type}"' diff --git a/src/aristaproto/compile/naming.py b/src/aristaproto/compile/naming.py new file mode 100644 index 0000000..0c45dde --- /dev/null +++ b/src/aristaproto/compile/naming.py @@ -0,0 +1,21 @@ +from aristaproto import casing + + +def pythonize_class_name(name: str) -> str: + return casing.pascal_case(name) + + +def pythonize_field_name(name: str) -> str: + return casing.safe_snake_case(name) + + +def pythonize_method_name(name: str) -> str: + return casing.safe_snake_case(name) + + +def pythonize_enum_member_name(name: str, enum_name: str) -> str: + enum_name = casing.snake_case(enum_name).upper() + find = name.find(enum_name) + if find != -1: + name = name[find + len(enum_name) :].strip("_") + return casing.sanitize_name(name) |