summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/duckdb.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-15 05:02:18 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-15 05:02:18 +0000
commit41f1f5740d2140bfd3b2a282ca1087a4b576679a (patch)
tree0b1eb5ba5c759d08b05d56e50675784b6170f955 /sqlglot/dialects/duckdb.py
parentReleasing debian version 23.7.0-1. (diff)
downloadsqlglot-41f1f5740d2140bfd3b2a282ca1087a4b576679a.tar.xz
sqlglot-41f1f5740d2140bfd3b2a282ca1087a4b576679a.zip
Merging upstream version 23.10.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/duckdb.py')
-rw-r--r--sqlglot/dialects/duckdb.py49
1 files changed, 34 insertions, 15 deletions
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 6a1d07a..6486dda 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -15,7 +15,7 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
encode_decode_sql,
build_formatted_time,
- inline_array_sql,
+ inline_array_unless_query,
no_comment_column_constraint_sql,
no_safe_divide_sql,
no_timestamp_sql,
@@ -312,6 +312,15 @@ class DuckDB(Dialect):
),
}
+ def _parse_bracket(
+ self, this: t.Optional[exp.Expression] = None
+ ) -> t.Optional[exp.Expression]:
+ bracket = super()._parse_bracket(this)
+ if isinstance(bracket, exp.Bracket):
+ bracket.set("returns_list_for_maps", True)
+
+ return bracket
+
def _parse_map(self) -> exp.ToMap | exp.Map:
if self._match(TokenType.L_BRACE, advance=False):
return self.expression(exp.ToMap, this=self._parse_bracket())
@@ -370,11 +379,7 @@ class DuckDB(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
- exp.Array: lambda self, e: (
- self.func("ARRAY", e.expressions[0])
- if e.expressions and e.expressions[0].find(exp.Select)
- else inline_array_sql(self, e)
- ),
+ exp.Array: inline_array_unless_query,
exp.ArrayFilter: rename_func("LIST_FILTER"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"),
@@ -416,8 +421,8 @@ class DuckDB(Dialect):
exp.MonthsBetween: lambda self, e: self.func(
"DATEDIFF",
"'month'",
- exp.cast(e.expression, "timestamp", copy=True),
- exp.cast(e.this, "timestamp", copy=True),
+ exp.cast(e.expression, exp.DataType.Type.TIMESTAMP, copy=True),
+ exp.cast(e.this, exp.DataType.Type.TIMESTAMP, copy=True),
),
exp.ParseJSON: rename_func("JSON"),
exp.PercentileCont: _rename_unless_within_group("PERCENTILE_CONT", "QUANTILE_CONT"),
@@ -452,9 +457,11 @@ class DuckDB(Dialect):
"DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
),
exp.TimestampTrunc: timestamptrunc_sql,
- exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, "date")),
+ exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.DATE)),
exp.TimeStrToTime: timestrtotime_sql,
- exp.TimeStrToUnix: lambda self, e: self.func("EPOCH", exp.cast(e.this, "timestamp")),
+ exp.TimeStrToUnix: lambda self, e: self.func(
+ "EPOCH", exp.cast(e.this, exp.DataType.Type.TIMESTAMP)
+ ),
exp.TimeToStr: lambda self, e: self.func("STRFTIME", e.this, self.format_time(e)),
exp.TimeToUnix: rename_func("EPOCH"),
exp.TsOrDiToDi: lambda self,
@@ -463,8 +470,8 @@ class DuckDB(Dialect):
exp.TsOrDsDiff: lambda self, e: self.func(
"DATE_DIFF",
f"'{e.args.get('unit') or 'DAY'}'",
- exp.cast(e.expression, "TIMESTAMP"),
- exp.cast(e.this, "TIMESTAMP"),
+ exp.cast(e.expression, exp.DataType.Type.TIMESTAMP),
+ exp.cast(e.this, exp.DataType.Type.TIMESTAMP),
),
exp.UnixToStr: lambda self, e: self.func(
"STRFTIME", self.func("TO_TIMESTAMP", e.this), self.format_time(e)
@@ -593,7 +600,19 @@ class DuckDB(Dialect):
return super().generateseries_sql(expression)
def bracket_sql(self, expression: exp.Bracket) -> str:
- if isinstance(expression.this, exp.Array):
- expression.this.replace(exp.paren(expression.this))
+ this = expression.this
+ if isinstance(this, exp.Array):
+ this.replace(exp.paren(this))
+
+ bracket = super().bracket_sql(expression)
+
+ if not expression.args.get("returns_list_for_maps"):
+ if not this.type:
+ from sqlglot.optimizer.annotate_types import annotate_types
+
+ this = annotate_types(this)
+
+ if this.is_type(exp.DataType.Type.MAP):
+ bracket = f"({bracket})[1]"
- return super().bracket_sql(expression)
+ return bracket