diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-03-03 14:11:07 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-03-03 14:11:07 +0000 |
commit | 42a1548cecf48d18233f56e3385cf9c89abcb9c2 (patch) | |
tree | 5e0fff4ecbd1fd7dd1022a7580139038df2a824c /sqlglot/dialects/presto.py | |
parent | Releasing debian version 21.1.2-1. (diff) | |
download | sqlglot-42a1548cecf48d18233f56e3385cf9c89abcb9c2.tar.xz sqlglot-42a1548cecf48d18233f56e3385cf9c89abcb9c2.zip |
Merging upstream version 22.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/presto.py')
-rw-r--r-- | sqlglot/dialects/presto.py | 29 |
1 files changed, 25 insertions, 4 deletions
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 8429547..3649bd2 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -453,11 +453,32 @@ class Presto(Dialect): return super().bracket_sql(expression) def struct_sql(self, expression: exp.Struct) -> str: - if any(isinstance(arg, self.KEY_VALUE_DEFINITIONS) for arg in expression.expressions): - self.unsupported("Struct with key-value definitions is unsupported.") - return self.function_fallback_sql(expression) + from sqlglot.optimizer.annotate_types import annotate_types + + expression = annotate_types(expression) + values: t.List[str] = [] + schema: t.List[str] = [] + unknown_type = False + + for e in expression.expressions: + if isinstance(e, exp.PropertyEQ): + if e.type and e.type.is_type(exp.DataType.Type.UNKNOWN): + unknown_type = True + else: + schema.append(f"{self.sql(e, 'this')} {self.sql(e.type)}") + values.append(self.sql(e, "expression")) + else: + values.append(self.sql(e)) + + size = len(expression.expressions) - return rename_func("ROW")(self, expression) + if not size or len(schema) != size: + if unknown_type: + self.unsupported( + "Cannot convert untyped key-value definitions (try annotate_types)." + ) + return self.func("ROW", *values) + return f"CAST(ROW({', '.join(values)}) AS ROW({', '.join(schema)}))" def interval_sql(self, expression: exp.Interval) -> str: unit = self.sql(expression, "unit") |