diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-15 05:02:18 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-15 05:02:18 +0000 |
commit | 41f1f5740d2140bfd3b2a282ca1087a4b576679a (patch) | |
tree | 0b1eb5ba5c759d08b05d56e50675784b6170f955 /sqlglot/dialects/duckdb.py | |
parent | Releasing debian version 23.7.0-1. (diff) | |
download | sqlglot-41f1f5740d2140bfd3b2a282ca1087a4b576679a.tar.xz sqlglot-41f1f5740d2140bfd3b2a282ca1087a4b576679a.zip |
Merging upstream version 23.10.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/duckdb.py')
-rw-r--r-- | sqlglot/dialects/duckdb.py | 49 |
1 files changed, 34 insertions, 15 deletions
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 6a1d07a..6486dda 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -15,7 +15,7 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, encode_decode_sql, build_formatted_time, - inline_array_sql, + inline_array_unless_query, no_comment_column_constraint_sql, no_safe_divide_sql, no_timestamp_sql, @@ -312,6 +312,15 @@ class DuckDB(Dialect): ), } + def _parse_bracket( + self, this: t.Optional[exp.Expression] = None + ) -> t.Optional[exp.Expression]: + bracket = super()._parse_bracket(this) + if isinstance(bracket, exp.Bracket): + bracket.set("returns_list_for_maps", True) + + return bracket + def _parse_map(self) -> exp.ToMap | exp.Map: if self._match(TokenType.L_BRACE, advance=False): return self.expression(exp.ToMap, this=self._parse_bracket()) @@ -370,11 +379,7 @@ class DuckDB(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, - exp.Array: lambda self, e: ( - self.func("ARRAY", e.expressions[0]) - if e.expressions and e.expressions[0].find(exp.Select) - else inline_array_sql(self, e) - ), + exp.Array: inline_array_unless_query, exp.ArrayFilter: rename_func("LIST_FILTER"), exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"), @@ -416,8 +421,8 @@ class DuckDB(Dialect): exp.MonthsBetween: lambda self, e: self.func( "DATEDIFF", "'month'", - exp.cast(e.expression, "timestamp", copy=True), - exp.cast(e.this, "timestamp", copy=True), + exp.cast(e.expression, exp.DataType.Type.TIMESTAMP, copy=True), + exp.cast(e.this, exp.DataType.Type.TIMESTAMP, copy=True), ), exp.ParseJSON: rename_func("JSON"), exp.PercentileCont: _rename_unless_within_group("PERCENTILE_CONT", "QUANTILE_CONT"), @@ -452,9 +457,11 @@ class DuckDB(Dialect): "DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this ), exp.TimestampTrunc: timestamptrunc_sql, - exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, "date")), + exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.DATE)), exp.TimeStrToTime: timestrtotime_sql, - exp.TimeStrToUnix: lambda self, e: self.func("EPOCH", exp.cast(e.this, "timestamp")), + exp.TimeStrToUnix: lambda self, e: self.func( + "EPOCH", exp.cast(e.this, exp.DataType.Type.TIMESTAMP) + ), exp.TimeToStr: lambda self, e: self.func("STRFTIME", e.this, self.format_time(e)), exp.TimeToUnix: rename_func("EPOCH"), exp.TsOrDiToDi: lambda self, @@ -463,8 +470,8 @@ class DuckDB(Dialect): exp.TsOrDsDiff: lambda self, e: self.func( "DATE_DIFF", f"'{e.args.get('unit') or 'DAY'}'", - exp.cast(e.expression, "TIMESTAMP"), - exp.cast(e.this, "TIMESTAMP"), + exp.cast(e.expression, exp.DataType.Type.TIMESTAMP), + exp.cast(e.this, exp.DataType.Type.TIMESTAMP), ), exp.UnixToStr: lambda self, e: self.func( "STRFTIME", self.func("TO_TIMESTAMP", e.this), self.format_time(e) @@ -593,7 +600,19 @@ class DuckDB(Dialect): return super().generateseries_sql(expression) def bracket_sql(self, expression: exp.Bracket) -> str: - if isinstance(expression.this, exp.Array): - expression.this.replace(exp.paren(expression.this)) + this = expression.this + if isinstance(this, exp.Array): + this.replace(exp.paren(this)) + + bracket = super().bracket_sql(expression) + + if not expression.args.get("returns_list_for_maps"): + if not this.type: + from sqlglot.optimizer.annotate_types import annotate_types + + this = annotate_types(this) + + if this.is_type(exp.DataType.Type.MAP): + bracket = f"({bracket})[1]" - return super().bracket_sql(expression) + return bracket |