From 20739a12c39121a9e7ad3c9a2469ec5a6876199d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 3 Jun 2023 01:59:40 +0200 Subject: Merging upstream version 15.0.0. Signed-off-by: Daniel Baumann --- sqlglot/dialects/bigquery.py | 97 ++++++++------ sqlglot/dialects/clickhouse.py | 279 +++++++++++++++++++++++++++++------------ sqlglot/dialects/databricks.py | 2 +- sqlglot/dialects/dialect.py | 99 +++++++++++---- sqlglot/dialects/drill.py | 8 +- sqlglot/dialects/duckdb.py | 33 +++-- sqlglot/dialects/hive.py | 38 +++--- sqlglot/dialects/mysql.py | 64 +++++----- sqlglot/dialects/oracle.py | 10 +- sqlglot/dialects/postgres.py | 45 ++++--- sqlglot/dialects/presto.py | 94 ++++++++------ sqlglot/dialects/redshift.py | 58 ++++----- sqlglot/dialects/snowflake.py | 111 ++++++++-------- sqlglot/dialects/spark.py | 6 +- sqlglot/dialects/spark2.py | 96 +++++++++----- sqlglot/dialects/sqlite.py | 19 +-- sqlglot/dialects/starrocks.py | 10 +- sqlglot/dialects/tableau.py | 39 +++--- sqlglot/dialects/teradata.py | 14 ++- sqlglot/dialects/trino.py | 2 +- sqlglot/dialects/tsql.py | 90 +++++++------ 21 files changed, 744 insertions(+), 470 deletions(-) (limited to 'sqlglot/dialects') diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 9705b35..1a58337 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -1,5 +1,3 @@ -"""Supports BigQuery Standard SQL.""" - from __future__ import annotations import re @@ -18,11 +16,9 @@ from sqlglot.dialects.dialect import ( timestrtotime_sql, ts_or_ds_to_date_sql, ) -from sqlglot.helper import seq_get +from sqlglot.helper import seq_get, split_num_words from sqlglot.tokens import TokenType -E = t.TypeVar("E", bound=exp.Expression) - def _date_add_sql( data_type: str, kind: str @@ -96,19 +92,12 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression: These are added by the optimizer's qualify_column step. """ if isinstance(expression, exp.Select): - unnests = { - unnest.alias - for unnest in expression.args.get("from", exp.From(expressions=[])).expressions - if isinstance(unnest, exp.Unnest) and unnest.alias - } - - if unnests: - expression = expression.copy() - - for select in expression.expressions: - for column in select.find_all(exp.Column): - if column.table in unnests: - column.set("table", None) + for unnest in expression.find_all(exp.Unnest): + if isinstance(unnest.parent, (exp.From, exp.Join)) and unnest.alias: + for select in expression.selects: + for column in select.find_all(exp.Column): + if column.table == unnest.alias: + column.set("table", None) return expression @@ -127,16 +116,20 @@ class BigQuery(Dialect): } class Tokenizer(tokens.Tokenizer): - QUOTES = [ - (prefix + quote, quote) if prefix else quote - for quote in ["'", '"', '"""', "'''"] - for prefix in ["", "r", "R"] - ] + QUOTES = ["'", '"', '"""', "'''"] COMMENTS = ["--", "#", ("/*", "*/")] IDENTIFIERS = ["`"] STRING_ESCAPES = ["\\"] + HEX_STRINGS = [("0x", ""), ("0X", "")] - BYTE_STRINGS = [("b'", "'"), ("B'", "'")] + + BYTE_STRINGS = [ + (prefix + q, q) for q in t.cast(t.List[str], QUOTES) for prefix in ("b", "B") + ] + + RAW_STRINGS = [ + (prefix + q, q) for q in t.cast(t.List[str], QUOTES) for prefix in ("r", "R") + ] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -144,11 +137,11 @@ class BigQuery(Dialect): "BEGIN": TokenType.COMMAND, "BEGIN TRANSACTION": TokenType.BEGIN, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, + "BYTES": TokenType.BINARY, "DECLARE": TokenType.COMMAND, - "GEOGRAPHY": TokenType.GEOGRAPHY, "FLOAT64": TokenType.DOUBLE, "INT64": TokenType.BIGINT, - "BYTES": TokenType.BINARY, + "RECORD": TokenType.STRUCT, "NOT DETERMINISTIC": TokenType.VOLATILE, "UNKNOWN": TokenType.NULL, } @@ -161,7 +154,7 @@ class BigQuery(Dialect): LOG_DEFAULTS_TO_LN = True FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "DATE_TRUNC": lambda args: exp.DateTrunc( unit=exp.Literal.string(str(seq_get(args, 1))), this=seq_get(args, 0), @@ -191,28 +184,28 @@ class BigQuery(Dialect): } FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, # type: ignore + **parser.Parser.FUNCTION_PARSERS, "ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]), } FUNCTION_PARSERS.pop("TRIM") NO_PAREN_FUNCTIONS = { - **parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore + **parser.Parser.NO_PAREN_FUNCTIONS, TokenType.CURRENT_DATETIME: exp.CurrentDatetime, } NESTED_TYPE_TOKENS = { - *parser.Parser.NESTED_TYPE_TOKENS, # type: ignore + *parser.Parser.NESTED_TYPE_TOKENS, TokenType.TABLE, } ID_VAR_TOKENS = { - *parser.Parser.ID_VAR_TOKENS, # type: ignore + *parser.Parser.ID_VAR_TOKENS, TokenType.VALUES, } PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, # type: ignore + **parser.Parser.PROPERTY_PARSERS, "NOT DETERMINISTIC": lambda self: self.expression( exp.StabilityProperty, this=exp.Literal.string("VOLATILE") ), @@ -220,19 +213,50 @@ class BigQuery(Dialect): } CONSTRAINT_PARSERS = { - **parser.Parser.CONSTRAINT_PARSERS, # type: ignore + **parser.Parser.CONSTRAINT_PARSERS, "OPTIONS": lambda self: exp.Properties(expressions=self._parse_with_property()), } + def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: + this = super()._parse_table_part(schema=schema) + + # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#table_names + if isinstance(this, exp.Identifier): + table_name = this.name + while self._match(TokenType.DASH, advance=False) and self._next: + self._advance(2) + table_name += f"-{self._prev.text}" + + this = exp.Identifier(this=table_name, quoted=this.args.get("quoted")) + + return this + + def _parse_table_parts(self, schema: bool = False) -> exp.Table: + table = super()._parse_table_parts(schema=schema) + if isinstance(table.this, exp.Identifier) and "." in table.name: + catalog, db, this, *rest = ( + t.cast(t.Optional[exp.Expression], exp.to_identifier(x)) + for x in split_num_words(table.name, ".", 3) + ) + + if rest and this: + this = exp.Dot.build(t.cast(t.List[exp.Expression], [this, *rest])) + + table = exp.Table(this=this, db=db, catalog=catalog) + + return table + class Generator(generator.Generator): EXPLICIT_UNION = True INTERVAL_ALLOWS_PLURAL_FORM = False JOIN_HINTS = False TABLE_HINTS = False LIMIT_FETCH = "LIMIT" + RENAME_TABLE_WITH_DB = False TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, + exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.AtTimeZone: lambda self, e: self.func( "TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone")) @@ -259,6 +283,7 @@ class BigQuery(Dialect): exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), exp.TimeStrToTime: timestrtotime_sql, + exp.TryCast: lambda self, e: f"SAFE_CAST({self.sql(e, 'this')} AS {self.sql(e, 'to')})", exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"), exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", @@ -274,7 +299,7 @@ class BigQuery(Dialect): } TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC", exp.DataType.Type.BIGINT: "INT64", exp.DataType.Type.BINARY: "BYTES", @@ -297,7 +322,7 @@ class BigQuery(Dialect): } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 2a49066..c8a9525 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -3,11 +3,16 @@ from __future__ import annotations import typing as t from sqlglot import exp, generator, parser, tokens -from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql +from sqlglot.dialects.dialect import ( + Dialect, + inline_array_sql, + no_pivot_sql, + rename_func, + var_map_sql, +) from sqlglot.errors import ParseError -from sqlglot.helper import ensure_list, seq_get from sqlglot.parser import parse_var_map -from sqlglot.tokens import TokenType +from sqlglot.tokens import Token, TokenType def _lower_func(sql: str) -> str: @@ -28,65 +33,122 @@ class ClickHouse(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "ASOF": TokenType.ASOF, - "GLOBAL": TokenType.GLOBAL, - "DATETIME64": TokenType.DATETIME, + "ATTACH": TokenType.COMMAND, + "DATETIME64": TokenType.DATETIME64, "FINAL": TokenType.FINAL, "FLOAT32": TokenType.FLOAT, "FLOAT64": TokenType.DOUBLE, - "INT8": TokenType.TINYINT, - "UINT8": TokenType.UTINYINT, + "GLOBAL": TokenType.GLOBAL, + "INT128": TokenType.INT128, "INT16": TokenType.SMALLINT, - "UINT16": TokenType.USMALLINT, + "INT256": TokenType.INT256, "INT32": TokenType.INT, - "UINT32": TokenType.UINT, "INT64": TokenType.BIGINT, - "UINT64": TokenType.UBIGINT, - "INT128": TokenType.INT128, + "INT8": TokenType.TINYINT, + "MAP": TokenType.MAP, + "TUPLE": TokenType.STRUCT, "UINT128": TokenType.UINT128, - "INT256": TokenType.INT256, + "UINT16": TokenType.USMALLINT, "UINT256": TokenType.UINT256, - "TUPLE": TokenType.STRUCT, + "UINT32": TokenType.UINT, + "UINT64": TokenType.UBIGINT, + "UINT8": TokenType.UTINYINT, } class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore - "EXPONENTIALTIMEDECAYEDAVG": lambda params, args: exp.ExponentialTimeDecayedAvg( - this=seq_get(args, 0), - time=seq_get(args, 1), - decay=seq_get(params, 0), - ), - "GROUPUNIQARRAY": lambda params, args: exp.GroupUniqArray( - this=seq_get(args, 0), size=seq_get(params, 0) - ), - "HISTOGRAM": lambda params, args: exp.Histogram( - this=seq_get(args, 0), bins=seq_get(params, 0) - ), + **parser.Parser.FUNCTIONS, + "ANY": exp.AnyValue.from_arg_list, "MAP": parse_var_map, "MATCH": exp.RegexpLike.from_arg_list, - "QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params), - "QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args), - "QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args), + "UNIQ": exp.ApproxDistinct.from_arg_list, + } + + FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"} + + FUNCTION_PARSERS = { + **parser.Parser.FUNCTION_PARSERS, + "QUANTILE": lambda self: self._parse_quantile(), } - FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy() FUNCTION_PARSERS.pop("MATCH") + NO_PAREN_FUNCTION_PARSERS = parser.Parser.NO_PAREN_FUNCTION_PARSERS.copy() + NO_PAREN_FUNCTION_PARSERS.pop(TokenType.ANY) + RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, TokenType.GLOBAL: lambda self, this: self._match(TokenType.IN) and self._parse_in(this, is_global=True), } - JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore + # The PLACEHOLDER entry is popped because 1) it doesn't affect Clickhouse (it corresponds to + # the postgres-specific JSONBContains parser) and 2) it makes parsing the ternary op simpler. + COLUMN_OPERATORS = parser.Parser.COLUMN_OPERATORS.copy() + COLUMN_OPERATORS.pop(TokenType.PLACEHOLDER) + + JOIN_KINDS = { + *parser.Parser.JOIN_KINDS, + TokenType.ANY, + TokenType.ASOF, + TokenType.ANTI, + TokenType.SEMI, + } - TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore + TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - { + TokenType.ANY, + TokenType.ASOF, + TokenType.SEMI, + TokenType.ANTI, + TokenType.SETTINGS, + TokenType.FORMAT, + } LOG_DEFAULTS_TO_LN = True - def _parse_in( - self, this: t.Optional[exp.Expression], is_global: bool = False - ) -> exp.Expression: + QUERY_MODIFIER_PARSERS = { + **parser.Parser.QUERY_MODIFIER_PARSERS, + "settings": lambda self: self._parse_csv(self._parse_conjunction) + if self._match(TokenType.SETTINGS) + else None, + "format": lambda self: self._parse_id_var() if self._match(TokenType.FORMAT) else None, + } + + def _parse_conjunction(self) -> t.Optional[exp.Expression]: + this = super()._parse_conjunction() + + if self._match(TokenType.PLACEHOLDER): + return self.expression( + exp.If, + this=this, + true=self._parse_conjunction(), + false=self._match(TokenType.COLON) and self._parse_conjunction(), + ) + + return this + + def _parse_placeholder(self) -> t.Optional[exp.Expression]: + """ + Parse a placeholder expression like SELECT {abc: UInt32} or FROM {table: Identifier} + https://clickhouse.com/docs/en/sql-reference/syntax#defining-and-using-query-parameters + """ + if not self._match(TokenType.L_BRACE): + return None + + this = self._parse_id_var() + self._match(TokenType.COLON) + kind = self._parse_types(check_func=False) or ( + self._match_text_seq("IDENTIFIER") and "Identifier" + ) + + if not kind: + self.raise_error("Expecting a placeholder type or 'Identifier' for tables") + elif not self._match(TokenType.R_BRACE): + self.raise_error("Expecting }") + + return self.expression(exp.Placeholder, this=this, kind=kind) + + def _parse_in(self, this: t.Optional[exp.Expression], is_global: bool = False) -> exp.In: this = super()._parse_in(this) this.set("is_global", is_global) return this @@ -120,81 +182,142 @@ class ClickHouse(Dialect): return self.expression(exp.CTE, this=statement, alias=statement and statement.this) + def _parse_join_side_and_kind( + self, + ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]: + is_global = self._match(TokenType.GLOBAL) and self._prev + kind_pre = self._match_set(self.JOIN_KINDS, advance=False) and self._prev + if kind_pre: + kind = self._match_set(self.JOIN_KINDS) and self._prev + side = self._match_set(self.JOIN_SIDES) and self._prev + return is_global, side, kind + return ( + is_global, + self._match_set(self.JOIN_SIDES) and self._prev, + self._match_set(self.JOIN_KINDS) and self._prev, + ) + + def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]: + join = super()._parse_join(skip_join_token) + + if join: + join.set("global", join.args.pop("natural", None)) + return join + + def _parse_function( + self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False + ) -> t.Optional[exp.Expression]: + func = super()._parse_function(functions, anonymous) + + if isinstance(func, exp.Anonymous): + params = self._parse_func_params(func) + + if params: + return self.expression( + exp.ParameterizedAgg, + this=func.this, + expressions=func.expressions, + params=params, + ) + + return func + + def _parse_func_params( + self, this: t.Optional[exp.Func] = None + ) -> t.Optional[t.List[t.Optional[exp.Expression]]]: + if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN): + return self._parse_csv(self._parse_lambda) + if self._match(TokenType.L_PAREN): + params = self._parse_csv(self._parse_lambda) + self._match_r_paren(this) + return params + return None + + def _parse_quantile(self) -> exp.Quantile: + this = self._parse_lambda() + params = self._parse_func_params() + if params: + return self.expression(exp.Quantile, this=params[0], quantile=this) + return self.expression(exp.Quantile, this=this, quantile=exp.Literal.number(0.5)) + + def _parse_wrapped_id_vars( + self, optional: bool = False + ) -> t.List[t.Optional[exp.Expression]]: + return super()._parse_wrapped_id_vars(optional=True) + class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore - exp.DataType.Type.NULLABLE: "Nullable", - exp.DataType.Type.DATETIME: "DateTime64", - exp.DataType.Type.MAP: "Map", + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.ARRAY: "Array", + exp.DataType.Type.BIGINT: "Int64", + exp.DataType.Type.DATETIME64: "DateTime64", + exp.DataType.Type.DOUBLE: "Float64", + exp.DataType.Type.FLOAT: "Float32", + exp.DataType.Type.INT: "Int32", + exp.DataType.Type.INT128: "Int128", + exp.DataType.Type.INT256: "Int256", + exp.DataType.Type.MAP: "Map", + exp.DataType.Type.NULLABLE: "Nullable", + exp.DataType.Type.SMALLINT: "Int16", exp.DataType.Type.STRUCT: "Tuple", exp.DataType.Type.TINYINT: "Int8", - exp.DataType.Type.UTINYINT: "UInt8", - exp.DataType.Type.SMALLINT: "Int16", - exp.DataType.Type.USMALLINT: "UInt16", - exp.DataType.Type.INT: "Int32", - exp.DataType.Type.UINT: "UInt32", - exp.DataType.Type.BIGINT: "Int64", exp.DataType.Type.UBIGINT: "UInt64", - exp.DataType.Type.INT128: "Int128", + exp.DataType.Type.UINT: "UInt32", exp.DataType.Type.UINT128: "UInt128", - exp.DataType.Type.INT256: "Int256", exp.DataType.Type.UINT256: "UInt256", - exp.DataType.Type.FLOAT: "Float32", - exp.DataType.Type.DOUBLE: "Float64", + exp.DataType.Type.USMALLINT: "UInt16", + exp.DataType.Type.UTINYINT: "UInt8", } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, + exp.AnyValue: rename_func("any"), + exp.ApproxDistinct: rename_func("uniq"), exp.Array: inline_array_sql, - exp.ExponentialTimeDecayedAvg: lambda self, e: f"exponentialTimeDecayedAvg{self._param_args_sql(e, 'decay', ['this', 'time'])}", + exp.CastToStrType: rename_func("CAST"), exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", - exp.GroupUniqArray: lambda self, e: f"groupUniqArray{self._param_args_sql(e, 'size', 'this')}", - exp.Histogram: lambda self, e: f"histogram{self._param_args_sql(e, 'bins', 'this')}", exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)), - exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}", - exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}", - exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}", + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", + exp.Pivot: no_pivot_sql, + exp.Quantile: lambda self, e: self.func("quantile", e.args.get("quantile")) + + f"({self.sql(e, 'this')})", exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})", exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, } JOIN_HINTS = False TABLE_HINTS = False EXPLICIT_UNION = True - - def _param_args_sql( - self, - expression: exp.Expression, - param_names: str | t.List[str], - arg_names: str | t.List[str], - ) -> str: - params = self.format_args( - *( - arg - for name in ensure_list(param_names) - for arg in ensure_list(expression.args.get(name)) - ) - ) - args = self.format_args( - *( - arg - for name in ensure_list(arg_names) - for arg in ensure_list(expression.args.get(name)) - ) - ) - return f"({params})({args})" + GROUPINGS_SEP = "" def cte_sql(self, expression: exp.CTE) -> str: if isinstance(expression.this, exp.Alias): return self.sql(expression, "this") return super().cte_sql(expression) + + def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]: + return super().after_limit_modifiers(expression) + [ + self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True) + if expression.args.get("settings") + else "", + self.seg("FORMAT ") + self.sql(expression, "format") + if expression.args.get("format") + else "", + ] + + def parameterizedagg_sql(self, expression: exp.Anonymous) -> str: + params = self.expressions(expression, "params", flat=True) + return self.func(expression.name, *expression.expressions) + f"({params})" + + def placeholder_sql(self, expression: exp.Placeholder) -> str: + return f"{{{expression.name}: {self.sql(expression, 'kind')}}}" diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 51112a0..2149aca 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -25,7 +25,7 @@ class Databricks(Spark): class Generator(Spark.Generator): TRANSFORMS = { - **Spark.Generator.TRANSFORMS, # type: ignore + **Spark.Generator.TRANSFORMS, exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, exp.JSONExtract: lambda self, e: self.binary(e, ":"), diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 71269f2..890a3c3 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -8,10 +8,16 @@ from sqlglot.generator import Generator from sqlglot.helper import flatten, seq_get from sqlglot.parser import Parser from sqlglot.time import format_time -from sqlglot.tokens import Token, Tokenizer +from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import new_trie -E = t.TypeVar("E", bound=exp.Expression) +if t.TYPE_CHECKING: + from sqlglot._typing import E + + +# Only Snowflake is currently known to resolve unquoted identifiers as uppercase. +# https://docs.snowflake.com/en/sql-reference/identifiers-syntax +RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"} class Dialects(str, Enum): @@ -42,6 +48,19 @@ class Dialects(str, Enum): class _Dialect(type): classes: t.Dict[str, t.Type[Dialect]] = {} + def __eq__(cls, other: t.Any) -> bool: + if cls is other: + return True + if isinstance(other, str): + return cls is cls.get(other) + if isinstance(other, Dialect): + return cls is type(other) + + return False + + def __hash__(cls) -> int: + return hash(cls.__name__.lower()) + @classmethod def __getitem__(cls, key: str) -> t.Type[Dialect]: return cls.classes[key] @@ -70,17 +89,20 @@ class _Dialect(type): klass.tokenizer_class._IDENTIFIERS.items() )[0] - klass.bit_start, klass.bit_end = seq_get( - list(klass.tokenizer_class._BIT_STRINGS.items()), 0 - ) or (None, None) - - klass.hex_start, klass.hex_end = seq_get( - list(klass.tokenizer_class._HEX_STRINGS.items()), 0 - ) or (None, None) + def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: + return next( + ( + (s, e) + for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() + if t == token_type + ), + (None, None), + ) - klass.byte_start, klass.byte_end = seq_get( - list(klass.tokenizer_class._BYTE_STRINGS.items()), 0 - ) or (None, None) + klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING) + klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING) + klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING) + klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING) return klass @@ -110,6 +132,12 @@ class Dialect(metaclass=_Dialect): parser_class = None generator_class = None + def __eq__(self, other: t.Any) -> bool: + return type(self) == other + + def __hash__(self) -> int: + return hash(type(self)) + @classmethod def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: if not dialect: @@ -192,6 +220,8 @@ class Dialect(metaclass=_Dialect): "hex_end": self.hex_end, "byte_start": self.byte_start, "byte_end": self.byte_end, + "raw_start": self.raw_start, + "raw_end": self.raw_end, "identifier_start": self.identifier_start, "identifier_end": self.identifier_end, "string_escape": self.tokenizer_class.STRING_ESCAPES[0], @@ -275,7 +305,7 @@ def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: self.unsupported("PIVOT unsupported") - return self.sql(expression) + return "" def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: @@ -328,7 +358,7 @@ def var_map_sql( def format_time_lambda( exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None -) -> t.Callable[[t.Sequence], E]: +) -> t.Callable[[t.List], E]: """Helper used for time expressions. Args: @@ -340,7 +370,7 @@ def format_time_lambda( A callable that can be used to return the appropriately formatted time expression. """ - def _format_time(args: t.Sequence): + def _format_time(args: t.List): return exp_class( this=seq_get(args, 0), format=Dialect[dialect].format_time( @@ -377,12 +407,12 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: def parse_date_delta( exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None -) -> t.Callable[[t.Sequence], E]: - def inner_func(args: t.Sequence) -> E: +) -> t.Callable[[t.List], E]: + def inner_func(args: t.List) -> E: unit_based = len(args) == 3 this = args[2] if unit_based else seq_get(args, 0) unit = args[0] if unit_based else exp.Literal.string("DAY") - unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit + unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit return exp_class(this=this, expression=seq_get(args, 1), unit=unit) return inner_func @@ -390,8 +420,8 @@ def parse_date_delta( def parse_date_delta_with_interval( expression_class: t.Type[E], -) -> t.Callable[[t.Sequence], t.Optional[E]]: - def func(args: t.Sequence) -> t.Optional[E]: +) -> t.Callable[[t.List], t.Optional[E]]: + def func(args: t.List) -> t.Optional[E]: if len(args) < 2: return None @@ -409,7 +439,7 @@ def parse_date_delta_with_interval( return func -def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc: +def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: unit = seq_get(args, 0) this = seq_get(args, 1) @@ -424,7 +454,7 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: ) -def locate_to_strposition(args: t.Sequence) -> exp.Expression: +def locate_to_strposition(args: t.List) -> exp.Expression: return exp.StrPosition( this=seq_get(args, 1), substr=seq_get(args, 0), @@ -483,7 +513,7 @@ def trim_sql(self: Generator, expression: exp.Trim) -> str: return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" -def str_to_time_sql(self, expression: exp.Expression) -> str: +def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: return self.func("STRPTIME", expression.this, self.format_time(expression)) @@ -496,3 +526,26 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: return f"CAST({self.sql(expression, 'this')} AS DATE)" return _ts_or_ds_to_date_sql + + +# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator +def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: + names = [] + for agg in aggregations: + if isinstance(agg, exp.Alias): + names.append(agg.alias) + else: + """ + This case corresponds to aggregations without aliases being used as suffixes + (e.g. col_avg(foo)). We need to unquote identifiers because they're going to + be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. + Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). + """ + agg_all_unquoted = agg.transform( + lambda node: exp.Identifier(this=node.name, quoted=False) + if isinstance(node, exp.Identifier) + else node + ) + names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) + + return names diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 7ad555e..924b979 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -95,7 +95,7 @@ class Drill(Dialect): STRICT_CAST = False FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "drill"), "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, "TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"), @@ -108,7 +108,7 @@ class Drill(Dialect): TABLE_HINTS = False TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.SMALLINT: "INTEGER", exp.DataType.Type.TINYINT: "INTEGER", @@ -121,13 +121,13 @@ class Drill(Dialect): } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.ArrayContains: rename_func("REPEATED_CONTAINS"), exp.ArraySize: rename_func("REPEATED_COUNT"), diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index bce956e..662882d 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -11,9 +11,9 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, format_time_lambda, no_comment_column_constraint_sql, - no_pivot_sql, no_properties_sql, no_safe_divide_sql, + pivot_column_names, rename_func, str_position_sql, str_to_time_sql, @@ -31,10 +31,11 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" -def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str: +def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" - return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" + op = "+" if isinstance(expression, exp.DateAdd) else "-" + return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}" def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str: @@ -50,11 +51,11 @@ def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str return f"ARRAY_SORT({this})" -def _sort_array_reverse(args: t.Sequence) -> exp.Expression: +def _sort_array_reverse(args: t.List) -> exp.Expression: return exp.SortArray(this=seq_get(args, 0), asc=exp.false()) -def _parse_date_diff(args: t.Sequence) -> exp.Expression: +def _parse_date_diff(args: t.List) -> exp.Expression: return exp.DateDiff( this=seq_get(args, 2), expression=seq_get(args, 1), @@ -89,11 +90,14 @@ def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract class DuckDB(Dialect): + null_ordering = "nulls_are_last" + class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "~": TokenType.RLIKE, ":=": TokenType.EQ, + "//": TokenType.DIV, "ATTACH": TokenType.COMMAND, "BINARY": TokenType.VARBINARY, "BPCHAR": TokenType.TEXT, @@ -104,6 +108,7 @@ class DuckDB(Dialect): "INT1": TokenType.TINYINT, "LOGICAL": TokenType.BOOLEAN, "NUMERIC": TokenType.DOUBLE, + "PIVOT_WIDER": TokenType.PIVOT, "SIGNED": TokenType.INT, "STRING": TokenType.VARCHAR, "UBIGINT": TokenType.UBIGINT, @@ -114,8 +119,7 @@ class DuckDB(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore - "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, + **parser.Parser.FUNCTIONS, "ARRAY_LENGTH": exp.ArraySize.from_arg_list, "ARRAY_SORT": exp.SortArray.from_arg_list, "ARRAY_REVERSE_SORT": _sort_array_reverse, @@ -152,11 +156,17 @@ class DuckDB(Dialect): TokenType.UTINYINT, } + def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: + if len(aggregations) == 1: + return super()._pivot_column_names(aggregations) + return pivot_column_names(aggregations, dialect="duckdb") + class Generator(generator.Generator): JOIN_HINTS = False TABLE_HINTS = False LIMIT_FETCH = "LIMIT" STRUCT_DELIMITER = ("(", ")") + RENAME_TABLE_WITH_DB = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -175,7 +185,8 @@ class DuckDB(Dialect): exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfYear: rename_func("DAYOFYEAR"), exp.DataType: _datatype_sql, - exp.DateAdd: _date_add_sql, + exp.DateAdd: _date_delta_sql, + exp.DateSub: _date_delta_sql, exp.DateDiff: lambda self, e: self.func( "DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this ), @@ -183,13 +194,13 @@ class DuckDB(Dialect): exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)", exp.Explode: rename_func("UNNEST"), + exp.IntDiv: lambda self, e: self.binary(e, "//"), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: arrow_json_extract_sql, exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), - exp.Pivot: no_pivot_sql, exp.Properties: no_properties_sql, exp.RegexpExtract: _regexp_extract_sql, exp.RegexpLike: rename_func("REGEXP_MATCHES"), @@ -232,11 +243,11 @@ class DuckDB(Dialect): STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"} PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } def tablesample_sql( - self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " + self, expression: exp.TableSample, seed_prefix: str = "SEED", sep: str = " AS " ) -> str: return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 871a180..fbd626a 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -147,13 +147,6 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str return f"TO_DATE({this})" -def _index_sql(self: generator.Generator, expression: exp.Index) -> str: - this = self.sql(expression, "this") - table = self.sql(expression, "table") - columns = self.sql(expression, "columns") - return f"{this} ON TABLE {table} {columns}" - - class Hive(Dialect): alias_post_tablesample = True @@ -225,8 +218,7 @@ class Hive(Dialect): STRICT_CAST = False FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore - "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, + **parser.Parser.FUNCTIONS, "BASE64": exp.ToBase64.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list, "DATE_ADD": lambda args: exp.TsOrDsAdd( @@ -271,21 +263,29 @@ class Hive(Dialect): } PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, # type: ignore + **parser.Parser.PROPERTY_PARSERS, "WITH SERDEPROPERTIES": lambda self: exp.SerdeProperties( expressions=self._parse_wrapped_csv(self._parse_property) ), } + QUERY_MODIFIER_PARSERS = { + **parser.Parser.QUERY_MODIFIER_PARSERS, + "distribute": lambda self: self._parse_sort(exp.Distribute, "DISTRIBUTE", "BY"), + "sort": lambda self: self._parse_sort(exp.Sort, "SORT", "BY"), + "cluster": lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"), + } + class Generator(generator.Generator): LIMIT_FETCH = "LIMIT" TABLESAMPLE_WITH_METHOD = False TABLESAMPLE_SIZE_IS_PERCENT = True JOIN_HINTS = False TABLE_HINTS = False + INDEX_ON = "ON TABLE" TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.DATETIME: "TIMESTAMP", exp.DataType.Type.VARBINARY: "BINARY", @@ -294,7 +294,7 @@ class Hive(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Select: transforms.preprocess( [ @@ -319,7 +319,6 @@ class Hive(Dialect): exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}", exp.FromBase64: rename_func("UNBASE64"), exp.If: if_sql, - exp.Index: _index_sql, exp.ILike: no_ilike_sql, exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), @@ -342,7 +341,6 @@ class Hive(Dialect): exp.StrToTime: _str_to_time_sql, exp.StrToUnix: _str_to_unix_sql, exp.StructExtract: struct_extract_sql, - exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}", exp.TimeStrToDate: rename_func("TO_DATE"), exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), @@ -363,14 +361,13 @@ class Hive(Dialect): exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"), exp.NumberToStr: rename_func("FORMAT_NUMBER"), exp.LastDateOfMonth: rename_func("LAST_DAY"), - exp.National: lambda self, e: self.sql(e, "this"), + exp.National: lambda self, e: self.national_sql(e, prefix=""), } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA, exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, - exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } @@ -396,3 +393,10 @@ class Hive(Dialect): expression = exp.DataType.build(expression.this) return super().datatype_sql(expression) + + def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]: + return super().after_having_modifiers(expression) + [ + self.sql(expression, "distribute"), + self.sql(expression, "sort"), + self.sql(expression, "cluster"), + ] diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 5342624..2b41860 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, @@ -11,6 +13,7 @@ from sqlglot.dialects.dialect import ( min_or_least, no_ilike_sql, no_paren_current_date_sql, + no_pivot_sql, no_tablesample_sql, no_trycast_sql, parse_date_delta_with_interval, @@ -21,14 +24,14 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _show_parser(*args, **kwargs): - def _parse(self): +def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], exp.Show]: + def _parse(self: MySQL.Parser) -> exp.Show: return self._parse_show_mysql(*args, **kwargs) return _parse -def _date_trunc_sql(self, expression): +def _date_trunc_sql(self: generator.Generator, expression: exp.DateTrunc) -> str: expr = self.sql(expression, "this") unit = expression.text("unit") @@ -54,17 +57,17 @@ def _date_trunc_sql(self, expression): return f"STR_TO_DATE({concat}, '{date_format}')" -def _str_to_date(args): +def _str_to_date(args: t.List) -> exp.StrToDate: date_format = MySQL.format_time(seq_get(args, 1)) return exp.StrToDate(this=seq_get(args, 0), format=date_format) -def _str_to_date_sql(self, expression): +def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate | exp.StrToTime) -> str: date_format = self.format_time(expression) return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})" -def _trim_sql(self, expression): +def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str: target = self.sql(expression, "this") trim_type = self.sql(expression, "position") remove_chars = self.sql(expression, "expression") @@ -79,8 +82,8 @@ def _trim_sql(self, expression): return f"TRIM({trim_type}{remove_chars}{from_part}{target})" -def _date_add_sql(kind): - def func(self, expression): +def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: + def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = expression.text("unit").upper() or "DAY" return ( @@ -175,10 +178,10 @@ class MySQL(Dialect): COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW} class Parser(parser.Parser): - FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} # type: ignore + FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"), "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), @@ -191,7 +194,7 @@ class MySQL(Dialect): } FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, # type: ignore + **parser.Parser.FUNCTION_PARSERS, "GROUP_CONCAT": lambda self: self.expression( exp.GroupConcat, this=self._parse_lambda(), @@ -199,13 +202,8 @@ class MySQL(Dialect): ), } - PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, # type: ignore - "ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty), - } - STATEMENT_PARSERS = { - **parser.Parser.STATEMENT_PARSERS, # type: ignore + **parser.Parser.STATEMENT_PARSERS, TokenType.SHOW: lambda self: self._parse_show(), } @@ -286,7 +284,13 @@ class MySQL(Dialect): LOG_DEFAULTS_TO_LN = True - def _parse_show_mysql(self, this, target=False, full=None, global_=None): + def _parse_show_mysql( + self, + this: str, + target: bool | str = False, + full: t.Optional[bool] = None, + global_: t.Optional[bool] = None, + ) -> exp.Show: if target: if isinstance(target, str): self._match_text_seq(target) @@ -342,10 +346,12 @@ class MySQL(Dialect): offset=offset, limit=limit, mutex=mutex, - **{"global": global_}, + **{"global": global_}, # type: ignore ) - def _parse_oldstyle_limit(self): + def _parse_oldstyle_limit( + self, + ) -> t.Tuple[t.Optional[exp.Expression], t.Optional[exp.Expression]]: limit = None offset = None if self._match_text_seq("LIMIT"): @@ -355,23 +361,20 @@ class MySQL(Dialect): elif len(parts) == 2: limit = parts[1] offset = parts[0] + return offset, limit - def _parse_set_item_charset(self, kind): + def _parse_set_item_charset(self, kind: str) -> exp.Expression: this = self._parse_string() or self._parse_id_var() + return self.expression(exp.SetItem, this=this, kind=kind) - return self.expression( - exp.SetItem, - this=this, - kind=kind, - ) - - def _parse_set_item_names(self): + def _parse_set_item_names(self) -> exp.Expression: charset = self._parse_string() or self._parse_id_var() if self._match_text_seq("COLLATE"): collate = self._parse_string() or self._parse_id_var() else: collate = None + return self.expression( exp.SetItem, this=charset, @@ -386,7 +389,7 @@ class MySQL(Dialect): TABLE_HINTS = False TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.CurrentDate: no_paren_current_date_sql, exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression), exp.DateAdd: _date_add_sql("ADD"), @@ -403,6 +406,7 @@ class MySQL(Dialect): exp.Min: min_or_least, exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), + exp.Pivot: no_pivot_sql, exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date_sql, @@ -422,7 +426,7 @@ class MySQL(Dialect): TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB) PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index c8af1c6..7722753 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -8,7 +8,7 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _parse_xml_table(self) -> exp.XMLTable: +def _parse_xml_table(self: parser.Parser) -> exp.XMLTable: this = self._parse_string() passing = None @@ -66,7 +66,7 @@ class Oracle(Dialect): WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP} FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), } @@ -107,7 +107,7 @@ class Oracle(Dialect): TABLE_HINTS = False TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "NUMBER", exp.DataType.Type.SMALLINT: "NUMBER", exp.DataType.Type.INT: "NUMBER", @@ -122,7 +122,7 @@ class Oracle(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.DateStrToDate: lambda self, e: self.func( "TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD") ), @@ -143,7 +143,7 @@ class Oracle(Dialect): } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 2132778..ab61880 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import ( max_or_greatest, min_or_least, no_paren_current_date_sql, + no_pivot_sql, no_tablesample_sql, no_trycast_sql, rename_func, @@ -33,8 +34,8 @@ DATE_DIFF_FACTOR = { } -def _date_add_sql(kind): - def func(self, expression): +def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: + def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: from sqlglot.optimizer.simplify import simplify this = self.sql(expression, "this") @@ -51,7 +52,7 @@ def _date_add_sql(kind): return func -def _date_diff_sql(self, expression): +def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: unit = expression.text("unit").upper() factor = DATE_DIFF_FACTOR.get(unit) @@ -77,7 +78,7 @@ def _date_diff_sql(self, expression): return f"CAST({unit} AS BIGINT)" -def _substring_sql(self, expression): +def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str: this = self.sql(expression, "this") start = self.sql(expression, "start") length = self.sql(expression, "length") @@ -88,7 +89,7 @@ def _substring_sql(self, expression): return f"SUBSTRING({this}{from_part}{for_part})" -def _string_agg_sql(self, expression): +def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str: expression = expression.copy() separator = expression.args.get("separator") or exp.Literal.string(",") @@ -102,13 +103,13 @@ def _string_agg_sql(self, expression): return f"STRING_AGG({self.format_args(this, separator)}{order})" -def _datatype_sql(self, expression): +def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: if expression.this == exp.DataType.Type.ARRAY: return f"{self.expressions(expression, flat=True)}[]" return self.datatype_sql(expression) -def _auto_increment_to_serial(expression): +def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression: auto = expression.find(exp.AutoIncrementColumnConstraint) if auto: @@ -126,7 +127,7 @@ def _auto_increment_to_serial(expression): return expression -def _serial_to_generated(expression): +def _serial_to_generated(expression: exp.Expression) -> exp.Expression: kind = expression.args["kind"] if kind.this == exp.DataType.Type.SERIAL: @@ -144,6 +145,7 @@ def _serial_to_generated(expression): constraints = expression.args["constraints"] generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False)) notnull = exp.ColumnConstraint(kind=exp.NotNullColumnConstraint()) + if notnull not in constraints: constraints.insert(0, notnull) if generated not in constraints: @@ -152,7 +154,7 @@ def _serial_to_generated(expression): return expression -def _generate_series(args): +def _generate_series(args: t.List) -> exp.Expression: # The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day step = seq_get(args, 2) @@ -168,11 +170,12 @@ def _generate_series(args): return exp.GenerateSeries.from_arg_list(args) -def _to_timestamp(args): +def _to_timestamp(args: t.List) -> exp.Expression: # TO_TIMESTAMP accepts either a single double argument or (text, text) if len(args) == 1: # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE return exp.UnixToTime.from_arg_list(args) + # https://www.postgresql.org/docs/current/functions-formatting.html return format_time_lambda(exp.StrToTime, "postgres")(args) @@ -255,7 +258,7 @@ class Postgres(Dialect): STRICT_CAST = False FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=seq_get(args, 0) ), @@ -271,7 +274,7 @@ class Postgres(Dialect): } BITWISE = { - **parser.Parser.BITWISE, # type: ignore + **parser.Parser.BITWISE, TokenType.HASH: exp.BitwiseXor, } @@ -280,7 +283,7 @@ class Postgres(Dialect): } RANGE_PARSERS = { - **parser.Parser.RANGE_PARSERS, # type: ignore + **parser.Parser.RANGE_PARSERS, TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps), TokenType.AT_GT: binary_range_parser(exp.ArrayContains), TokenType.LT_AT: binary_range_parser(exp.ArrayContained), @@ -303,14 +306,14 @@ class Postgres(Dialect): return self.expression(exp.Extract, this=part, expression=value) class Generator(generator.Generator): - INTERVAL_ALLOWS_PLURAL_FORM = False + SINGLE_STRING_INTERVAL = True LOCKING_READS_SUPPORTED = True JOIN_HINTS = False TABLE_HINTS = False PARAMETER_TOKEN = "$" TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "SMALLINT", exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", @@ -320,14 +323,9 @@ class Postgres(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.BitwiseXor: lambda self, e: self.binary(e, "#"), - exp.ColumnDef: transforms.preprocess( - [ - _auto_increment_to_serial, - _serial_to_generated, - ], - ), + exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), @@ -348,6 +346,7 @@ class Postgres(Dialect): exp.ArrayContains: lambda self, e: self.binary(e, "@>"), exp.ArrayContained: lambda self, e: self.binary(e, "<@"), exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]), + exp.Pivot: no_pivot_sql, exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), exp.StrPosition: str_position_sql, @@ -369,7 +368,7 @@ class Postgres(Dialect): } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 6133a27..52a04a4 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import ( format_time_lambda, if_sql, no_ilike_sql, + no_pivot_sql, no_safe_divide_sql, rename_func, struct_extract_sql, @@ -127,39 +128,12 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s ) -def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str: - start = expression.args["start"] - end = expression.args["end"] - step = expression.args.get("step") - - target_type = None - - if isinstance(start, exp.Cast): - target_type = start.to - elif isinstance(end, exp.Cast): - target_type = end.to - - if target_type and target_type.this == exp.DataType.Type.TIMESTAMP: - to = target_type.copy() - - if target_type is start.to: - end = exp.Cast(this=end, to=to) - else: - start = exp.Cast(this=start, to=to) - - sql = self.func("SEQUENCE", start, end, step) - if isinstance(expression.parent, exp.Table): - sql = f"UNNEST({sql})" - - return sql - - def _ensure_utf8(charset: exp.Literal) -> None: if charset.name.lower() != "utf-8": raise UnsupportedError(f"Unsupported charset {charset}") -def _approx_percentile(args: t.Sequence) -> exp.Expression: +def _approx_percentile(args: t.List) -> exp.Expression: if len(args) == 4: return exp.ApproxQuantile( this=seq_get(args, 0), @@ -176,7 +150,7 @@ def _approx_percentile(args: t.Sequence) -> exp.Expression: return exp.ApproxQuantile.from_arg_list(args) -def _from_unixtime(args: t.Sequence) -> exp.Expression: +def _from_unixtime(args: t.List) -> exp.Expression: if len(args) == 3: return exp.UnixToTime( this=seq_get(args, 0), @@ -191,22 +165,39 @@ def _from_unixtime(args: t.Sequence) -> exp.Expression: return exp.UnixToTime.from_arg_list(args) +def _unnest_sequence(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Table): + if isinstance(expression.this, exp.GenerateSeries): + unnest = exp.Unnest(expressions=[expression.this]) + + if expression.alias: + return exp.alias_( + unnest, + alias="_u", + table=[expression.alias], + copy=False, + ) + return unnest + return expression + + class Presto(Dialect): index_offset = 1 null_ordering = "nulls_are_last" - time_format = MySQL.time_format # type: ignore - time_mapping = MySQL.time_mapping # type: ignore + time_format = MySQL.time_format + time_mapping = MySQL.time_mapping class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "START": TokenType.BEGIN, + "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "ROW": TokenType.STRUCT, } class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, "APPROX_PERCENTILE": _approx_percentile, "CARDINALITY": exp.ArraySize.from_arg_list, @@ -252,13 +243,13 @@ class Presto(Dialect): STRUCT_DELIMITER = ("(", ")") PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.LocationProperty: exp.Properties.Location.UNSUPPORTED, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.BINARY: "VARBINARY", @@ -268,8 +259,9 @@ class Presto(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.ApproxDistinct: _approx_distinct_sql, + exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayContains: rename_func("CONTAINS"), @@ -293,7 +285,7 @@ class Presto(Dialect): exp.Decode: _decode_sql, exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", exp.Encode: _encode_sql, - exp.GenerateSeries: _sequence_sql, + exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Hex: rename_func("TO_HEX"), exp.If: if_sql, @@ -301,10 +293,10 @@ class Presto(Dialect): exp.Initcap: _initcap_sql, exp.Lateral: _explode_to_unnest_sql, exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), - exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), + exp.LogicalOr: rename_func("BOOL_OR"), + exp.Pivot: no_pivot_sql, exp.Quantile: _quantile_sql, - exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, exp.Select: transforms.preprocess( @@ -320,8 +312,7 @@ class Presto(Dialect): exp.StrToTime: _str_to_time_sql, exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", exp.StructExtract: struct_extract_sql, - exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'", - exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", + exp.Table: transforms.preprocess([_unnest_sequence]), exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToDate: timestrtotime_sql, exp.TimeStrToTime: timestrtotime_sql, @@ -336,6 +327,7 @@ class Presto(Dialect): exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", exp.VariancePop: rename_func("VAR_POP"), + exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]), exp.WithinGroup: transforms.preprocess( [transforms.remove_within_group_for_percentiles] ), @@ -351,3 +343,25 @@ class Presto(Dialect): modes = expression.args.get("modes") modes = f" {', '.join(modes)}" if modes else "" return f"START TRANSACTION{modes}" + + def generateseries_sql(self, expression: exp.GenerateSeries) -> str: + start = expression.args["start"] + end = expression.args["end"] + step = expression.args.get("step") + + if isinstance(start, exp.Cast): + target_type = start.to + elif isinstance(end, exp.Cast): + target_type = end.to + else: + target_type = None + + if target_type and target_type.is_type(exp.DataType.Type.TIMESTAMP): + to = target_type.copy() + + if target_type is start.to: + end = exp.Cast(this=end, to=to) + else: + start = exp.Cast(this=start, to=to) + + return self.func("SEQUENCE", start, end, step) diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 1b7cf31..55e393a 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -8,21 +8,21 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _json_sql(self, e) -> str: - return f'{self.sql(e, "this")}."{e.expression.name}"' +def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str: + return f'{self.sql(expression, "this")}."{expression.expression.name}"' class Redshift(Postgres): time_format = "'YYYY-MM-DD HH:MI:SS'" time_mapping = { - **Postgres.time_mapping, # type: ignore + **Postgres.time_mapping, "MON": "%b", "HH": "%H", } class Parser(Postgres.Parser): FUNCTIONS = { - **Postgres.Parser.FUNCTIONS, # type: ignore + **Postgres.Parser.FUNCTIONS, "DATEADD": lambda args: exp.DateAdd( this=seq_get(args, 2), expression=seq_get(args, 1), @@ -45,7 +45,7 @@ class Redshift(Postgres): isinstance(this, exp.DataType) and this.this == exp.DataType.Type.VARCHAR and this.expressions - and this.expressions[0] == exp.column("MAX") + and this.expressions[0].this == exp.column("MAX") ): this.set("expressions", [exp.Var(this="MAX")]) @@ -57,9 +57,7 @@ class Redshift(Postgres): STRING_ESCAPES = ["\\"] KEYWORDS = { - **Postgres.Tokenizer.KEYWORDS, # type: ignore - "GEOMETRY": TokenType.GEOMETRY, - "GEOGRAPHY": TokenType.GEOGRAPHY, + **Postgres.Tokenizer.KEYWORDS, "HLLSKETCH": TokenType.HLLSKETCH, "SUPER": TokenType.SUPER, "SYSDATE": TokenType.CURRENT_TIMESTAMP, @@ -76,22 +74,22 @@ class Redshift(Postgres): class Generator(Postgres.Generator): LOCKING_READS_SUPPORTED = False - SINGLE_STRING_INTERVAL = True + RENAME_TABLE_WITH_DB = False TYPE_MAPPING = { - **Postgres.Generator.TYPE_MAPPING, # type: ignore + **Postgres.Generator.TYPE_MAPPING, exp.DataType.Type.BINARY: "VARBYTE", exp.DataType.Type.VARBINARY: "VARBYTE", exp.DataType.Type.INT: "INTEGER", } PROPERTIES_LOCATION = { - **Postgres.Generator.PROPERTIES_LOCATION, # type: ignore + **Postgres.Generator.PROPERTIES_LOCATION, exp.LikeProperty: exp.Properties.Location.POST_WITH, } TRANSFORMS = { - **Postgres.Generator.TRANSFORMS, # type: ignore + **Postgres.Generator.TRANSFORMS, exp.CurrentTimestamp: lambda self, e: "SYSDATE", exp.DateAdd: lambda self, e: self.func( "DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this @@ -107,10 +105,13 @@ class Redshift(Postgres): exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", } + # Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots + TRANSFORMS.pop(exp.Pivot) + # Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres) TRANSFORMS.pop(exp.Pow) - RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot"} + RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"} def values_sql(self, expression: exp.Values) -> str: """ @@ -120,37 +121,36 @@ class Redshift(Postgres): evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be very slow. """ - if not isinstance(expression.unnest().parent, exp.From): + + # The VALUES clause is still valid in an `INSERT INTO ..` statement, for example + if not expression.find_ancestor(exp.From, exp.Join): return super().values_sql(expression) - rows = [tuple_exp.expressions for tuple_exp in expression.expressions] + + column_names = expression.alias and expression.args["alias"].columns + selects = [] + rows = [tuple_exp.expressions for tuple_exp in expression.expressions] + for i, row in enumerate(rows): - if i == 0 and expression.alias: + if i == 0 and column_names: row = [ exp.alias_(value, column_name) - for value, column_name in zip(row, expression.args["alias"].args["columns"]) + for value, column_name in zip(row, column_names) ] + selects.append(exp.Select(expressions=row)) - subquery_expression = selects[0] + + subquery_expression: exp.Select | exp.Union = selects[0] if len(selects) > 1: for select in selects[1:]: subquery_expression = exp.union(subquery_expression, select, distinct=False) + return self.subquery_sql(subquery_expression.subquery(expression.alias)) def with_properties(self, properties: exp.Properties) -> str: """Redshift doesn't have `WITH` as part of their with_properties so we remove it""" return self.properties(properties, prefix=" ", suffix="") - def renametable_sql(self, expression: exp.RenameTable) -> str: - """Redshift only supports defining the table name itself (not the db) when renaming tables""" - expression = expression.copy() - target_table = expression.this - for arg in target_table.args: - if arg != "this": - target_table.set(arg, None) - this = self.sql(expression, "this") - return f"RENAME TO {this}" - def datatype_sql(self, expression: exp.DataType) -> str: """ Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean @@ -162,6 +162,8 @@ class Redshift(Postgres): expression = expression.copy() expression.set("this", exp.DataType.Type.VARCHAR) precision = expression.args.get("expressions") + if not precision: expression.append("expressions", exp.Var(this="MAX")) + return super().datatype_sql(expression) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 70dcaa9..756e8e9 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -18,7 +18,7 @@ from sqlglot.dialects.dialect import ( var_map_sql, ) from sqlglot.expressions import Literal -from sqlglot.helper import flatten, seq_get +from sqlglot.helper import seq_get from sqlglot.parser import binary_range_parser from sqlglot.tokens import TokenType @@ -30,7 +30,7 @@ def _check_int(s: str) -> bool: # from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html -def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.UnixToTime]: +def _snowflake_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]: if len(args) == 2: first_arg, second_arg = args if second_arg.is_string: @@ -52,8 +52,12 @@ def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.Unix return exp.UnixToTime(this=first_arg, scale=timescale) + from sqlglot.optimizer.simplify import simplify_literals + + # The first argument might be an expression like 40 * 365 * 86400, so we try to + # reduce it using `simplify_literals` first and then check if it's a Literal. first_arg = seq_get(args, 0) - if not isinstance(first_arg, Literal): + if not isinstance(simplify_literals(first_arg, root=True), Literal): # case: return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args) @@ -69,6 +73,19 @@ def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.Unix return exp.UnixToTime.from_arg_list(args) +def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: + expression = parser.parse_var_map(args) + + if isinstance(expression, exp.StarMap): + return expression + + return exp.Struct( + expressions=[ + t.cast(exp.Condition, k).eq(v) for k, v in zip(expression.keys, expression.values) + ] + ) + + def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") timestamp = self.sql(expression, "this") @@ -116,7 +133,7 @@ def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]: # https://docs.snowflake.com/en/sql-reference/functions/div0 -def _div0_to_if(args: t.Sequence) -> exp.Expression: +def _div0_to_if(args: t.List) -> exp.Expression: cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) true = exp.Literal.number(0) false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1)) @@ -124,13 +141,13 @@ def _div0_to_if(args: t.Sequence) -> exp.Expression: # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _zeroifnull_to_if(args: t.Sequence) -> exp.Expression: +def _zeroifnull_to_if(args: t.List) -> exp.Expression: cond = exp.Is(this=seq_get(args, 0), expression=exp.Null()) return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0)) # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _nullifzero_to_if(args: t.Sequence) -> exp.Expression: +def _nullifzero_to_if(args: t.List) -> exp.Expression: cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0)) return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) @@ -143,6 +160,12 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: return self.datatype_sql(expression) +def _parse_convert_timezone(args: t.List) -> exp.Expression: + if len(args) == 3: + return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args) + return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0)) + + class Snowflake(Dialect): null_ordering = "nulls_are_large" time_format = "'yyyy-mm-dd hh24:mi:ss'" @@ -177,17 +200,14 @@ class Snowflake(Dialect): } class Parser(parser.Parser): - QUOTED_PIVOT_COLUMNS = True + IDENTIFY_PIVOT_STRINGS = True FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, - "CONVERT_TIMEZONE": lambda args: exp.AtTimeZone( - this=seq_get(args, 1), - zone=seq_get(args, 0), - ), + "CONVERT_TIMEZONE": _parse_convert_timezone, "DATE_TRUNC": date_trunc_to_time, "DATEADD": lambda args: exp.DateAdd( this=seq_get(args, 2), @@ -202,7 +222,7 @@ class Snowflake(Dialect): "DIV0": _div0_to_if, "IFF": exp.If.from_arg_list, "NULLIFZERO": _nullifzero_to_if, - "OBJECT_CONSTRUCT": parser.parse_var_map, + "OBJECT_CONSTRUCT": _parse_object_construct, "RLIKE": exp.RegexpLike.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "TO_ARRAY": exp.Array.from_arg_list, @@ -224,7 +244,7 @@ class Snowflake(Dialect): } COLUMN_OPERATORS = { - **parser.Parser.COLUMN_OPERATORS, # type: ignore + **parser.Parser.COLUMN_OPERATORS, TokenType.COLON: lambda self, this, path: self.expression( exp.Bracket, this=this, @@ -232,14 +252,16 @@ class Snowflake(Dialect): ), } + TIMESTAMPS = parser.Parser.TIMESTAMPS.copy() - {TokenType.TIME} + RANGE_PARSERS = { - **parser.Parser.RANGE_PARSERS, # type: ignore + **parser.Parser.RANGE_PARSERS, TokenType.LIKE_ANY: binary_range_parser(exp.LikeAny), TokenType.ILIKE_ANY: binary_range_parser(exp.ILikeAny), } ALTER_PARSERS = { - **parser.Parser.ALTER_PARSERS, # type: ignore + **parser.Parser.ALTER_PARSERS, "UNSET": lambda self: self._parse_alter_table_set_tag(unset=True), "SET": lambda self: self._parse_alter_table_set_tag(), } @@ -256,17 +278,20 @@ class Snowflake(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "CHAR VARYING": TokenType.VARCHAR, + "CHARACTER VARYING": TokenType.VARCHAR, "EXCLUDE": TokenType.EXCEPT, "ILIKE ANY": TokenType.ILIKE_ANY, "LIKE ANY": TokenType.LIKE_ANY, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, + "MINUS": TokenType.EXCEPT, + "NCHAR VARYING": TokenType.VARCHAR, "PUT": TokenType.COMMAND, "RENAME": TokenType.REPLACE, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, "TIMESTAMPNTZ": TokenType.TIMESTAMP, - "MINUS": TokenType.EXCEPT, "SAMPLE": TokenType.TABLE_SAMPLE, } @@ -285,7 +310,7 @@ class Snowflake(Dialect): TABLE_HINTS = False TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.ArrayJoin: rename_func("ARRAY_TO_STRING"), @@ -299,6 +324,7 @@ class Snowflake(Dialect): exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), + exp.Extract: rename_func("DATE_PART"), exp.If: rename_func("IFF"), exp.LogicalAnd: rename_func("BOOLAND_AGG"), exp.LogicalOr: rename_func("BOOLOR_AGG"), @@ -312,6 +338,10 @@ class Snowflake(Dialect): "POSITION", e.args.get("substr"), e.this, e.args.get("position") ), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.Struct: lambda self, e: self.func( + "OBJECT_CONSTRUCT", + *(arg for expression in e.expressions for arg in expression.flatten()), + ), exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.TimeToStr: lambda self, e: self.func( @@ -326,7 +356,7 @@ class Snowflake(Dialect): } TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", } @@ -336,7 +366,7 @@ class Snowflake(Dialect): } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.SetProperty: exp.Properties.Location.UNSUPPORTED, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } @@ -351,53 +381,10 @@ class Snowflake(Dialect): self.unsupported("INTERSECT with All is not supported in Snowflake") return super().intersect_op(expression) - def values_sql(self, expression: exp.Values) -> str: - """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted. - - We also want to make sure that after we find matches where we need to unquote a column that we prevent users - from adding quotes to the column by using the `identify` argument when generating the SQL. - """ - alias = expression.args.get("alias") - if alias and alias.args.get("columns"): - expression = expression.transform( - lambda node: exp.Identifier(**{**node.args, "quoted": False}) - if isinstance(node, exp.Identifier) - and isinstance(node.parent, exp.TableAlias) - and node.arg_key == "columns" - else node, - ) - return self.no_identify(lambda: super(self.__class__, self).values_sql(expression)) - return super().values_sql(expression) - def settag_sql(self, expression: exp.SetTag) -> str: action = "UNSET" if expression.args.get("unset") else "SET" return f"{action} TAG {self.expressions(expression)}" - def select_sql(self, expression: exp.Select) -> str: - """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also - that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need - to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when - generating the SQL. - - Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the - expression. This might not be true in a case where the same column name can be sourced from another table that can - properly quote but should be true in most cases. - """ - values_identifiers = set( - flatten( - (v.args.get("alias") or exp.Alias()).args.get("columns", []) - for v in expression.find_all(exp.Values) - ) - ) - if values_identifiers: - expression = expression.transform( - lambda node: exp.Identifier(**{**node.args, "quoted": False}) - if isinstance(node, exp.Identifier) and node in values_identifiers - else node, - ) - return self.no_identify(lambda: super(self.__class__, self).select_sql(expression)) - return super().select_sql(expression) - def describe_sql(self, expression: exp.Describe) -> str: # Default to table if kind is unknown kind_value = expression.args.get("kind") or "TABLE" diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 939f2fd..b7d1641 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -7,10 +7,10 @@ from sqlglot.dialects.spark2 import Spark2 from sqlglot.helper import seq_get -def _parse_datediff(args: t.Sequence) -> exp.Expression: +def _parse_datediff(args: t.List) -> exp.Expression: """ Although Spark docs don't mention the "unit" argument, Spark3 added support for - it at some point. Databricks also supports this variation (see below). + it at some point. Databricks also supports this variant (see below). For example, in spark-sql (v3.3.1): - SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4 @@ -36,7 +36,7 @@ def _parse_datediff(args: t.Sequence) -> exp.Expression: class Spark(Spark2): class Parser(Spark2.Parser): FUNCTIONS = { - **Spark2.Parser.FUNCTIONS, # type: ignore + **Spark2.Parser.FUNCTIONS, "DATEDIFF": _parse_datediff, } diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 584671f..912b86b 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -3,7 +3,12 @@ from __future__ import annotations import typing as t from sqlglot import exp, parser, transforms -from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql +from sqlglot.dialects.dialect import ( + create_with_partitions_sql, + pivot_column_names, + rename_func, + trim_sql, +) from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get @@ -26,7 +31,7 @@ def _map_sql(self: Hive.Generator, expression: exp.Map) -> str: return f"MAP_FROM_ARRAYS({keys}, {values})" -def _parse_as_cast(to_type: str) -> t.Callable[[t.Sequence], exp.Expression]: +def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]: return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type)) @@ -53,10 +58,56 @@ def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str: raise ValueError("Improper scale for timestamp") +def _unalias_pivot(expression: exp.Expression) -> exp.Expression: + """ + Spark doesn't allow PIVOT aliases, so we need to remove them and possibly wrap a + pivoted source in a subquery with the same alias to preserve the query's semantics. + + Example: + >>> from sqlglot import parse_one + >>> expr = parse_one("SELECT piv.x FROM tbl PIVOT (SUM(a) FOR b IN ('x')) piv") + >>> print(_unalias_pivot(expr).sql(dialect="spark")) + SELECT piv.x FROM (SELECT * FROM tbl PIVOT(SUM(a) FOR b IN ('x'))) AS piv + """ + if isinstance(expression, exp.From) and expression.this.args.get("pivots"): + pivot = expression.this.args["pivots"][0] + if pivot.alias: + alias = pivot.args["alias"].pop() + return exp.From( + this=expression.this.replace( + exp.select("*").from_(expression.this.copy()).subquery(alias=alias) + ) + ) + + return expression + + +def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: + """ + Spark doesn't allow the column referenced in the PIVOT's field to be qualified, + so we need to unqualify it. + + Example: + >>> from sqlglot import parse_one + >>> expr = parse_one("SELECT * FROM tbl PIVOT (SUM(tbl.sales) FOR tbl.quarter IN ('Q1', 'Q2'))") + >>> print(_unqualify_pivot_columns(expr).sql(dialect="spark")) + SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1')) + """ + if isinstance(expression, exp.Pivot): + expression.args["field"].transform( + lambda node: exp.column(node.output_name, quoted=node.this.quoted) + if isinstance(node, exp.Column) + else node, + copy=False, + ) + + return expression + + class Spark2(Hive): class Parser(Hive.Parser): FUNCTIONS = { - **Hive.Parser.FUNCTIONS, # type: ignore + **Hive.Parser.FUNCTIONS, "MAP_FROM_ARRAYS": exp.Map.from_arg_list, "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, "LEFT": lambda args: exp.Substring( @@ -110,7 +161,7 @@ class Spark2(Hive): } FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, # type: ignore + **parser.Parser.FUNCTION_PARSERS, "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), @@ -131,43 +182,21 @@ class Spark2(Hive): kind="COLUMNS", ) - def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]: - # Spark doesn't add a suffix to the pivot columns when there's a single aggregation - if len(pivot_columns) == 1: + def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: + if len(aggregations) == 1: return [""] - - names = [] - for agg in pivot_columns: - if isinstance(agg, exp.Alias): - names.append(agg.alias) - else: - """ - This case corresponds to aggregations without aliases being used as suffixes - (e.g. col_avg(foo)). We need to unquote identifiers because they're going to - be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. - Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). - - Moreover, function names are lowercased in order to mimic Spark's naming scheme. - """ - agg_all_unquoted = agg.transform( - lambda node: exp.Identifier(this=node.name, quoted=False) - if isinstance(node, exp.Identifier) - else node - ) - names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower")) - - return names + return pivot_column_names(aggregations, dialect="spark") class Generator(Hive.Generator): TYPE_MAPPING = { - **Hive.Generator.TYPE_MAPPING, # type: ignore + **Hive.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "BYTE", exp.DataType.Type.SMALLINT: "SHORT", exp.DataType.Type.BIGINT: "LONG", } PROPERTIES_LOCATION = { - **Hive.Generator.PROPERTIES_LOCATION, # type: ignore + **Hive.Generator.PROPERTIES_LOCATION, exp.EngineProperty: exp.Properties.Location.UNSUPPORTED, exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED, exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED, @@ -175,7 +204,7 @@ class Spark2(Hive): } TRANSFORMS = { - **Hive.Generator.TRANSFORMS, # type: ignore + **Hive.Generator.TRANSFORMS, exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", @@ -188,11 +217,12 @@ class Spark2(Hive): exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfYear: rename_func("DAYOFYEAR"), exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", + exp.From: transforms.preprocess([_unalias_pivot]), exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalOr: rename_func("BOOL_OR"), exp.Map: _map_sql, - exp.Pivot: transforms.preprocess([transforms.unqualify_pivot_columns]), + exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]), exp.Reduce: rename_func("AGGREGATE"), exp.StrToDate: _str_to_date, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index f2efe32..56e7773 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_sql, count_if_to_sum, no_ilike_sql, + no_pivot_sql, no_tablesample_sql, no_trycast_sql, rename_func, @@ -14,7 +15,7 @@ from sqlglot.dialects.dialect import ( from sqlglot.tokens import TokenType -def _date_add_sql(self, expression): +def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str: modifier = expression.expression modifier = modifier.name if modifier.is_string else self.sql(modifier) unit = expression.args.get("unit") @@ -67,7 +68,7 @@ class SQLite(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "EDITDIST3": exp.Levenshtein.from_arg_list, } @@ -76,7 +77,7 @@ class SQLite(Dialect): TABLE_HINTS = False TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.BOOLEAN: "INTEGER", exp.DataType.Type.TINYINT: "INTEGER", exp.DataType.Type.SMALLINT: "INTEGER", @@ -98,7 +99,7 @@ class SQLite(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.CountIf: count_if_to_sum, exp.Create: transforms.preprocess([_transform_create]), exp.CurrentDate: lambda *_: "CURRENT_DATE", @@ -114,6 +115,7 @@ class SQLite(Dialect): exp.Levenshtein: rename_func("EDITDIST3"), exp.LogicalOr: rename_func("MAX"), exp.LogicalAnd: rename_func("MIN"), + exp.Pivot: no_pivot_sql, exp.Select: transforms.preprocess( [transforms.eliminate_distinct_on, transforms.eliminate_qualify] ), @@ -163,12 +165,15 @@ class SQLite(Dialect): return f"CAST({sql} AS INTEGER)" # https://www.sqlite.org/lang_aggfunc.html#group_concat - def groupconcat_sql(self, expression): + def groupconcat_sql(self, expression: exp.GroupConcat) -> str: this = expression.this distinct = expression.find(exp.Distinct) + if distinct: this = distinct.expressions[0] - distinct = "DISTINCT " + distinct_sql = "DISTINCT " + else: + distinct_sql = "" if isinstance(expression.this, exp.Order): self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.") @@ -176,7 +181,7 @@ class SQLite(Dialect): this = expression.this.this separator = expression.args.get("separator") - return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})" + return f"GROUP_CONCAT({distinct_sql}{self.format_args(this, separator)})" def least_sql(self, expression: exp.Least) -> str: if len(expression.expressions) > 1: diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 895588a..0390113 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -11,25 +11,24 @@ from sqlglot.helper import seq_get class StarRocks(MySQL): - class Parser(MySQL.Parser): # type: ignore + class Parser(MySQL.Parser): FUNCTIONS = { **MySQL.Parser.FUNCTIONS, - "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=seq_get(args, 0) ), } - class Generator(MySQL.Generator): # type: ignore + class Generator(MySQL.Generator): TYPE_MAPPING = { - **MySQL.Generator.TYPE_MAPPING, # type: ignore + **MySQL.Generator.TYPE_MAPPING, exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.TIMESTAMP: "DATETIME", exp.DataType.Type.TIMESTAMPTZ: "DATETIME", } TRANSFORMS = { - **MySQL.Generator.TRANSFORMS, # type: ignore + **MySQL.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, @@ -43,4 +42,5 @@ class StarRocks(MySQL): exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), } + TRANSFORMS.pop(exp.DateTrunc) diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index 51e685b..d5fba17 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -4,41 +4,38 @@ from sqlglot import exp, generator, parser, transforms from sqlglot.dialects.dialect import Dialect -def _if_sql(self, expression): - return f"IF {self.sql(expression, 'this')} THEN {self.sql(expression, 'true')} ELSE {self.sql(expression, 'false')} END" - - -def _coalesce_sql(self, expression): - return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})" - - -def _count_sql(self, expression): - this = expression.this - if isinstance(this, exp.Distinct): - return f"COUNTD({self.expressions(this, flat=True)})" - return f"COUNT({self.sql(expression, 'this')})" - - class Tableau(Dialect): class Generator(generator.Generator): JOIN_HINTS = False TABLE_HINTS = False TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore - exp.If: _if_sql, - exp.Coalesce: _coalesce_sql, - exp.Count: _count_sql, + **generator.Generator.TRANSFORMS, exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def if_sql(self, expression: exp.If) -> str: + this = self.sql(expression, "this") + true = self.sql(expression, "true") + false = self.sql(expression, "false") + return f"IF {this} THEN {true} ELSE {false} END" + + def coalesce_sql(self, expression: exp.Coalesce) -> str: + return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})" + + def count_sql(self, expression: exp.Count) -> str: + this = expression.this + if isinstance(this, exp.Distinct): + return f"COUNTD({self.expressions(this, flat=True)})" + return f"COUNT({self.sql(expression, 'this')})" + class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)), } diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index a79eaeb..9b39178 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -75,12 +75,12 @@ class Teradata(Dialect): FUNC_TOKENS.remove(TokenType.REPLACE) STATEMENT_PARSERS = { - **parser.Parser.STATEMENT_PARSERS, # type: ignore + **parser.Parser.STATEMENT_PARSERS, TokenType.REPLACE: lambda self: self._parse_create(), } FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, # type: ignore + **parser.Parser.FUNCTION_PARSERS, "RANGE_N": lambda self: self._parse_rangen(), "TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST), } @@ -106,7 +106,7 @@ class Teradata(Dialect): exp.Update, **{ # type: ignore "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), - "from": self._parse_from(), + "from": self._parse_from(modifiers=True), "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), "where": self._parse_where(), @@ -135,13 +135,15 @@ class Teradata(Dialect): TABLE_HINTS = False TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.GEOMETRY: "ST_GEOMETRY", } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore - exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX, + **generator.Generator.PROPERTIES_LOCATION, + exp.OnCommitProperty: exp.Properties.Location.POST_INDEX, + exp.PartitionedByProperty: exp.Properties.Location.POST_EXPRESSION, + exp.StabilityProperty: exp.Properties.Location.POST_CREATE, } TRANSFORMS = { diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py index c7b34fe..af0f78d 100644 --- a/sqlglot/dialects/trino.py +++ b/sqlglot/dialects/trino.py @@ -7,7 +7,7 @@ from sqlglot.dialects.presto import Presto class Trino(Presto): class Generator(Presto.Generator): TRANSFORMS = { - **Presto.Generator.TRANSFORMS, # type: ignore + **Presto.Generator.TRANSFORMS, exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 03de99c..f6ad888 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -16,6 +16,9 @@ from sqlglot.helper import seq_get from sqlglot.time import format_time from sqlglot.tokens import TokenType +if t.TYPE_CHECKING: + from sqlglot._typing import E + FULL_FORMAT_TIME_MAPPING = { "weekday": "%A", "dw": "%A", @@ -50,13 +53,17 @@ DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{ TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"} -def _format_time_lambda(exp_class, full_format_mapping=None, default=None): - def _format_time(args): +def _format_time_lambda( + exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None +) -> t.Callable[[t.List], E]: + def _format_time(args: t.List) -> E: + assert len(args) == 2 + return exp_class( - this=seq_get(args, 1), + this=args[1], format=exp.Literal.string( format_time( - seq_get(args, 0).name or (TSQL.time_format if default is True else default), + args[0].name, {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.time_mapping, @@ -67,13 +74,17 @@ def _format_time_lambda(exp_class, full_format_mapping=None, default=None): return _format_time -def _parse_format(args): - fmt = seq_get(args, 1) - number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this) +def _parse_format(args: t.List) -> exp.Expression: + assert len(args) == 2 + + fmt = args[1] + number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name) + if number_fmt: - return exp.NumberToStr(this=seq_get(args, 0), format=fmt) + return exp.NumberToStr(this=args[0], format=fmt) + return exp.TimeToStr( - this=seq_get(args, 0), + this=args[0], format=exp.Literal.string( format_time(fmt.name, TSQL.format_time_mapping) if len(fmt.name) == 1 @@ -82,7 +93,7 @@ def _parse_format(args): ) -def _parse_eomonth(args): +def _parse_eomonth(args: t.List) -> exp.Expression: date = seq_get(args, 0) month_lag = seq_get(args, 1) unit = DATE_DELTA_INTERVAL.get("month") @@ -96,7 +107,7 @@ def _parse_eomonth(args): return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit)) -def _parse_hashbytes(args): +def _parse_hashbytes(args: t.List) -> exp.Expression: kind, data = args kind = kind.name.upper() if kind.is_string else "" @@ -110,40 +121,47 @@ def _parse_hashbytes(args): return exp.SHA2(this=data, length=exp.Literal.number(256)) if kind == "SHA2_512": return exp.SHA2(this=data, length=exp.Literal.number(512)) + return exp.func("HASHBYTES", *args) -def generate_date_delta_with_unit_sql(self, e): - func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF" - return self.func(func, e.text("unit"), e.expression, e.this) +def generate_date_delta_with_unit_sql( + self: generator.Generator, expression: exp.DateAdd | exp.DateDiff +) -> str: + func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF" + return self.func(func, expression.text("unit"), expression.expression, expression.this) -def _format_sql(self, e): +def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str: fmt = ( - e.args["format"] - if isinstance(e, exp.NumberToStr) - else exp.Literal.string(format_time(e.text("format"), TSQL.inverse_time_mapping)) + expression.args["format"] + if isinstance(expression, exp.NumberToStr) + else exp.Literal.string( + format_time( + expression.text("format"), t.cast(t.Dict[str, str], TSQL.inverse_time_mapping) + ) + ) ) - return self.func("FORMAT", e.this, fmt) + return self.func("FORMAT", expression.this, fmt) -def _string_agg_sql(self, e): - e = e.copy() +def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str: + expression = expression.copy() - this = e.this - distinct = e.find(exp.Distinct) + this = expression.this + distinct = expression.find(exp.Distinct) if distinct: # exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.") this = distinct.pop().expressions[0] order = "" - if isinstance(e.this, exp.Order): - if e.this.this: - this = e.this.this.pop() - order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space + if isinstance(expression.this, exp.Order): + if expression.this.this: + this = expression.this.this.pop() + order = f" WITHIN GROUP ({self.sql(expression.this)[1:]})" # Order has a leading space - separator = e.args.get("separator") or exp.Literal.string(",") + separator = expression.args.get("separator") or exp.Literal.string(",") return f"STRING_AGG({self.format_args(this, separator)}){order}" @@ -292,7 +310,7 @@ class TSQL(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "CHARINDEX": lambda args: exp.StrPosition( this=seq_get(args, 1), substr=seq_get(args, 0), @@ -332,13 +350,13 @@ class TSQL(Dialect): DataType.Type.NCHAR, } - RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { # type: ignore + RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { TokenType.TABLE, - *parser.Parser.TYPE_TOKENS, # type: ignore + *parser.Parser.TYPE_TOKENS, } STATEMENT_PARSERS = { - **parser.Parser.STATEMENT_PARSERS, # type: ignore + **parser.Parser.STATEMENT_PARSERS, TokenType.END: lambda self: self._parse_command(), } @@ -377,7 +395,7 @@ class TSQL(Dialect): return system_time - def _parse_table_parts(self, schema: bool = False) -> exp.Expression: + def _parse_table_parts(self, schema: bool = False) -> exp.Table: table = super()._parse_table_parts(schema=schema) table.set("system_time", self._parse_system_time()) return table @@ -450,7 +468,7 @@ class TSQL(Dialect): LOCKING_READS_SUPPORTED = True TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.DECIMAL: "NUMERIC", exp.DataType.Type.DATETIME: "DATETIME2", @@ -458,7 +476,7 @@ class TSQL(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), @@ -480,7 +498,7 @@ class TSQL(Dialect): TRANSFORMS.pop(exp.ReturnsProperty) PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } -- cgit v1.2.3