diff options
Diffstat (limited to '')
-rw-r--r-- | third_party/rust/uniffi_bindgen/src/bindings/python/templates/EnumTemplate.py | 71 |
1 files changed, 64 insertions, 7 deletions
diff --git a/third_party/rust/uniffi_bindgen/src/bindings/python/templates/EnumTemplate.py b/third_party/rust/uniffi_bindgen/src/bindings/python/templates/EnumTemplate.py index 84d089baf9..d07dd1c44a 100644 --- a/third_party/rust/uniffi_bindgen/src/bindings/python/templates/EnumTemplate.py +++ b/third_party/rust/uniffi_bindgen/src/bindings/python/templates/EnumTemplate.py @@ -7,31 +7,59 @@ {% if e.is_flat() %} class {{ type_name }}(enum.Enum): - {% for variant in e.variants() -%} - {{ variant.name()|enum_variant_py }} = {{ loop.index }} + {%- call py::docstring(e, 4) %} + {%- for variant in e.variants() %} + {{ variant.name()|enum_variant_py }} = {{ e|variant_discr_literal(loop.index0) }} + {%- call py::docstring(variant, 4) %} {% endfor %} {% else %} class {{ type_name }}: + {%- call py::docstring(e, 4) %} def __init__(self): raise RuntimeError("{{ type_name }} cannot be instantiated directly") # Each enum variant is a nested class of the enum itself. {% for variant in e.variants() -%} class {{ variant.name()|enum_variant_py }}: - {% for field in variant.fields() %} - {{- field.name()|var_name }}: "{{- field|type_name }}"; + {%- call py::docstring(variant, 8) %} + + {%- if variant.has_nameless_fields() %} + def __init__(self, *values): + if len(values) != {{ variant.fields().len() }}: + raise TypeError(f"Expected a tuple of len {{ variant.fields().len() }}, found len {len(values)}") + {%- for field in variant.fields() %} + if not isinstance(values[{{ loop.index0 }}], {{ field|type_name }}): + raise TypeError(f"unexpected type for tuple element {{ loop.index0 }} - expected '{{ field|type_name }}', got '{type(values[{{ loop.index0 }}])}'") + {%- endfor %} + self._values = values + + def __getitem__(self, index): + return self._values[index] + + def __str__(self): + return f"{{ type_name }}.{{ variant.name()|enum_variant_py }}{self._values!r}" + + def __eq__(self, other): + if not other.is_{{ variant.name()|var_name }}(): + return False + return self._values == other._values + + {%- else -%} + {%- for field in variant.fields() %} + {{ field.name()|var_name }}: "{{ field|type_name }}" + {%- call py::docstring(field, 8) %} {%- endfor %} @typing.no_type_check def __init__(self,{% for field in variant.fields() %}{{ field.name()|var_name }}: "{{- field|type_name }}"{% if loop.last %}{% else %}, {% endif %}{% endfor %}): - {% if variant.has_fields() %} + {%- if variant.has_fields() %} {%- for field in variant.fields() %} self.{{ field.name()|var_name }} = {{ field.name()|var_name }} {%- endfor %} - {% else %} + {%- else %} pass - {% endif %} + {%- endif %} def __str__(self): return "{{ type_name }}.{{ variant.name()|enum_variant_py }}({% for field in variant.fields() %}{{ field.name()|var_name }}={}{% if loop.last %}{% else %}, {% endif %}{% endfor %})".format({% for field in variant.fields() %}self.{{ field.name()|var_name }}{% if loop.last %}{% else %}, {% endif %}{% endfor %}) @@ -44,6 +72,7 @@ class {{ type_name }}: return False {%- endfor %} return True + {% endif %} {% endfor %} # For each variant, we have an `is_NAME` method for easily checking @@ -81,6 +110,30 @@ class {{ ffi_converter_name }}(_UniffiConverterRustBuffer): {%- endfor %} raise InternalError("Raw enum value doesn't match any cases") + @staticmethod + def check_lower(value): + {%- if e.variants().is_empty() %} + pass + {%- else %} + {%- for variant in e.variants() %} + {%- if e.is_flat() %} + if value == {{ type_name }}.{{ variant.name()|enum_variant_py }}: + {%- else %} + if value.is_{{ variant.name()|var_name }}(): + {%- endif %} + {%- for field in variant.fields() %} + {%- if variant.has_nameless_fields() %} + {{ field|check_lower_fn }}(value._values[{{ loop.index0 }}]) + {%- else %} + {{ field|check_lower_fn }}(value.{{ field.name()|var_name }}) + {%- endif %} + {%- endfor %} + return + {%- endfor %} + raise ValueError(value) + {%- endif %} + + @staticmethod def write(value, buf): {%- for variant in e.variants() %} {%- if e.is_flat() %} @@ -90,7 +143,11 @@ class {{ ffi_converter_name }}(_UniffiConverterRustBuffer): if value.is_{{ variant.name()|var_name }}(): buf.write_i32({{ loop.index }}) {%- for field in variant.fields() %} + {%- if variant.has_nameless_fields() %} + {{ field|write_fn }}(value._values[{{ loop.index0 }}], buf) + {%- else %} {{ field|write_fn }}(value.{{ field.name()|var_name }}, buf) + {%- endif %} {%- endfor %} {%- endif %} {%- endfor %} |