summaryrefslogtreecommitdiffstats
path: root/third_party/rust/uniffi_bindgen/src/bindings/python/templates/EnumTemplate.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--third_party/rust/uniffi_bindgen/src/bindings/python/templates/EnumTemplate.py71
1 files changed, 64 insertions, 7 deletions
diff --git a/third_party/rust/uniffi_bindgen/src/bindings/python/templates/EnumTemplate.py b/third_party/rust/uniffi_bindgen/src/bindings/python/templates/EnumTemplate.py
index 84d089baf9..d07dd1c44a 100644
--- a/third_party/rust/uniffi_bindgen/src/bindings/python/templates/EnumTemplate.py
+++ b/third_party/rust/uniffi_bindgen/src/bindings/python/templates/EnumTemplate.py
@@ -7,31 +7,59 @@
{% if e.is_flat() %}
class {{ type_name }}(enum.Enum):
- {% for variant in e.variants() -%}
- {{ variant.name()|enum_variant_py }} = {{ loop.index }}
+ {%- call py::docstring(e, 4) %}
+ {%- for variant in e.variants() %}
+ {{ variant.name()|enum_variant_py }} = {{ e|variant_discr_literal(loop.index0) }}
+ {%- call py::docstring(variant, 4) %}
{% endfor %}
{% else %}
class {{ type_name }}:
+ {%- call py::docstring(e, 4) %}
def __init__(self):
raise RuntimeError("{{ type_name }} cannot be instantiated directly")
# Each enum variant is a nested class of the enum itself.
{% for variant in e.variants() -%}
class {{ variant.name()|enum_variant_py }}:
- {% for field in variant.fields() %}
- {{- field.name()|var_name }}: "{{- field|type_name }}";
+ {%- call py::docstring(variant, 8) %}
+
+ {%- if variant.has_nameless_fields() %}
+ def __init__(self, *values):
+ if len(values) != {{ variant.fields().len() }}:
+ raise TypeError(f"Expected a tuple of len {{ variant.fields().len() }}, found len {len(values)}")
+ {%- for field in variant.fields() %}
+ if not isinstance(values[{{ loop.index0 }}], {{ field|type_name }}):
+ raise TypeError(f"unexpected type for tuple element {{ loop.index0 }} - expected '{{ field|type_name }}', got '{type(values[{{ loop.index0 }}])}'")
+ {%- endfor %}
+ self._values = values
+
+ def __getitem__(self, index):
+ return self._values[index]
+
+ def __str__(self):
+ return f"{{ type_name }}.{{ variant.name()|enum_variant_py }}{self._values!r}"
+
+ def __eq__(self, other):
+ if not other.is_{{ variant.name()|var_name }}():
+ return False
+ return self._values == other._values
+
+ {%- else -%}
+ {%- for field in variant.fields() %}
+ {{ field.name()|var_name }}: "{{ field|type_name }}"
+ {%- call py::docstring(field, 8) %}
{%- endfor %}
@typing.no_type_check
def __init__(self,{% for field in variant.fields() %}{{ field.name()|var_name }}: "{{- field|type_name }}"{% if loop.last %}{% else %}, {% endif %}{% endfor %}):
- {% if variant.has_fields() %}
+ {%- if variant.has_fields() %}
{%- for field in variant.fields() %}
self.{{ field.name()|var_name }} = {{ field.name()|var_name }}
{%- endfor %}
- {% else %}
+ {%- else %}
pass
- {% endif %}
+ {%- endif %}
def __str__(self):
return "{{ type_name }}.{{ variant.name()|enum_variant_py }}({% for field in variant.fields() %}{{ field.name()|var_name }}={}{% if loop.last %}{% else %}, {% endif %}{% endfor %})".format({% for field in variant.fields() %}self.{{ field.name()|var_name }}{% if loop.last %}{% else %}, {% endif %}{% endfor %})
@@ -44,6 +72,7 @@ class {{ type_name }}:
return False
{%- endfor %}
return True
+ {% endif %}
{% endfor %}
# For each variant, we have an `is_NAME` method for easily checking
@@ -81,6 +110,30 @@ class {{ ffi_converter_name }}(_UniffiConverterRustBuffer):
{%- endfor %}
raise InternalError("Raw enum value doesn't match any cases")
+ @staticmethod
+ def check_lower(value):
+ {%- if e.variants().is_empty() %}
+ pass
+ {%- else %}
+ {%- for variant in e.variants() %}
+ {%- if e.is_flat() %}
+ if value == {{ type_name }}.{{ variant.name()|enum_variant_py }}:
+ {%- else %}
+ if value.is_{{ variant.name()|var_name }}():
+ {%- endif %}
+ {%- for field in variant.fields() %}
+ {%- if variant.has_nameless_fields() %}
+ {{ field|check_lower_fn }}(value._values[{{ loop.index0 }}])
+ {%- else %}
+ {{ field|check_lower_fn }}(value.{{ field.name()|var_name }})
+ {%- endif %}
+ {%- endfor %}
+ return
+ {%- endfor %}
+ raise ValueError(value)
+ {%- endif %}
+
+ @staticmethod
def write(value, buf):
{%- for variant in e.variants() %}
{%- if e.is_flat() %}
@@ -90,7 +143,11 @@ class {{ ffi_converter_name }}(_UniffiConverterRustBuffer):
if value.is_{{ variant.name()|var_name }}():
buf.write_i32({{ loop.index }})
{%- for field in variant.fields() %}
+ {%- if variant.has_nameless_fields() %}
+ {{ field|write_fn }}(value._values[{{ loop.index0 }}], buf)
+ {%- else %}
{{ field|write_fn }}(value.{{ field.name()|var_name }}, buf)
+ {%- endif %}
{%- endfor %}
{%- endif %}
{%- endfor %}