summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-31 05:44:41 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-01-31 05:44:41 +0000
commit376de8b6892deca7dc5d83035c047f1e13eb67ea (patch)
tree334a1753cd914294aa99128fac3fb59bf14dc10f /sqlglot/dialects
parentReleasing debian version 20.9.0-1. (diff)
downloadsqlglot-376de8b6892deca7dc5d83035c047f1e13eb67ea.tar.xz
sqlglot-376de8b6892deca7dc5d83035c047f1e13eb67ea.zip
Merging upstream version 20.11.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r--sqlglot/dialects/bigquery.py34
-rw-r--r--sqlglot/dialects/clickhouse.py20
-rw-r--r--sqlglot/dialects/dialect.py36
-rw-r--r--sqlglot/dialects/duckdb.py20
-rw-r--r--sqlglot/dialects/hive.py22
-rw-r--r--sqlglot/dialects/oracle.py3
-rw-r--r--sqlglot/dialects/postgres.py15
-rw-r--r--sqlglot/dialects/presto.py1
-rw-r--r--sqlglot/dialects/snowflake.py60
-rw-r--r--sqlglot/dialects/spark.py6
-rw-r--r--sqlglot/dialects/spark2.py27
-rw-r--r--sqlglot/dialects/tableau.py1
-rw-r--r--sqlglot/dialects/tsql.py37
13 files changed, 195 insertions, 87 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 0151e6c..771ae1a 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -5,7 +5,6 @@ import re
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
-from sqlglot._typing import E
from sqlglot.dialects.dialect import (
Dialect,
NormalizationStrategy,
@@ -30,7 +29,7 @@ from sqlglot.helper import seq_get, split_num_words
from sqlglot.tokens import TokenType
if t.TYPE_CHECKING:
- from typing_extensions import Literal
+ from sqlglot._typing import E, Lit
logger = logging.getLogger("sqlglot")
@@ -47,9 +46,11 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va
exp.alias_(value, column_name)
for value, column_name in zip(
t.expressions,
- alias.columns
- if alias and alias.columns
- else (f"_c{i}" for i in range(len(t.expressions))),
+ (
+ alias.columns
+ if alias and alias.columns
+ else (f"_c{i}" for i in range(len(t.expressions)))
+ ),
)
]
)
@@ -473,12 +474,10 @@ class BigQuery(Dialect):
return table
@t.overload
- def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject:
- ...
+ def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ...
@t.overload
- def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg:
- ...
+ def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ...
def _parse_json_object(self, agg=False):
json_object = super()._parse_json_object()
@@ -546,9 +545,11 @@ class BigQuery(Dialect):
exp.ArrayContains: _array_contains_sql,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]),
- exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}"
- if e.args.get("default")
- else f"COLLATE {self.sql(e, 'this')}",
+ exp.CollateProperty: lambda self, e: (
+ f"DEFAULT COLLATE {self.sql(e, 'this')}"
+ if e.args.get("default")
+ else f"COLLATE {self.sql(e, 'this')}"
+ ),
exp.CountIf: rename_func("COUNTIF"),
exp.Create: _create_sql,
exp.CTE: transforms.preprocess([_pushdown_cte_column_names]),
@@ -560,6 +561,9 @@ class BigQuery(Dialect):
exp.DatetimeAdd: date_add_interval_sql("DATETIME", "ADD"),
exp.DatetimeSub: date_add_interval_sql("DATETIME", "SUB"),
exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")),
+ exp.FromTimeZone: lambda self, e: self.func(
+ "DATETIME", self.func("TIMESTAMP", e.this, e.args.get("zone")), "'UTC'"
+ ),
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
exp.GetPath: path_to_jsonpath(),
exp.GroupConcat: rename_func("STRING_AGG"),
@@ -595,9 +599,9 @@ class BigQuery(Dialect):
exp.SHA2: lambda self, e: self.func(
f"SHA256" if e.text("length") == "256" else "SHA512", e.this
),
- exp.StabilityProperty: lambda self, e: f"DETERMINISTIC"
- if e.name == "IMMUTABLE"
- else "NOT DETERMINISTIC",
+ exp.StabilityProperty: lambda self, e: (
+ f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC"
+ ),
exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})",
exp.StrToTime: lambda self, e: self.func(
"PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone")
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index f2e4fe1..1248edc 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -88,6 +88,8 @@ class ClickHouse(Dialect):
"UINT8": TokenType.UTINYINT,
"IPV4": TokenType.IPV4,
"IPV6": TokenType.IPV6,
+ "AGGREGATEFUNCTION": TokenType.AGGREGATEFUNCTION,
+ "SIMPLEAGGREGATEFUNCTION": TokenType.SIMPLEAGGREGATEFUNCTION,
}
SINGLE_TOKENS = {
@@ -548,6 +550,8 @@ class ClickHouse(Dialect):
exp.DataType.Type.UTINYINT: "UInt8",
exp.DataType.Type.IPV4: "IPv4",
exp.DataType.Type.IPV6: "IPv6",
+ exp.DataType.Type.AGGREGATEFUNCTION: "AggregateFunction",
+ exp.DataType.Type.SIMPLEAGGREGATEFUNCTION: "SimpleAggregateFunction",
}
TRANSFORMS = {
@@ -651,12 +655,16 @@ class ClickHouse(Dialect):
def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
return super().after_limit_modifiers(expression) + [
- self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True)
- if expression.args.get("settings")
- else "",
- self.seg("FORMAT ") + self.sql(expression, "format")
- if expression.args.get("format")
- else "",
+ (
+ self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True)
+ if expression.args.get("settings")
+ else ""
+ ),
+ (
+ self.seg("FORMAT ") + self.sql(expression, "format")
+ if expression.args.get("format")
+ else ""
+ ),
]
def parameterizedagg_sql(self, expression: exp.ParameterizedAgg) -> str:
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 7664c40..6be991b 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -5,7 +5,6 @@ from enum import Enum, auto
from functools import reduce
from sqlglot import exp
-from sqlglot._typing import E
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
from sqlglot.helper import AutoName, flatten, seq_get
@@ -14,11 +13,12 @@ from sqlglot.time import TIMEZONES, format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie
-B = t.TypeVar("B", bound=exp.Binary)
-
DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff]
DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub]
+if t.TYPE_CHECKING:
+ from sqlglot._typing import B, E
+
class Dialects(str, Enum):
"""Dialects supported by SQLGLot."""
@@ -381,9 +381,11 @@ class Dialect(metaclass=_Dialect):
):
expression.set(
"this",
- expression.this.upper()
- if self.normalization_strategy is NormalizationStrategy.UPPERCASE
- else expression.this.lower(),
+ (
+ expression.this.upper()
+ if self.normalization_strategy is NormalizationStrategy.UPPERCASE
+ else expression.this.lower()
+ ),
)
return expression
@@ -877,9 +879,11 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp
Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes).
"""
agg_all_unquoted = agg.transform(
- lambda node: exp.Identifier(this=node.name, quoted=False)
- if isinstance(node, exp.Identifier)
- else node
+ lambda node: (
+ exp.Identifier(this=node.name, quoted=False)
+ if isinstance(node, exp.Identifier)
+ else node
+ )
)
names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower"))
@@ -999,10 +1003,8 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
"""Remove table refs from columns in when statements."""
alias = expression.this.args.get("alias")
- normalize = (
- lambda identifier: self.dialect.normalize_identifier(identifier).name
- if identifier
- else None
+ normalize = lambda identifier: (
+ self.dialect.normalize_identifier(identifier).name if identifier else None
)
targets = {normalize(expression.this.this)}
@@ -1012,9 +1014,11 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
for when in expression.expressions:
when.transform(
- lambda node: exp.column(node.this)
- if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
- else node,
+ lambda node: (
+ exp.column(node.this)
+ if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
+ else node
+ ),
copy=False,
)
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 2343b35..f55ad70 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -148,8 +148,8 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str
def _rename_unless_within_group(
a: str, b: str
) -> t.Callable[[DuckDB.Generator, exp.Expression], str]:
- return (
- lambda self, expression: self.func(a, *flatten(expression.args.values()))
+ return lambda self, expression: (
+ self.func(a, *flatten(expression.args.values()))
if isinstance(expression.find_ancestor(exp.Select, exp.WithinGroup), exp.WithinGroup)
else self.func(b, *flatten(expression.args.values()))
)
@@ -273,9 +273,11 @@ class DuckDB(Dialect):
PLACEHOLDER_PARSERS = {
**parser.Parser.PLACEHOLDER_PARSERS,
- TokenType.PARAMETER: lambda self: self.expression(exp.Placeholder, this=self._prev.text)
- if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
- else None,
+ TokenType.PARAMETER: lambda self: (
+ self.expression(exp.Placeholder, this=self._prev.text)
+ if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS)
+ else None
+ ),
}
def _parse_types(
@@ -321,9 +323,11 @@ class DuckDB(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
- exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0])
- if e.expressions and e.expressions[0].find(exp.Select)
- else inline_array_sql(self, e),
+ exp.Array: lambda self, e: (
+ self.func("ARRAY", e.expressions[0])
+ if e.expressions and e.expressions[0].find(exp.Select)
+ else inline_array_sql(self, e)
+ ),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"),
exp.ArgMin: arg_max_or_min_no_count("ARG_MIN"),
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index dffa41e..060f9bd 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -397,9 +397,11 @@ class Hive(Dialect):
if this and not schema:
return this.transform(
- lambda node: node.replace(exp.DataType.build("text"))
- if isinstance(node, exp.DataType) and node.is_type("char", "varchar")
- else node,
+ lambda node: (
+ node.replace(exp.DataType.build("text"))
+ if isinstance(node, exp.DataType) and node.is_type("char", "varchar")
+ else node
+ ),
copy=False,
)
@@ -409,9 +411,11 @@ class Hive(Dialect):
self,
) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]:
return (
- self._parse_csv(self._parse_conjunction)
- if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY})
- else [],
+ (
+ self._parse_csv(self._parse_conjunction)
+ if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY})
+ else []
+ ),
super()._parse_order(skip_order_token=self._match(TokenType.SORT_BY)),
)
@@ -483,9 +487,9 @@ class Hive(Dialect):
exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)),
exp.Min: min_or_least,
exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression),
- exp.NotNullColumnConstraint: lambda self, e: ""
- if e.args.get("allow_null")
- else "NOT NULL",
+ exp.NotNullColumnConstraint: lambda self, e: (
+ "" if e.args.get("allow_null") else "NOT NULL"
+ ),
exp.VarMap: var_map_sql,
exp.Create: _create_sql,
exp.Quantile: rename_func("PERCENTILE"),
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 6ad3718..4591d59 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -166,6 +166,7 @@ class Oracle(Dialect):
TABLESAMPLE_KEYWORDS = "SAMPLE"
LAST_DAY_SUPPORTS_DATE_PART = False
SUPPORTS_SELECT_INTO = True
+ TZ_TO_WITH_TIME_ZONE = True
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -179,6 +180,8 @@ class Oracle(Dialect):
exp.DataType.Type.NVARCHAR: "NVARCHAR2",
exp.DataType.Type.NCHAR: "NCHAR",
exp.DataType.Type.TEXT: "CLOB",
+ exp.DataType.Type.TIMETZ: "TIME",
+ exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
exp.DataType.Type.BINARY: "BLOB",
exp.DataType.Type.VARBINARY: "BLOB",
}
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 1ca0a78..87f6b02 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -282,6 +282,12 @@ class Postgres(Dialect):
VAR_SINGLE_TOKENS = {"$"}
class Parser(parser.Parser):
+ PROPERTY_PARSERS = {
+ **parser.Parser.PROPERTY_PARSERS,
+ "SET": lambda self: self.expression(exp.SetConfigProperty, this=self._parse_set()),
+ }
+ PROPERTY_PARSERS.pop("INPUT", None)
+
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": parse_timestamp_trunc,
@@ -385,9 +391,11 @@ class Postgres(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.AnyValue: any_value_to_max_sql,
- exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
- if isinstance(seq_get(e.expressions, 0), exp.Select)
- else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]",
+ exp.Array: lambda self, e: (
+ f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
+ if isinstance(seq_get(e.expressions, 0), exp.Select)
+ else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]"
+ ),
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
@@ -396,6 +404,7 @@ class Postgres(Dialect):
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
+ exp.CurrentUser: lambda *_: "CURRENT_USER",
exp.DateAdd: _date_add_sql("+"),
exp.DateDiff: _date_diff_sql,
exp.DateStrToDate: datestrtodate_sql,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 9b421e7..6cc6030 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -356,6 +356,7 @@ class Presto(Dialect):
exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"),
exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'",
exp.First: _first_last_sql,
+ exp.FromTimeZone: lambda self, e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'",
exp.GetPath: path_to_jsonpath(),
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.GroupConcat: lambda self, e: self.func(
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index a8e4a42..281167d 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -3,7 +3,6 @@ from __future__ import annotations
import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
-from sqlglot._typing import E
from sqlglot.dialects.dialect import (
Dialect,
NormalizationStrategy,
@@ -25,6 +24,9 @@ from sqlglot.expressions import Literal
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
+
def _check_int(s: str) -> bool:
if s[0] in ("-", "+"):
@@ -297,10 +299,7 @@ def _parse_colon_get_path(
if not self._match(TokenType.COLON):
break
- if self._match_set(self.RANGE_PARSERS):
- this = self.RANGE_PARSERS[self._prev.token_type](self, this) or this
-
- return this
+ return self._parse_range(this)
def _parse_timestamp_from_parts(args: t.List) -> exp.Func:
@@ -376,7 +375,7 @@ class Snowflake(Dialect):
and isinstance(expression.parent, exp.Table)
and expression.name.lower() == "dual"
):
- return t.cast(E, expression)
+ return expression # type: ignore
return super().quote_identifier(expression, identify=identify)
@@ -471,6 +470,10 @@ class Snowflake(Dialect):
}
SHOW_PARSERS = {
+ "SCHEMAS": _show_parser("SCHEMAS"),
+ "TERSE SCHEMAS": _show_parser("SCHEMAS"),
+ "OBJECTS": _show_parser("OBJECTS"),
+ "TERSE OBJECTS": _show_parser("OBJECTS"),
"PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"),
"COLUMNS": _show_parser("COLUMNS"),
@@ -580,21 +583,35 @@ class Snowflake(Dialect):
scope = None
scope_kind = None
+ # will identity SHOW TERSE SCHEMAS but not SHOW TERSE PRIMARY KEYS
+ # which is syntactically valid but has no effect on the output
+ terse = self._tokens[self._index - 2].text.upper() == "TERSE"
+
like = self._parse_string() if self._match(TokenType.LIKE) else None
if self._match(TokenType.IN):
if self._match_text_seq("ACCOUNT"):
scope_kind = "ACCOUNT"
elif self._match_set(self.DB_CREATABLES):
- scope_kind = self._prev.text
+ scope_kind = self._prev.text.upper()
if self._curr:
- scope = self._parse_table()
+ scope = self._parse_table_parts()
elif self._curr:
- scope_kind = "TABLE"
- scope = self._parse_table()
+ scope_kind = "SCHEMA" if this == "OBJECTS" else "TABLE"
+ scope = self._parse_table_parts()
return self.expression(
- exp.Show, this=this, like=like, scope=scope, scope_kind=scope_kind
+ exp.Show,
+ **{
+ "terse": terse,
+ "this": this,
+ "like": like,
+ "scope": scope,
+ "scope_kind": scope_kind,
+ "starts_with": self._match_text_seq("STARTS", "WITH") and self._parse_string(),
+ "limit": self._parse_limit(),
+ "from": self._parse_string() if self._match(TokenType.FROM) else None,
+ },
)
def _parse_alter_table_swap(self) -> exp.SwapTable:
@@ -690,6 +707,9 @@ class Snowflake(Dialect):
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.Explode: rename_func("FLATTEN"),
exp.Extract: rename_func("DATE_PART"),
+ exp.FromTimeZone: lambda self, e: self.func(
+ "CONVERT_TIMEZONE", e.args.get("zone"), "'UTC'", e.this
+ ),
exp.GenerateSeries: lambda self, e: self.func(
"ARRAY_GENERATE_RANGE", e.args["start"], e.args["end"] + 1, e.args.get("step")
),
@@ -820,6 +840,7 @@ class Snowflake(Dialect):
return f"{explode}{alias}"
def show_sql(self, expression: exp.Show) -> str:
+ terse = "TERSE " if expression.args.get("terse") else ""
like = self.sql(expression, "like")
like = f" LIKE {like}" if like else ""
@@ -830,7 +851,19 @@ class Snowflake(Dialect):
if scope_kind:
scope_kind = f" IN {scope_kind}"
- return f"SHOW {expression.name}{like}{scope_kind}{scope}"
+ starts_with = self.sql(expression, "starts_with")
+ if starts_with:
+ starts_with = f" STARTS WITH {starts_with}"
+
+ limit = self.sql(expression, "limit")
+
+ from_ = self.sql(expression, "from")
+ if from_:
+ from_ = f" FROM {from_}"
+
+ return (
+ f"SHOW {terse}{expression.name}{like}{scope_kind}{scope}{starts_with}{limit}{from_}"
+ )
def regexpextract_sql(self, expression: exp.RegexpExtract) -> str:
# Other dialects don't support all of the following parameters, so we need to
@@ -884,3 +917,6 @@ class Snowflake(Dialect):
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, wrapped=False, prefix=self.seg(""), sep=" ")
+
+ def cluster_sql(self, expression: exp.Cluster) -> str:
+ return f"CLUSTER BY ({self.expressions(expression, flat=True)})"
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index ba73ac0..624f76e 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -80,9 +80,9 @@ class Spark(Spark2):
exp.TimestampAdd: lambda self, e: self.func(
"DATEADD", e.args.get("unit") or "DAY", e.expression, e.this
),
- exp.TryCast: lambda self, e: self.trycast_sql(e)
- if e.args.get("safe")
- else self.cast_sql(e),
+ exp.TryCast: lambda self, e: (
+ self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e)
+ ),
}
TRANSFORMS.pop(exp.AnyValue)
TRANSFORMS.pop(exp.DateDiff)
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index e27ba18..e4bb30e 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -129,10 +129,20 @@ class Spark2(Hive):
"SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift),
"STRING": _parse_as_cast("string"),
"TIMESTAMP": _parse_as_cast("timestamp"),
- "TO_TIMESTAMP": lambda args: _parse_as_cast("timestamp")(args)
- if len(args) == 1
- else format_time_lambda(exp.StrToTime, "spark")(args),
+ "TO_TIMESTAMP": lambda args: (
+ _parse_as_cast("timestamp")(args)
+ if len(args) == 1
+ else format_time_lambda(exp.StrToTime, "spark")(args)
+ ),
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
+ "TO_UTC_TIMESTAMP": lambda args: exp.FromTimeZone(
+ this=exp.cast_unless(
+ seq_get(args, 0) or exp.Var(this=""),
+ exp.DataType.build("timestamp"),
+ exp.DataType.build("timestamp"),
+ ),
+ zone=seq_get(args, 1),
+ ),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
"WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
}
@@ -188,6 +198,7 @@ class Spark2(Hive):
exp.DayOfYear: rename_func("DAYOFYEAR"),
exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}",
exp.From: transforms.preprocess([_unalias_pivot]),
+ exp.FromTimeZone: lambda self, e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})",
exp.LogicalAnd: rename_func("BOOL_AND"),
exp.LogicalOr: rename_func("BOOL_OR"),
exp.Map: _map_sql,
@@ -255,10 +266,12 @@ class Spark2(Hive):
def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
return super().columndef_sql(
expression,
- sep=": "
- if isinstance(expression.parent, exp.DataType)
- and expression.parent.is_type("struct")
- else sep,
+ sep=(
+ ": "
+ if isinstance(expression.parent, exp.DataType)
+ and expression.parent.is_type("struct")
+ else sep
+ ),
)
class Tokenizer(Hive.Tokenizer):
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
index 33ec7e1..3795045 100644
--- a/sqlglot/dialects/tableau.py
+++ b/sqlglot/dialects/tableau.py
@@ -38,3 +38,4 @@ class Tableau(Dialect):
**parser.Parser.FUNCTIONS,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
}
+ NO_PAREN_IF_COMMANDS = False
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index b9c347c..a5e04da 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -76,9 +76,11 @@ def _format_time_lambda(
format=exp.Literal.string(
format_time(
args[0].name.lower(),
- {**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING}
- if full_format_mapping
- else TSQL.TIME_MAPPING,
+ (
+ {**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING}
+ if full_format_mapping
+ else TSQL.TIME_MAPPING
+ ),
)
),
)
@@ -264,6 +266,15 @@ def _parse_timefromparts(args: t.List) -> exp.TimeFromParts:
)
+def _parse_len(args: t.List) -> exp.Length:
+ this = seq_get(args, 0)
+
+ if this and not this.is_string:
+ this = exp.cast(this, exp.DataType.Type.TEXT)
+
+ return exp.Length(this=this)
+
+
class TSQL(Dialect):
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
@@ -431,7 +442,7 @@ class TSQL(Dialect):
"IIF": exp.If.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
- "LEN": exp.Length.from_arg_list,
+ "LEN": _parse_len,
"REPLICATE": exp.Repeat.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
@@ -469,6 +480,7 @@ class TSQL(Dialect):
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
STRING_ALIASES = True
+ NO_PAREN_IF_COMMANDS = False
def _parse_projections(self) -> t.List[exp.Expression]:
"""
@@ -478,9 +490,11 @@ class TSQL(Dialect):
See: https://learn.microsoft.com/en-us/sql/t-sql/queries/select-clause-transact-sql?view=sql-server-ver16#syntax
"""
return [
- exp.alias_(projection.expression, projection.this.this, copy=False)
- if isinstance(projection, exp.EQ) and isinstance(projection.this, exp.Column)
- else projection
+ (
+ exp.alias_(projection.expression, projection.this.this, copy=False)
+ if isinstance(projection, exp.EQ) and isinstance(projection.this, exp.Column)
+ else projection
+ )
for projection in super()._parse_projections()
]
@@ -702,7 +716,6 @@ class TSQL(Dialect):
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
exp.LastDay: lambda self, e: self.func("EOMONTH", e.this),
- exp.Length: rename_func("LEN"),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
@@ -922,3 +935,11 @@ class TSQL(Dialect):
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True, sep=" ")
return f"CONSTRAINT {this} {expressions}"
+
+ def length_sql(self, expression: exp.Length) -> str:
+ this = expression.this
+ if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.TEXT):
+ this_sql = self.sql(this, "this")
+ else:
+ this_sql = self.sql(this)
+ return self.func("LEN", this_sql)