summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-15 05:02:18 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-15 05:02:18 +0000
commit41f1f5740d2140bfd3b2a282ca1087a4b576679a (patch)
tree0b1eb5ba5c759d08b05d56e50675784b6170f955 /sqlglot/dialects
parentReleasing debian version 23.7.0-1. (diff)
downloadsqlglot-41f1f5740d2140bfd3b2a282ca1087a4b576679a.tar.xz
sqlglot-41f1f5740d2140bfd3b2a282ca1087a4b576679a.zip
Merging upstream version 23.10.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r--sqlglot/dialects/bigquery.py43
-rw-r--r--sqlglot/dialects/clickhouse.py4
-rw-r--r--sqlglot/dialects/dialect.py19
-rw-r--r--sqlglot/dialects/drill.py4
-rw-r--r--sqlglot/dialects/duckdb.py49
-rw-r--r--sqlglot/dialects/mysql.py4
-rw-r--r--sqlglot/dialects/postgres.py3
-rw-r--r--sqlglot/dialects/presto.py14
-rw-r--r--sqlglot/dialects/prql.py41
-rw-r--r--sqlglot/dialects/redshift.py8
-rw-r--r--sqlglot/dialects/snowflake.py2
-rw-r--r--sqlglot/dialects/spark.py12
-rw-r--r--sqlglot/dialects/spark2.py14
-rw-r--r--sqlglot/dialects/teradata.py50
-rw-r--r--sqlglot/dialects/tsql.py10
15 files changed, 189 insertions, 88 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 2167ba2..a7b4895 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -15,7 +15,7 @@ from sqlglot.dialects.dialect import (
build_formatted_time,
filter_array_using_unnest,
if_sql,
- inline_array_sql,
+ inline_array_unless_query,
max_or_greatest,
min_or_least,
no_ilike_sql,
@@ -80,29 +80,6 @@ def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str:
return self.create_sql(expression)
-def _unqualify_unnest(expression: exp.Expression) -> exp.Expression:
- """Remove references to unnest table aliases since bigquery doesn't allow them.
-
- These are added by the optimizer's qualify_column step.
- """
- from sqlglot.optimizer.scope import find_all_in_scope
-
- if isinstance(expression, exp.Select):
- unnest_aliases = {
- unnest.alias
- for unnest in find_all_in_scope(expression, exp.Unnest)
- if isinstance(unnest.parent, (exp.From, exp.Join))
- }
- if unnest_aliases:
- for column in expression.find_all(exp.Column):
- if column.table in unnest_aliases:
- column.set("table", None)
- elif column.db in unnest_aliases:
- column.set("db", None)
-
- return expression
-
-
# https://issuetracker.google.com/issues/162294746
# workaround for bigquery bug when grouping by an expression and then ordering
# WITH x AS (SELECT 1 y)
@@ -197,8 +174,8 @@ def _ts_or_ds_add_sql(self: BigQuery.Generator, expression: exp.TsOrDsAdd) -> st
def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) -> str:
- expression.this.replace(exp.cast(expression.this, "TIMESTAMP", copy=True))
- expression.expression.replace(exp.cast(expression.expression, "TIMESTAMP", copy=True))
+ expression.this.replace(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
+ expression.expression.replace(exp.cast(expression.expression, exp.DataType.Type.TIMESTAMP))
unit = unit_to_var(expression)
return self.func("DATE_DIFF", expression.this, expression.expression, unit)
@@ -214,7 +191,9 @@ def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> s
if scale == exp.UnixToTime.MICROS:
return self.func("TIMESTAMP_MICROS", timestamp)
- unix_seconds = exp.cast(exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), "int64")
+ unix_seconds = exp.cast(
+ exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), exp.DataType.Type.BIGINT
+ )
return self.func("TIMESTAMP_SECONDS", unix_seconds)
@@ -576,6 +555,7 @@ class BigQuery(Dialect):
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
+ exp.Array: inline_array_unless_query,
exp.ArrayContains: _array_contains_sql,
exp.ArrayFilter: filter_array_using_unnest,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
@@ -629,7 +609,7 @@ class BigQuery(Dialect):
exp.Select: transforms.preprocess(
[
transforms.explode_to_unnest(),
- _unqualify_unnest,
+ transforms.unqualify_unnest,
transforms.eliminate_distinct_on,
_alias_ordered_group,
transforms.eliminate_semi_and_anti_joins,
@@ -843,13 +823,6 @@ class BigQuery(Dialect):
def trycast_sql(self, expression: exp.TryCast) -> str:
return self.cast_sql(expression, safe_prefix="SAFE_")
- def array_sql(self, expression: exp.Array) -> str:
- first_arg = seq_get(expression.expressions, 0)
- if isinstance(first_arg, exp.Query):
- return f"ARRAY{self.wrap(self.sql(first_arg))}"
-
- return inline_array_sql(self, expression)
-
def bracket_sql(self, expression: exp.Bracket) -> str:
this = expression.this
expressions = expression.expressions
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 631dc30..34ee529 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -629,7 +629,8 @@ class ClickHouse(Dialect):
exp.CountIf: rename_func("countIf"),
exp.CompressColumnConstraint: lambda self,
e: f"CODEC({self.expressions(e, key='this', flat=True)})",
- exp.ComputedColumnConstraint: lambda self, e: f"ALIAS {self.sql(e, 'this')}",
+ exp.ComputedColumnConstraint: lambda self,
+ e: f"{'MATERIALIZED' if e.args.get('persisted') else 'ALIAS'} {self.sql(e, 'this')}",
exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"),
exp.DateAdd: date_delta_sql("DATE_ADD"),
exp.DateDiff: date_delta_sql("DATE_DIFF"),
@@ -667,6 +668,7 @@ class ClickHouse(Dialect):
TABLE_HINTS = False
EXPLICIT_UNION = True
GROUPINGS_SEP = ""
+ OUTER_UNION_MODIFIERS = False
# there's no list in docs, but it can be found in Clickhouse code
# see `ClickHouse/src/Parsers/ParserCreate*.cpp`
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 81057c2..5a47438 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -562,7 +562,7 @@ def if_sql(
def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
this = expression.this
if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
- this.replace(exp.cast(this, "json"))
+ this.replace(exp.cast(this, exp.DataType.Type.JSON))
return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
@@ -571,6 +571,13 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str:
return f"[{self.expressions(expression, flat=True)}]"
+def inline_array_unless_query(self: Generator, expression: exp.Array) -> str:
+ elem = seq_get(expression.expressions, 0)
+ if isinstance(elem, exp.Expression) and elem.find(exp.Query):
+ return self.func("ARRAY", elem)
+ return inline_array_sql(self, expression)
+
+
def no_ilike_sql(self: Generator, expression: exp.ILike) -> str:
return self.like_sql(
exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression)
@@ -765,11 +772,11 @@ def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
from sqlglot.optimizer.annotate_types import annotate_types
target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
- return self.sql(exp.cast(expression.this, to=target_type))
+ return self.sql(exp.cast(expression.this, target_type))
if expression.text("expression").lower() in TIMEZONES:
return self.sql(
exp.AtTimeZone(
- this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
+ this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
zone=expression.expression,
)
)
@@ -806,11 +813,11 @@ def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
- return self.sql(exp.cast(expression.this, "timestamp"))
+ return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
- return self.sql(exp.cast(expression.this, "date"))
+ return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))
# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
@@ -1023,7 +1030,7 @@ def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")
- return self.sql(exp.cast(minus_one_day, "date"))
+ return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
index 0a00d92..06f49d5 100644
--- a/sqlglot/dialects/drill.py
+++ b/sqlglot/dialects/drill.py
@@ -19,7 +19,7 @@ def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Drill.DATE_FORMAT:
- return self.sql(exp.cast(this, "date"))
+ return self.sql(exp.cast(this, exp.DataType.Type.DATE))
return self.func("TO_DATE", this, time_format)
@@ -134,7 +134,7 @@ class Drill(Dialect):
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
),
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
- exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, "date")),
+ exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.DATE)),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 6a1d07a..6486dda 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -15,7 +15,7 @@ from sqlglot.dialects.dialect import (
datestrtodate_sql,
encode_decode_sql,
build_formatted_time,
- inline_array_sql,
+ inline_array_unless_query,
no_comment_column_constraint_sql,
no_safe_divide_sql,
no_timestamp_sql,
@@ -312,6 +312,15 @@ class DuckDB(Dialect):
),
}
+ def _parse_bracket(
+ self, this: t.Optional[exp.Expression] = None
+ ) -> t.Optional[exp.Expression]:
+ bracket = super()._parse_bracket(this)
+ if isinstance(bracket, exp.Bracket):
+ bracket.set("returns_list_for_maps", True)
+
+ return bracket
+
def _parse_map(self) -> exp.ToMap | exp.Map:
if self._match(TokenType.L_BRACE, advance=False):
return self.expression(exp.ToMap, this=self._parse_bracket())
@@ -370,11 +379,7 @@ class DuckDB(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
- exp.Array: lambda self, e: (
- self.func("ARRAY", e.expressions[0])
- if e.expressions and e.expressions[0].find(exp.Select)
- else inline_array_sql(self, e)
- ),
+ exp.Array: inline_array_unless_query,
exp.ArrayFilter: rename_func("LIST_FILTER"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"),
@@ -416,8 +421,8 @@ class DuckDB(Dialect):
exp.MonthsBetween: lambda self, e: self.func(
"DATEDIFF",
"'month'",
- exp.cast(e.expression, "timestamp", copy=True),
- exp.cast(e.this, "timestamp", copy=True),
+ exp.cast(e.expression, exp.DataType.Type.TIMESTAMP, copy=True),
+ exp.cast(e.this, exp.DataType.Type.TIMESTAMP, copy=True),
),
exp.ParseJSON: rename_func("JSON"),
exp.PercentileCont: _rename_unless_within_group("PERCENTILE_CONT", "QUANTILE_CONT"),
@@ -452,9 +457,11 @@ class DuckDB(Dialect):
"DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
),
exp.TimestampTrunc: timestamptrunc_sql,
- exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, "date")),
+ exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.DATE)),
exp.TimeStrToTime: timestrtotime_sql,
- exp.TimeStrToUnix: lambda self, e: self.func("EPOCH", exp.cast(e.this, "timestamp")),
+ exp.TimeStrToUnix: lambda self, e: self.func(
+ "EPOCH", exp.cast(e.this, exp.DataType.Type.TIMESTAMP)
+ ),
exp.TimeToStr: lambda self, e: self.func("STRFTIME", e.this, self.format_time(e)),
exp.TimeToUnix: rename_func("EPOCH"),
exp.TsOrDiToDi: lambda self,
@@ -463,8 +470,8 @@ class DuckDB(Dialect):
exp.TsOrDsDiff: lambda self, e: self.func(
"DATE_DIFF",
f"'{e.args.get('unit') or 'DAY'}'",
- exp.cast(e.expression, "TIMESTAMP"),
- exp.cast(e.this, "TIMESTAMP"),
+ exp.cast(e.expression, exp.DataType.Type.TIMESTAMP),
+ exp.cast(e.this, exp.DataType.Type.TIMESTAMP),
),
exp.UnixToStr: lambda self, e: self.func(
"STRFTIME", self.func("TO_TIMESTAMP", e.this), self.format_time(e)
@@ -593,7 +600,19 @@ class DuckDB(Dialect):
return super().generateseries_sql(expression)
def bracket_sql(self, expression: exp.Bracket) -> str:
- if isinstance(expression.this, exp.Array):
- expression.this.replace(exp.paren(expression.this))
+ this = expression.this
+ if isinstance(this, exp.Array):
+ this.replace(exp.paren(this))
+
+ bracket = super().bracket_sql(expression)
+
+ if not expression.args.get("returns_list_for_maps"):
+ if not this.type:
+ from sqlglot.optimizer.annotate_types import annotate_types
+
+ this = annotate_types(this)
+
+ if this.is_type(exp.DataType.Type.MAP):
+ bracket = f"({bracket})[1]"
- return super().bracket_sql(expression)
+ return bracket
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 1d53346..611a155 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -710,7 +710,9 @@ class MySQL(Dialect):
),
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
- exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),
+ exp.TimeStrToTime: lambda self, e: self.sql(
+ exp.cast(e.this, exp.DataType.Type.DATETIME, copy=True)
+ ),
exp.TimeToStr: _remove_ts_or_ds_to_date(
lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e))
),
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 11398ed..7cbcc23 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -510,6 +510,9 @@ class Postgres(Dialect):
exp.TsOrDsAdd: _date_add_sql("+"),
exp.TsOrDsDiff: _date_diff_sql,
exp.UnixToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this),
+ exp.TimeToUnix: lambda self, e: self.func(
+ "DATE_PART", exp.Literal.string("epoch"), e.this
+ ),
exp.VariancePop: rename_func("VAR_POP"),
exp.Variance: rename_func("VAR_SAMP"),
exp.Xor: bool_xor_sql,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 25bba96..6c23bdf 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -90,8 +90,10 @@ def _str_to_time_sql(
def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT):
- return self.sql(exp.cast(_str_to_time_sql(self, expression), "DATE"))
- return self.sql(exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE"))
+ return self.sql(exp.cast(_str_to_time_sql(self, expression), exp.DataType.Type.DATE))
+ return self.sql(
+ exp.cast(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), exp.DataType.Type.DATE)
+ )
def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
@@ -101,8 +103,8 @@ def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str:
- this = exp.cast(expression.this, "TIMESTAMP")
- expr = exp.cast(expression.expression, "TIMESTAMP")
+ this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)
+ expr = exp.cast(expression.expression, exp.DataType.Type.TIMESTAMP)
unit = unit_to_str(expression)
return self.func("DATE_DIFF", unit, expr, this)
@@ -222,6 +224,8 @@ class Presto(Dialect):
"IPPREFIX": TokenType.IPPREFIX,
}
+ KEYWORDS.pop("QUALIFY")
+
class Parser(parser.Parser):
VALUES_FOLLOWED_BY_PAREN = False
@@ -445,7 +449,7 @@ class Presto(Dialect):
# timezone involved, we wrap it in a `TRY` call and use `PARSE_DATETIME` as a fallback,
# which seems to be using the same time mapping as Hive, as per:
# https://joda-time.sourceforge.net/apidocs/org/joda/time/format/DateTimeFormat.html
- value_as_text = exp.cast(expression.this, "text")
+ value_as_text = exp.cast(expression.this, exp.DataType.Type.TEXT)
parse_without_tz = self.func("DATE_PARSE", value_as_text, self.format_time(expression))
parse_with_tz = self.func(
"PARSE_DATETIME",
diff --git a/sqlglot/dialects/prql.py b/sqlglot/dialects/prql.py
index 3005753..3ee91a8 100644
--- a/sqlglot/dialects/prql.py
+++ b/sqlglot/dialects/prql.py
@@ -7,7 +7,13 @@ from sqlglot.dialects.dialect import Dialect
from sqlglot.tokens import TokenType
+def _select_all(table: exp.Expression) -> t.Optional[exp.Select]:
+ return exp.select("*").from_(table, copy=False) if table else None
+
+
class PRQL(Dialect):
+ DPIPE_IS_STRING_CONCAT = False
+
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ["`"]
QUOTES = ["'", '"']
@@ -26,10 +32,27 @@ class PRQL(Dialect):
}
class Parser(parser.Parser):
+ CONJUNCTION = {
+ **parser.Parser.CONJUNCTION,
+ TokenType.DAMP: exp.And,
+ TokenType.DPIPE: exp.Or,
+ }
+
TRANSFORM_PARSERS = {
"DERIVE": lambda self, query: self._parse_selection(query),
"SELECT": lambda self, query: self._parse_selection(query, append=False),
"TAKE": lambda self, query: self._parse_take(query),
+ "FILTER": lambda self, query: query.where(self._parse_conjunction()),
+ "APPEND": lambda self, query: query.union(
+ _select_all(self._parse_table()), distinct=False, copy=False
+ ),
+ "REMOVE": lambda self, query: query.except_(
+ _select_all(self._parse_table()), distinct=False, copy=False
+ ),
+ "INTERSECT": lambda self, query: query.intersect(
+ _select_all(self._parse_table()), distinct=False, copy=False
+ ),
+ "SORT": lambda self, query: self._parse_order_by(query),
}
def _parse_statement(self) -> t.Optional[exp.Expression]:
@@ -81,6 +104,24 @@ class PRQL(Dialect):
num = self._parse_number() # TODO: TAKE for ranges a..b
return query.limit(num) if num else None
+ def _parse_ordered(
+ self, parse_method: t.Optional[t.Callable] = None
+ ) -> t.Optional[exp.Ordered]:
+ asc = self._match(TokenType.PLUS)
+ desc = self._match(TokenType.DASH) or (asc and False)
+ term = term = super()._parse_ordered(parse_method=parse_method)
+ if term and desc:
+ term.set("desc", True)
+ term.set("nulls_first", False)
+ return term
+
+ def _parse_order_by(self, query: exp.Select) -> t.Optional[exp.Query]:
+ l_brace = self._match(TokenType.L_BRACE)
+ expressions = self._parse_csv(self._parse_ordered)
+ if l_brace and not self._match(TokenType.R_BRACE):
+ self.raise_error("Expecting }")
+ return query.order_by(self.expression(exp.Order, expressions=expressions), copy=False)
+
def _parse_expression(self) -> t.Optional[exp.Expression]:
if self._next and self._next.token_type == TokenType.ALIAS:
alias = self._parse_id_var(True)
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 1f0c411..7a86c61 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -167,7 +167,11 @@ class Redshift(Postgres):
exp.GroupConcat: rename_func("LISTAGG"),
exp.ParseJSON: rename_func("JSON_PARSE"),
exp.Select: transforms.preprocess(
- [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
+ [
+ transforms.eliminate_distinct_on,
+ transforms.eliminate_semi_and_anti_joins,
+ transforms.unqualify_unnest,
+ ]
),
exp.SortKeyProperty: lambda self,
e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
@@ -203,7 +207,7 @@ class Redshift(Postgres):
return ""
arg = self.sql(seq_get(args, 0))
- alias = self.expressions(expression.args.get("alias"), key="columns")
+ alias = self.expressions(expression.args.get("alias"), key="columns", flat=True)
return f"{arg} AS {alias}" if alias else arg
def with_properties(self, properties: exp.Properties) -> str:
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 73a9166..41d5b65 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -818,7 +818,7 @@ class Snowflake(Dialect):
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: self.func(
- "TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e)
+ "TO_CHAR", exp.cast(e.this, exp.DataType.Type.TIMESTAMP), self.format_time(e)
),
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.ToArray: rename_func("TO_ARRAY"),
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 88b5ddc..9bb9a5c 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -6,7 +6,7 @@ from sqlglot import exp
from sqlglot.dialects.dialect import rename_func, unit_to_var
from sqlglot.dialects.hive import _build_with_ignore_nulls
from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider
-from sqlglot.helper import seq_get
+from sqlglot.helper import ensure_list, seq_get
from sqlglot.transforms import (
ctas_with_tmp_tables_to_create_tmp_view,
remove_unique_constraints,
@@ -63,6 +63,9 @@ class Spark(Spark2):
**Spark2.Parser.FUNCTIONS,
"ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue),
"DATEDIFF": _build_datediff,
+ "TRY_ELEMENT_AT": lambda args: exp.Bracket(
+ this=seq_get(args, 0), expressions=ensure_list(seq_get(args, 1)), safe=True
+ ),
}
def _parse_generated_as_identity(
@@ -112,6 +115,13 @@ class Spark(Spark2):
TRANSFORMS.pop(exp.DateDiff)
TRANSFORMS.pop(exp.Group)
+ def bracket_sql(self, expression: exp.Bracket) -> str:
+ if expression.args.get("safe"):
+ key = seq_get(self.bracket_offset_expressions(expression), 0)
+ return self.func("TRY_ELEMENT_AT", expression.this, key)
+
+ return super().bracket_sql(expression)
+
def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str:
return f"GENERATED ALWAYS AS ({self.sql(expression, 'this')})"
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index 069916f..5264f39 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -48,7 +48,7 @@ def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str
timestamp = expression.this
if scale is None:
- return self.sql(exp.cast(exp.func("from_unixtime", timestamp), "timestamp"))
+ return self.sql(exp.cast(exp.func("from_unixtime", timestamp), exp.DataType.Type.TIMESTAMP))
if scale == exp.UnixToTime.SECONDS:
return self.func("TIMESTAMP_SECONDS", timestamp)
if scale == exp.UnixToTime.MILLIS:
@@ -129,11 +129,7 @@ class Spark2(Hive):
"DOUBLE": _build_as_cast("double"),
"FLOAT": _build_as_cast("float"),
"FROM_UTC_TIMESTAMP": lambda args: exp.AtTimeZone(
- this=exp.cast_unless(
- seq_get(args, 0) or exp.Var(this=""),
- exp.DataType.build("timestamp"),
- exp.DataType.build("timestamp"),
- ),
+ this=exp.cast(seq_get(args, 0) or exp.Var(this=""), exp.DataType.Type.TIMESTAMP),
zone=seq_get(args, 1),
),
"INT": _build_as_cast("int"),
@@ -150,11 +146,7 @@ class Spark2(Hive):
),
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"TO_UTC_TIMESTAMP": lambda args: exp.FromTimeZone(
- this=exp.cast_unless(
- seq_get(args, 0) or exp.Var(this=""),
- exp.DataType.build("timestamp"),
- exp.DataType.build("timestamp"),
- ),
+ this=exp.cast(seq_get(args, 0) or exp.Var(this=""), exp.DataType.Type.TIMESTAMP),
zone=seq_get(args, 1),
),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index a65e10e..feb2097 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -13,6 +13,29 @@ from sqlglot.dialects.dialect import (
from sqlglot.tokens import TokenType
+def _date_add_sql(
+ kind: t.Literal["+", "-"],
+) -> t.Callable[[Teradata.Generator, exp.DateAdd | exp.DateSub], str]:
+ def func(self: Teradata.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+ this = self.sql(expression, "this")
+ unit = expression.args.get("unit")
+ value = self._simplify_unless_literal(expression.expression)
+
+ if not isinstance(value, exp.Literal):
+ self.unsupported("Cannot add non literal")
+
+ if value.is_negative:
+ kind_to_op = {"+": "-", "-": "+"}
+ value = exp.Literal.string(value.name[1:])
+ else:
+ kind_to_op = {"+": "+", "-": "-"}
+ value.set("is_string", True)
+
+ return f"{this} {kind_to_op[kind]} {self.sql(exp.Interval(this=value, unit=unit))}"
+
+ return func
+
+
class Teradata(Dialect):
SUPPORTS_SEMI_ANTI_JOIN = False
TYPED_DIVISION = True
@@ -189,6 +212,7 @@ class Teradata(Dialect):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.GEOMETRY: "ST_GEOMETRY",
+ exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
}
PROPERTIES_LOCATION = {
@@ -214,6 +238,10 @@ class Teradata(Dialect):
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.ToNumber: to_number_with_nls_param,
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
+ exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
+ exp.DateAdd: _date_add_sql("+"),
+ exp.DateSub: _date_add_sql("-"),
+ exp.Quarter: lambda self, e: self.sql(exp.Extract(this="QUARTER", expression=e.this)),
}
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
@@ -276,3 +304,25 @@ class Teradata(Dialect):
return f"{this_name}{this_properties}{self.sep()}{this_schema}"
return super().createable_sql(expression, locations)
+
+ def extract_sql(self, expression: exp.Extract) -> str:
+ this = self.sql(expression, "this")
+ if this.upper() != "QUARTER":
+ return super().extract_sql(expression)
+
+ to_char = exp.func("to_char", expression.expression, exp.Literal.string("Q"))
+ return self.sql(exp.cast(to_char, exp.DataType.Type.INT))
+
+ def interval_sql(self, expression: exp.Interval) -> str:
+ multiplier = 0
+ unit = expression.text("unit")
+
+ if unit.startswith("WEEK"):
+ multiplier = 7
+ elif unit.startswith("QUARTER"):
+ multiplier = 90
+
+ if multiplier:
+ return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('DAY')))})"
+
+ return super().interval_sql(expression)
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 8e06be6..6eed46d 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -109,7 +109,7 @@ def _build_formatted_time(
assert len(args) == 2
return exp_class(
- this=exp.cast(args[1], "datetime"),
+ this=exp.cast(args[1], exp.DataType.Type.DATETIME),
format=exp.Literal.string(
format_time(
args[0].name.lower(),
@@ -726,6 +726,7 @@ class TSQL(Dialect):
SUPPORTS_SELECT_INTO = True
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
SUPPORTS_TO_NUMBER = False
+ OUTER_UNION_MODIFIERS = False
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Delete,
@@ -882,13 +883,6 @@ class TSQL(Dialect):
return rename_func("DATETIMEFROMPARTS")(self, expression)
- def set_operations(self, expression: exp.Union) -> str:
- limit = expression.args.get("limit")
- if limit:
- return self.sql(expression.limit(limit.pop(), copy=False))
-
- return super().set_operations(expression)
-
def setitem_sql(self, expression: exp.SetItem) -> str:
this = expression.this
if isinstance(this, exp.EQ) and not isinstance(this.left, exp.Parameter):