From 376de8b6892deca7dc5d83035c047f1e13eb67ea Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 31 Jan 2024 06:44:41 +0100 Subject: Merging upstream version 20.11.0. Signed-off-by: Daniel Baumann --- sqlglot/dialects/bigquery.py | 34 +++++++++++++----------- sqlglot/dialects/clickhouse.py | 20 +++++++++----- sqlglot/dialects/dialect.py | 36 ++++++++++++++----------- sqlglot/dialects/duckdb.py | 20 ++++++++------ sqlglot/dialects/hive.py | 22 +++++++++------- sqlglot/dialects/oracle.py | 3 +++ sqlglot/dialects/postgres.py | 15 ++++++++--- sqlglot/dialects/presto.py | 1 + sqlglot/dialects/snowflake.py | 60 +++++++++++++++++++++++++++++++++--------- sqlglot/dialects/spark.py | 6 ++--- sqlglot/dialects/spark2.py | 27 ++++++++++++++----- sqlglot/dialects/tableau.py | 1 + sqlglot/dialects/tsql.py | 37 ++++++++++++++++++++------ 13 files changed, 195 insertions(+), 87 deletions(-) (limited to 'sqlglot/dialects') diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 0151e6c..771ae1a 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -5,7 +5,6 @@ import re import typing as t from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot._typing import E from sqlglot.dialects.dialect import ( Dialect, NormalizationStrategy, @@ -30,7 +29,7 @@ from sqlglot.helper import seq_get, split_num_words from sqlglot.tokens import TokenType if t.TYPE_CHECKING: - from typing_extensions import Literal + from sqlglot._typing import E, Lit logger = logging.getLogger("sqlglot") @@ -47,9 +46,11 @@ def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Va exp.alias_(value, column_name) for value, column_name in zip( t.expressions, - alias.columns - if alias and alias.columns - else (f"_c{i}" for i in range(len(t.expressions))), + ( + alias.columns + if alias and alias.columns + else (f"_c{i}" for i in range(len(t.expressions))) + ), ) ] ) @@ -473,12 +474,10 @@ class BigQuery(Dialect): return table @t.overload - def _parse_json_object(self, agg: Literal[False]) -> exp.JSONObject: - ... + def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ... @t.overload - def _parse_json_object(self, agg: Literal[True]) -> exp.JSONObjectAgg: - ... + def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ... def _parse_json_object(self, agg=False): json_object = super()._parse_json_object() @@ -546,9 +545,11 @@ class BigQuery(Dialect): exp.ArrayContains: _array_contains_sql, exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]), - exp.CollateProperty: lambda self, e: f"DEFAULT COLLATE {self.sql(e, 'this')}" - if e.args.get("default") - else f"COLLATE {self.sql(e, 'this')}", + exp.CollateProperty: lambda self, e: ( + f"DEFAULT COLLATE {self.sql(e, 'this')}" + if e.args.get("default") + else f"COLLATE {self.sql(e, 'this')}" + ), exp.CountIf: rename_func("COUNTIF"), exp.Create: _create_sql, exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), @@ -560,6 +561,9 @@ class BigQuery(Dialect): exp.DatetimeAdd: date_add_interval_sql("DATETIME", "ADD"), exp.DatetimeSub: date_add_interval_sql("DATETIME", "SUB"), exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, e.text("unit")), + exp.FromTimeZone: lambda self, e: self.func( + "DATETIME", self.func("TIMESTAMP", e.this, e.args.get("zone")), "'UTC'" + ), exp.GenerateSeries: rename_func("GENERATE_ARRAY"), exp.GetPath: path_to_jsonpath(), exp.GroupConcat: rename_func("STRING_AGG"), @@ -595,9 +599,9 @@ class BigQuery(Dialect): exp.SHA2: lambda self, e: self.func( f"SHA256" if e.text("length") == "256" else "SHA512", e.this ), - exp.StabilityProperty: lambda self, e: f"DETERMINISTIC" - if e.name == "IMMUTABLE" - else "NOT DETERMINISTIC", + exp.StabilityProperty: lambda self, e: ( + f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC" + ), exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})", exp.StrToTime: lambda self, e: self.func( "PARSE_TIMESTAMP", self.format_time(e), e.this, e.args.get("zone") diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index f2e4fe1..1248edc 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -88,6 +88,8 @@ class ClickHouse(Dialect): "UINT8": TokenType.UTINYINT, "IPV4": TokenType.IPV4, "IPV6": TokenType.IPV6, + "AGGREGATEFUNCTION": TokenType.AGGREGATEFUNCTION, + "SIMPLEAGGREGATEFUNCTION": TokenType.SIMPLEAGGREGATEFUNCTION, } SINGLE_TOKENS = { @@ -548,6 +550,8 @@ class ClickHouse(Dialect): exp.DataType.Type.UTINYINT: "UInt8", exp.DataType.Type.IPV4: "IPv4", exp.DataType.Type.IPV6: "IPv6", + exp.DataType.Type.AGGREGATEFUNCTION: "AggregateFunction", + exp.DataType.Type.SIMPLEAGGREGATEFUNCTION: "SimpleAggregateFunction", } TRANSFORMS = { @@ -651,12 +655,16 @@ class ClickHouse(Dialect): def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]: return super().after_limit_modifiers(expression) + [ - self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True) - if expression.args.get("settings") - else "", - self.seg("FORMAT ") + self.sql(expression, "format") - if expression.args.get("format") - else "", + ( + self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True) + if expression.args.get("settings") + else "" + ), + ( + self.seg("FORMAT ") + self.sql(expression, "format") + if expression.args.get("format") + else "" + ), ] def parameterizedagg_sql(self, expression: exp.ParameterizedAgg) -> str: diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 7664c40..6be991b 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -5,7 +5,6 @@ from enum import Enum, auto from functools import reduce from sqlglot import exp -from sqlglot._typing import E from sqlglot.errors import ParseError from sqlglot.generator import Generator from sqlglot.helper import AutoName, flatten, seq_get @@ -14,11 +13,12 @@ from sqlglot.time import TIMEZONES, format_time from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import new_trie -B = t.TypeVar("B", bound=exp.Binary) - DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] +if t.TYPE_CHECKING: + from sqlglot._typing import B, E + class Dialects(str, Enum): """Dialects supported by SQLGLot.""" @@ -381,9 +381,11 @@ class Dialect(metaclass=_Dialect): ): expression.set( "this", - expression.this.upper() - if self.normalization_strategy is NormalizationStrategy.UPPERCASE - else expression.this.lower(), + ( + expression.this.upper() + if self.normalization_strategy is NormalizationStrategy.UPPERCASE + else expression.this.lower() + ), ) return expression @@ -877,9 +879,11 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). """ agg_all_unquoted = agg.transform( - lambda node: exp.Identifier(this=node.name, quoted=False) - if isinstance(node, exp.Identifier) - else node + lambda node: ( + exp.Identifier(this=node.name, quoted=False) + if isinstance(node, exp.Identifier) + else node + ) ) names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) @@ -999,10 +1003,8 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: """Remove table refs from columns in when statements.""" alias = expression.this.args.get("alias") - normalize = ( - lambda identifier: self.dialect.normalize_identifier(identifier).name - if identifier - else None + normalize = lambda identifier: ( + self.dialect.normalize_identifier(identifier).name if identifier else None ) targets = {normalize(expression.this.this)} @@ -1012,9 +1014,11 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: for when in expression.expressions: when.transform( - lambda node: exp.column(node.this) - if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets - else node, + lambda node: ( + exp.column(node.this) + if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets + else node + ), copy=False, ) diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 2343b35..f55ad70 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -148,8 +148,8 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str def _rename_unless_within_group( a: str, b: str ) -> t.Callable[[DuckDB.Generator, exp.Expression], str]: - return ( - lambda self, expression: self.func(a, *flatten(expression.args.values())) + return lambda self, expression: ( + self.func(a, *flatten(expression.args.values())) if isinstance(expression.find_ancestor(exp.Select, exp.WithinGroup), exp.WithinGroup) else self.func(b, *flatten(expression.args.values())) ) @@ -273,9 +273,11 @@ class DuckDB(Dialect): PLACEHOLDER_PARSERS = { **parser.Parser.PLACEHOLDER_PARSERS, - TokenType.PARAMETER: lambda self: self.expression(exp.Placeholder, this=self._prev.text) - if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS) - else None, + TokenType.PARAMETER: lambda self: ( + self.expression(exp.Placeholder, this=self._prev.text) + if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS) + else None + ), } def _parse_types( @@ -321,9 +323,11 @@ class DuckDB(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, - exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0]) - if e.expressions and e.expressions[0].find(exp.Select) - else inline_array_sql(self, e), + exp.Array: lambda self, e: ( + self.func("ARRAY", e.expressions[0]) + if e.expressions and e.expressions[0].find(exp.Select) + else inline_array_sql(self, e) + ), exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"), exp.ArgMin: arg_max_or_min_no_count("ARG_MIN"), diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index dffa41e..060f9bd 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -397,9 +397,11 @@ class Hive(Dialect): if this and not schema: return this.transform( - lambda node: node.replace(exp.DataType.build("text")) - if isinstance(node, exp.DataType) and node.is_type("char", "varchar") - else node, + lambda node: ( + node.replace(exp.DataType.build("text")) + if isinstance(node, exp.DataType) and node.is_type("char", "varchar") + else node + ), copy=False, ) @@ -409,9 +411,11 @@ class Hive(Dialect): self, ) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]: return ( - self._parse_csv(self._parse_conjunction) - if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY}) - else [], + ( + self._parse_csv(self._parse_conjunction) + if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY}) + else [] + ), super()._parse_order(skip_order_token=self._match(TokenType.SORT_BY)), ) @@ -483,9 +487,9 @@ class Hive(Dialect): exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)), exp.Min: min_or_least, exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression), - exp.NotNullColumnConstraint: lambda self, e: "" - if e.args.get("allow_null") - else "NOT NULL", + exp.NotNullColumnConstraint: lambda self, e: ( + "" if e.args.get("allow_null") else "NOT NULL" + ), exp.VarMap: var_map_sql, exp.Create: _create_sql, exp.Quantile: rename_func("PERCENTILE"), diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 6ad3718..4591d59 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -166,6 +166,7 @@ class Oracle(Dialect): TABLESAMPLE_KEYWORDS = "SAMPLE" LAST_DAY_SUPPORTS_DATE_PART = False SUPPORTS_SELECT_INTO = True + TZ_TO_WITH_TIME_ZONE = True TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -179,6 +180,8 @@ class Oracle(Dialect): exp.DataType.Type.NVARCHAR: "NVARCHAR2", exp.DataType.Type.NCHAR: "NCHAR", exp.DataType.Type.TEXT: "CLOB", + exp.DataType.Type.TIMETZ: "TIME", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", exp.DataType.Type.BINARY: "BLOB", exp.DataType.Type.VARBINARY: "BLOB", } diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 1ca0a78..87f6b02 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -282,6 +282,12 @@ class Postgres(Dialect): VAR_SINGLE_TOKENS = {"$"} class Parser(parser.Parser): + PROPERTY_PARSERS = { + **parser.Parser.PROPERTY_PARSERS, + "SET": lambda self: self.expression(exp.SetConfigProperty, this=self._parse_set()), + } + PROPERTY_PARSERS.pop("INPUT", None) + FUNCTIONS = { **parser.Parser.FUNCTIONS, "DATE_TRUNC": parse_timestamp_trunc, @@ -385,9 +391,11 @@ class Postgres(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.AnyValue: any_value_to_max_sql, - exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})" - if isinstance(seq_get(e.expressions, 0), exp.Select) - else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]", + exp.Array: lambda self, e: ( + f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})" + if isinstance(seq_get(e.expressions, 0), exp.Select) + else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]" + ), exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.ArrayContained: lambda self, e: self.binary(e, "<@"), exp.ArrayContains: lambda self, e: self.binary(e, "@>"), @@ -396,6 +404,7 @@ class Postgres(Dialect): exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]), exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", + exp.CurrentUser: lambda *_: "CURRENT_USER", exp.DateAdd: _date_add_sql("+"), exp.DateDiff: _date_diff_sql, exp.DateStrToDate: datestrtodate_sql, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 9b421e7..6cc6030 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -356,6 +356,7 @@ class Presto(Dialect): exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"), exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.First: _first_last_sql, + exp.FromTimeZone: lambda self, e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'", exp.GetPath: path_to_jsonpath(), exp.Group: transforms.preprocess([transforms.unalias_group]), exp.GroupConcat: lambda self, e: self.func( diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index a8e4a42..281167d 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -3,7 +3,6 @@ from __future__ import annotations import typing as t from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot._typing import E from sqlglot.dialects.dialect import ( Dialect, NormalizationStrategy, @@ -25,6 +24,9 @@ from sqlglot.expressions import Literal from sqlglot.helper import seq_get from sqlglot.tokens import TokenType +if t.TYPE_CHECKING: + from sqlglot._typing import E + def _check_int(s: str) -> bool: if s[0] in ("-", "+"): @@ -297,10 +299,7 @@ def _parse_colon_get_path( if not self._match(TokenType.COLON): break - if self._match_set(self.RANGE_PARSERS): - this = self.RANGE_PARSERS[self._prev.token_type](self, this) or this - - return this + return self._parse_range(this) def _parse_timestamp_from_parts(args: t.List) -> exp.Func: @@ -376,7 +375,7 @@ class Snowflake(Dialect): and isinstance(expression.parent, exp.Table) and expression.name.lower() == "dual" ): - return t.cast(E, expression) + return expression # type: ignore return super().quote_identifier(expression, identify=identify) @@ -471,6 +470,10 @@ class Snowflake(Dialect): } SHOW_PARSERS = { + "SCHEMAS": _show_parser("SCHEMAS"), + "TERSE SCHEMAS": _show_parser("SCHEMAS"), + "OBJECTS": _show_parser("OBJECTS"), + "TERSE OBJECTS": _show_parser("OBJECTS"), "PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), "COLUMNS": _show_parser("COLUMNS"), @@ -580,21 +583,35 @@ class Snowflake(Dialect): scope = None scope_kind = None + # will identity SHOW TERSE SCHEMAS but not SHOW TERSE PRIMARY KEYS + # which is syntactically valid but has no effect on the output + terse = self._tokens[self._index - 2].text.upper() == "TERSE" + like = self._parse_string() if self._match(TokenType.LIKE) else None if self._match(TokenType.IN): if self._match_text_seq("ACCOUNT"): scope_kind = "ACCOUNT" elif self._match_set(self.DB_CREATABLES): - scope_kind = self._prev.text + scope_kind = self._prev.text.upper() if self._curr: - scope = self._parse_table() + scope = self._parse_table_parts() elif self._curr: - scope_kind = "TABLE" - scope = self._parse_table() + scope_kind = "SCHEMA" if this == "OBJECTS" else "TABLE" + scope = self._parse_table_parts() return self.expression( - exp.Show, this=this, like=like, scope=scope, scope_kind=scope_kind + exp.Show, + **{ + "terse": terse, + "this": this, + "like": like, + "scope": scope, + "scope_kind": scope_kind, + "starts_with": self._match_text_seq("STARTS", "WITH") and self._parse_string(), + "limit": self._parse_limit(), + "from": self._parse_string() if self._match(TokenType.FROM) else None, + }, ) def _parse_alter_table_swap(self) -> exp.SwapTable: @@ -690,6 +707,9 @@ class Snowflake(Dialect): exp.DayOfYear: rename_func("DAYOFYEAR"), exp.Explode: rename_func("FLATTEN"), exp.Extract: rename_func("DATE_PART"), + exp.FromTimeZone: lambda self, e: self.func( + "CONVERT_TIMEZONE", e.args.get("zone"), "'UTC'", e.this + ), exp.GenerateSeries: lambda self, e: self.func( "ARRAY_GENERATE_RANGE", e.args["start"], e.args["end"] + 1, e.args.get("step") ), @@ -820,6 +840,7 @@ class Snowflake(Dialect): return f"{explode}{alias}" def show_sql(self, expression: exp.Show) -> str: + terse = "TERSE " if expression.args.get("terse") else "" like = self.sql(expression, "like") like = f" LIKE {like}" if like else "" @@ -830,7 +851,19 @@ class Snowflake(Dialect): if scope_kind: scope_kind = f" IN {scope_kind}" - return f"SHOW {expression.name}{like}{scope_kind}{scope}" + starts_with = self.sql(expression, "starts_with") + if starts_with: + starts_with = f" STARTS WITH {starts_with}" + + limit = self.sql(expression, "limit") + + from_ = self.sql(expression, "from") + if from_: + from_ = f" FROM {from_}" + + return ( + f"SHOW {terse}{expression.name}{like}{scope_kind}{scope}{starts_with}{limit}{from_}" + ) def regexpextract_sql(self, expression: exp.RegexpExtract) -> str: # Other dialects don't support all of the following parameters, so we need to @@ -884,3 +917,6 @@ class Snowflake(Dialect): def with_properties(self, properties: exp.Properties) -> str: return self.properties(properties, wrapped=False, prefix=self.seg(""), sep=" ") + + def cluster_sql(self, expression: exp.Cluster) -> str: + return f"CLUSTER BY ({self.expressions(expression, flat=True)})" diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index ba73ac0..624f76e 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -80,9 +80,9 @@ class Spark(Spark2): exp.TimestampAdd: lambda self, e: self.func( "DATEADD", e.args.get("unit") or "DAY", e.expression, e.this ), - exp.TryCast: lambda self, e: self.trycast_sql(e) - if e.args.get("safe") - else self.cast_sql(e), + exp.TryCast: lambda self, e: ( + self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e) + ), } TRANSFORMS.pop(exp.AnyValue) TRANSFORMS.pop(exp.DateDiff) diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index e27ba18..e4bb30e 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -129,10 +129,20 @@ class Spark2(Hive): "SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift), "STRING": _parse_as_cast("string"), "TIMESTAMP": _parse_as_cast("timestamp"), - "TO_TIMESTAMP": lambda args: _parse_as_cast("timestamp")(args) - if len(args) == 1 - else format_time_lambda(exp.StrToTime, "spark")(args), + "TO_TIMESTAMP": lambda args: ( + _parse_as_cast("timestamp")(args) + if len(args) == 1 + else format_time_lambda(exp.StrToTime, "spark")(args) + ), "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, + "TO_UTC_TIMESTAMP": lambda args: exp.FromTimeZone( + this=exp.cast_unless( + seq_get(args, 0) or exp.Var(this=""), + exp.DataType.build("timestamp"), + exp.DataType.build("timestamp"), + ), + zone=seq_get(args, 1), + ), "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), "WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), } @@ -188,6 +198,7 @@ class Spark2(Hive): exp.DayOfYear: rename_func("DAYOFYEAR"), exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", exp.From: transforms.preprocess([_unalias_pivot]), + exp.FromTimeZone: lambda self, e: f"TO_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalOr: rename_func("BOOL_OR"), exp.Map: _map_sql, @@ -255,10 +266,12 @@ class Spark2(Hive): def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: return super().columndef_sql( expression, - sep=": " - if isinstance(expression.parent, exp.DataType) - and expression.parent.is_type("struct") - else sep, + sep=( + ": " + if isinstance(expression.parent, exp.DataType) + and expression.parent.is_type("struct") + else sep + ), ) class Tokenizer(Hive.Tokenizer): diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index 33ec7e1..3795045 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -38,3 +38,4 @@ class Tableau(Dialect): **parser.Parser.FUNCTIONS, "COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)), } + NO_PAREN_IF_COMMANDS = False diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index b9c347c..a5e04da 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -76,9 +76,11 @@ def _format_time_lambda( format=exp.Literal.string( format_time( args[0].name.lower(), - {**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING} - if full_format_mapping - else TSQL.TIME_MAPPING, + ( + {**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING} + if full_format_mapping + else TSQL.TIME_MAPPING + ), ) ), ) @@ -264,6 +266,15 @@ def _parse_timefromparts(args: t.List) -> exp.TimeFromParts: ) +def _parse_len(args: t.List) -> exp.Length: + this = seq_get(args, 0) + + if this and not this.is_string: + this = exp.cast(this, exp.DataType.Type.TEXT) + + return exp.Length(this=this) + + class TSQL(Dialect): NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'" @@ -431,7 +442,7 @@ class TSQL(Dialect): "IIF": exp.If.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list, "JSON_VALUE": exp.JSONExtractScalar.from_arg_list, - "LEN": exp.Length.from_arg_list, + "LEN": _parse_len, "REPLICATE": exp.Repeat.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "SYSDATETIME": exp.CurrentTimestamp.from_arg_list, @@ -469,6 +480,7 @@ class TSQL(Dialect): ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False STRING_ALIASES = True + NO_PAREN_IF_COMMANDS = False def _parse_projections(self) -> t.List[exp.Expression]: """ @@ -478,9 +490,11 @@ class TSQL(Dialect): See: https://learn.microsoft.com/en-us/sql/t-sql/queries/select-clause-transact-sql?view=sql-server-ver16#syntax """ return [ - exp.alias_(projection.expression, projection.this.this, copy=False) - if isinstance(projection, exp.EQ) and isinstance(projection.this, exp.Column) - else projection + ( + exp.alias_(projection.expression, projection.this.this, copy=False) + if isinstance(projection, exp.EQ) and isinstance(projection.this, exp.Column) + else projection + ) for projection in super()._parse_projections() ] @@ -702,7 +716,6 @@ class TSQL(Dialect): exp.GroupConcat: _string_agg_sql, exp.If: rename_func("IIF"), exp.LastDay: lambda self, e: self.func("EOMONTH", e.this), - exp.Length: rename_func("LEN"), exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this), exp.Min: min_or_least, @@ -922,3 +935,11 @@ class TSQL(Dialect): this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True, sep=" ") return f"CONSTRAINT {this} {expressions}" + + def length_sql(self, expression: exp.Length) -> str: + this = expression.this + if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.TEXT): + this_sql = self.sql(this, "this") + else: + this_sql = self.sql(this) + return self.func("LEN", this_sql) -- cgit v1.2.3