summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/presto.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-03 14:11:07 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-03 14:11:07 +0000
commit42a1548cecf48d18233f56e3385cf9c89abcb9c2 (patch)
tree5e0fff4ecbd1fd7dd1022a7580139038df2a824c /sqlglot/dialects/presto.py
parentReleasing debian version 21.1.2-1. (diff)
downloadsqlglot-42a1548cecf48d18233f56e3385cf9c89abcb9c2.tar.xz
sqlglot-42a1548cecf48d18233f56e3385cf9c89abcb9c2.zip
Merging upstream version 22.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/presto.py')
-rw-r--r--sqlglot/dialects/presto.py29
1 files changed, 25 insertions, 4 deletions
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 8429547..3649bd2 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -453,11 +453,32 @@ class Presto(Dialect):
return super().bracket_sql(expression)
def struct_sql(self, expression: exp.Struct) -> str:
- if any(isinstance(arg, self.KEY_VALUE_DEFINITIONS) for arg in expression.expressions):
- self.unsupported("Struct with key-value definitions is unsupported.")
- return self.function_fallback_sql(expression)
+ from sqlglot.optimizer.annotate_types import annotate_types
+
+ expression = annotate_types(expression)
+ values: t.List[str] = []
+ schema: t.List[str] = []
+ unknown_type = False
+
+ for e in expression.expressions:
+ if isinstance(e, exp.PropertyEQ):
+ if e.type and e.type.is_type(exp.DataType.Type.UNKNOWN):
+ unknown_type = True
+ else:
+ schema.append(f"{self.sql(e, 'this')} {self.sql(e.type)}")
+ values.append(self.sql(e, "expression"))
+ else:
+ values.append(self.sql(e))
+
+ size = len(expression.expressions)
- return rename_func("ROW")(self, expression)
+ if not size or len(schema) != size:
+ if unknown_type:
+ self.unsupported(
+ "Cannot convert untyped key-value definitions (try annotate_types)."
+ )
+ return self.func("ROW", *values)
+ return f"CAST(ROW({', '.join(values)}) AS ROW({', '.join(schema)}))"
def interval_sql(self, expression: exp.Interval) -> str:
unit = self.sql(expression, "unit")