diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-02 23:59:40 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-02 23:59:46 +0000 |
commit | 20739a12c39121a9e7ad3c9a2469ec5a6876199d (patch) | |
tree | c000de91c59fd29b2d9beecf9f93b84e69727f37 /sqlglot | |
parent | Releasing debian version 12.2.0-1. (diff) | |
download | sqlglot-20739a12c39121a9e7ad3c9a2469ec5a6876199d.tar.xz sqlglot-20739a12c39121a9e7ad3c9a2469ec5a6876199d.zip |
Merging upstream version 15.0.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
61 files changed, 3221 insertions, 2172 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index f7440e0..8fb623a 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -6,6 +6,7 @@ from __future__ import annotations +import logging import typing as t from sqlglot import expressions as exp @@ -45,12 +46,19 @@ from sqlglot.schema import MappingSchema as MappingSchema, Schema as Schema from sqlglot.tokens import Tokenizer as Tokenizer, TokenType as TokenType if t.TYPE_CHECKING: + from sqlglot._typing import E from sqlglot.dialects.dialect import DialectType as DialectType - T = t.TypeVar("T", bound=Expression) +logger = logging.getLogger("sqlglot") -__version__ = "12.2.0" +try: + from sqlglot._version import __version__, __version_tuple__ +except ImportError: + logger.error( + "Unable to set __version__, run `pip install -e .` or `python setup.py develop` first." + ) + pretty = False """Whether to format generated SQL by default.""" @@ -79,9 +87,9 @@ def parse(sql: str, read: DialectType = None, **opts) -> t.List[t.Optional[Expre def parse_one( sql: str, read: None = None, - into: t.Type[T] = ..., + into: t.Type[E] = ..., **opts, -) -> T: +) -> E: ... @@ -89,9 +97,9 @@ def parse_one( def parse_one( sql: str, read: DialectType, - into: t.Type[T], + into: t.Type[E], **opts, -) -> T: +) -> E: ... diff --git a/sqlglot/_typing.py b/sqlglot/_typing.py new file mode 100644 index 0000000..2acbbf7 --- /dev/null +++ b/sqlglot/_typing.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +import typing as t + +import sqlglot + +E = t.TypeVar("E", bound="sqlglot.exp.Expression") +T = t.TypeVar("T") diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.py index 1682ec1..fb46026 100644 --- a/sqlglot/dataframe/sql/_typing.pyi +++ b/sqlglot/dataframe/sql/_typing.py @@ -11,6 +11,8 @@ if t.TYPE_CHECKING: ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] ColumnOrName = t.Union[Column, str] -ColumnOrLiteral = t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] +ColumnOrLiteral = t.Union[ + Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime +] SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]] OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert] diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index f3a6f6f..3fc9232 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -127,7 +127,7 @@ class DataFrame: sequence_id: t.Optional[str] = None, **kwargs, ) -> t.Tuple[exp.CTE, str]: - name = self.spark._random_name + name = self._create_hash_from_expression(expression) expression_to_cte = expression.copy() expression_to_cte.set("with", None) cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] @@ -263,7 +263,7 @@ class DataFrame: return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions] @classmethod - def _create_hash_from_expression(cls, expression: exp.Select): + def _create_hash_from_expression(cls, expression: exp.Expression) -> str: value = expression.sql(dialect="spark").encode("utf-8") return f"t{zlib.crc32(value)}"[:6] @@ -299,7 +299,7 @@ class DataFrame: for expression_type, select_expression in select_expressions: select_expression = select_expression.transform(replace_id_value, replacement_mapping) if optimize: - select_expression = optimize_func(select_expression, identify="always") + select_expression = t.cast(exp.Select, optimize_func(select_expression)) select_expression = df._replace_cte_names_with_hashes(select_expression) expression: t.Union[exp.Select, exp.Cache, exp.Drop] if expression_type == exp.Cache: @@ -570,9 +570,9 @@ class DataFrame: r_expressions.append(l_column) r_columns_unused.remove(l_column) else: - r_expressions.append(exp.alias_(exp.Null(), l_column)) + r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) for r_column in r_columns_unused: - l_expressions.append(exp.alias_(exp.Null(), r_column)) + l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) r_expressions.append(r_column) r_df = ( other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) @@ -761,7 +761,7 @@ class DataFrame: raise ValueError("Tried to rename a column that doesn't exist") for existing_column in existing_columns: if isinstance(existing_column, exp.Column): - existing_column.replace(exp.alias_(existing_column.copy(), new)) + existing_column.replace(exp.alias_(existing_column, new)) else: existing_column.set("alias", exp.to_identifier(new)) return self.copy(expression=expression) diff --git a/sqlglot/dataframe/sql/operations.py b/sqlglot/dataframe/sql/operations.py index d51335c..e4c106b 100644 --- a/sqlglot/dataframe/sql/operations.py +++ b/sqlglot/dataframe/sql/operations.py @@ -41,7 +41,7 @@ def operation(op: Operation): self.last_op = Operation.NO_OP last_op = self.last_op new_op = op if op != Operation.NO_OP else last_op - if new_op < last_op or (last_op == new_op and new_op == Operation.SELECT): + if new_op < last_op or (last_op == new_op == Operation.SELECT): self = self._convert_leaf_to_cte() df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs) df.last_op = new_op # type: ignore diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py index af589b0..b883359 100644 --- a/sqlglot/dataframe/sql/session.py +++ b/sqlglot/dataframe/sql/session.py @@ -87,15 +87,13 @@ class SparkSession: select_kwargs = { "expressions": sel_columns, "from": exp.From( - expressions=[ - exp.Values( - expressions=data_expressions, - alias=exp.TableAlias( - this=exp.to_identifier(self._auto_incrementing_name), - columns=[exp.to_identifier(col_name) for col_name in column_mapping], - ), + this=exp.Values( + expressions=data_expressions, + alias=exp.TableAlias( + this=exp.to_identifier(self._auto_incrementing_name), + columns=[exp.to_identifier(col_name) for col_name in column_mapping], ), - ], + ), ), } @@ -128,10 +126,6 @@ class SparkSession: return name @property - def _random_name(self) -> str: - return "r" + uuid.uuid4().hex - - @property def _random_branch_id(self) -> str: id = self._random_id self.known_branch_ids.add(id) @@ -145,7 +139,7 @@ class SparkSession: @property def _random_id(self) -> str: - id = self._random_name + id = "r" + uuid.uuid4().hex self.known_ids.add(id) return id diff --git a/sqlglot/dataframe/sql/util.py b/sqlglot/dataframe/sql/util.py index 575d18a..4b9fbb1 100644 --- a/sqlglot/dataframe/sql/util.py +++ b/sqlglot/dataframe/sql/util.py @@ -27,6 +27,6 @@ def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.T if not expression.args.get("joins"): return [] - left_table = expression.args["from"].args["expressions"][0] + left_table = expression.args["from"].this other_tables = [join.this for join in expression.args["joins"]] return [left_table] + other_tables diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 9705b35..1a58337 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -1,5 +1,3 @@ -"""Supports BigQuery Standard SQL.""" - from __future__ import annotations import re @@ -18,11 +16,9 @@ from sqlglot.dialects.dialect import ( timestrtotime_sql, ts_or_ds_to_date_sql, ) -from sqlglot.helper import seq_get +from sqlglot.helper import seq_get, split_num_words from sqlglot.tokens import TokenType -E = t.TypeVar("E", bound=exp.Expression) - def _date_add_sql( data_type: str, kind: str @@ -96,19 +92,12 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression: These are added by the optimizer's qualify_column step. """ if isinstance(expression, exp.Select): - unnests = { - unnest.alias - for unnest in expression.args.get("from", exp.From(expressions=[])).expressions - if isinstance(unnest, exp.Unnest) and unnest.alias - } - - if unnests: - expression = expression.copy() - - for select in expression.expressions: - for column in select.find_all(exp.Column): - if column.table in unnests: - column.set("table", None) + for unnest in expression.find_all(exp.Unnest): + if isinstance(unnest.parent, (exp.From, exp.Join)) and unnest.alias: + for select in expression.selects: + for column in select.find_all(exp.Column): + if column.table == unnest.alias: + column.set("table", None) return expression @@ -127,16 +116,20 @@ class BigQuery(Dialect): } class Tokenizer(tokens.Tokenizer): - QUOTES = [ - (prefix + quote, quote) if prefix else quote - for quote in ["'", '"', '"""', "'''"] - for prefix in ["", "r", "R"] - ] + QUOTES = ["'", '"', '"""', "'''"] COMMENTS = ["--", "#", ("/*", "*/")] IDENTIFIERS = ["`"] STRING_ESCAPES = ["\\"] + HEX_STRINGS = [("0x", ""), ("0X", "")] - BYTE_STRINGS = [("b'", "'"), ("B'", "'")] + + BYTE_STRINGS = [ + (prefix + q, q) for q in t.cast(t.List[str], QUOTES) for prefix in ("b", "B") + ] + + RAW_STRINGS = [ + (prefix + q, q) for q in t.cast(t.List[str], QUOTES) for prefix in ("r", "R") + ] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -144,11 +137,11 @@ class BigQuery(Dialect): "BEGIN": TokenType.COMMAND, "BEGIN TRANSACTION": TokenType.BEGIN, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, + "BYTES": TokenType.BINARY, "DECLARE": TokenType.COMMAND, - "GEOGRAPHY": TokenType.GEOGRAPHY, "FLOAT64": TokenType.DOUBLE, "INT64": TokenType.BIGINT, - "BYTES": TokenType.BINARY, + "RECORD": TokenType.STRUCT, "NOT DETERMINISTIC": TokenType.VOLATILE, "UNKNOWN": TokenType.NULL, } @@ -161,7 +154,7 @@ class BigQuery(Dialect): LOG_DEFAULTS_TO_LN = True FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "DATE_TRUNC": lambda args: exp.DateTrunc( unit=exp.Literal.string(str(seq_get(args, 1))), this=seq_get(args, 0), @@ -191,28 +184,28 @@ class BigQuery(Dialect): } FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, # type: ignore + **parser.Parser.FUNCTION_PARSERS, "ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]), } FUNCTION_PARSERS.pop("TRIM") NO_PAREN_FUNCTIONS = { - **parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore + **parser.Parser.NO_PAREN_FUNCTIONS, TokenType.CURRENT_DATETIME: exp.CurrentDatetime, } NESTED_TYPE_TOKENS = { - *parser.Parser.NESTED_TYPE_TOKENS, # type: ignore + *parser.Parser.NESTED_TYPE_TOKENS, TokenType.TABLE, } ID_VAR_TOKENS = { - *parser.Parser.ID_VAR_TOKENS, # type: ignore + *parser.Parser.ID_VAR_TOKENS, TokenType.VALUES, } PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, # type: ignore + **parser.Parser.PROPERTY_PARSERS, "NOT DETERMINISTIC": lambda self: self.expression( exp.StabilityProperty, this=exp.Literal.string("VOLATILE") ), @@ -220,19 +213,50 @@ class BigQuery(Dialect): } CONSTRAINT_PARSERS = { - **parser.Parser.CONSTRAINT_PARSERS, # type: ignore + **parser.Parser.CONSTRAINT_PARSERS, "OPTIONS": lambda self: exp.Properties(expressions=self._parse_with_property()), } + def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: + this = super()._parse_table_part(schema=schema) + + # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#table_names + if isinstance(this, exp.Identifier): + table_name = this.name + while self._match(TokenType.DASH, advance=False) and self._next: + self._advance(2) + table_name += f"-{self._prev.text}" + + this = exp.Identifier(this=table_name, quoted=this.args.get("quoted")) + + return this + + def _parse_table_parts(self, schema: bool = False) -> exp.Table: + table = super()._parse_table_parts(schema=schema) + if isinstance(table.this, exp.Identifier) and "." in table.name: + catalog, db, this, *rest = ( + t.cast(t.Optional[exp.Expression], exp.to_identifier(x)) + for x in split_num_words(table.name, ".", 3) + ) + + if rest and this: + this = exp.Dot.build(t.cast(t.List[exp.Expression], [this, *rest])) + + table = exp.Table(this=this, db=db, catalog=catalog) + + return table + class Generator(generator.Generator): EXPLICIT_UNION = True INTERVAL_ALLOWS_PLURAL_FORM = False JOIN_HINTS = False TABLE_HINTS = False LIMIT_FETCH = "LIMIT" + RENAME_TABLE_WITH_DB = False TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, + exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.AtTimeZone: lambda self, e: self.func( "TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone")) @@ -259,6 +283,7 @@ class BigQuery(Dialect): exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), exp.TimeStrToTime: timestrtotime_sql, + exp.TryCast: lambda self, e: f"SAFE_CAST({self.sql(e, 'this')} AS {self.sql(e, 'to')})", exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"), exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", @@ -274,7 +299,7 @@ class BigQuery(Dialect): } TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC", exp.DataType.Type.BIGINT: "INT64", exp.DataType.Type.BINARY: "BYTES", @@ -297,7 +322,7 @@ class BigQuery(Dialect): } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 2a49066..c8a9525 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -3,11 +3,16 @@ from __future__ import annotations import typing as t from sqlglot import exp, generator, parser, tokens -from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql +from sqlglot.dialects.dialect import ( + Dialect, + inline_array_sql, + no_pivot_sql, + rename_func, + var_map_sql, +) from sqlglot.errors import ParseError -from sqlglot.helper import ensure_list, seq_get from sqlglot.parser import parse_var_map -from sqlglot.tokens import TokenType +from sqlglot.tokens import Token, TokenType def _lower_func(sql: str) -> str: @@ -28,65 +33,122 @@ class ClickHouse(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "ASOF": TokenType.ASOF, - "GLOBAL": TokenType.GLOBAL, - "DATETIME64": TokenType.DATETIME, + "ATTACH": TokenType.COMMAND, + "DATETIME64": TokenType.DATETIME64, "FINAL": TokenType.FINAL, "FLOAT32": TokenType.FLOAT, "FLOAT64": TokenType.DOUBLE, - "INT8": TokenType.TINYINT, - "UINT8": TokenType.UTINYINT, + "GLOBAL": TokenType.GLOBAL, + "INT128": TokenType.INT128, "INT16": TokenType.SMALLINT, - "UINT16": TokenType.USMALLINT, + "INT256": TokenType.INT256, "INT32": TokenType.INT, - "UINT32": TokenType.UINT, "INT64": TokenType.BIGINT, - "UINT64": TokenType.UBIGINT, - "INT128": TokenType.INT128, + "INT8": TokenType.TINYINT, + "MAP": TokenType.MAP, + "TUPLE": TokenType.STRUCT, "UINT128": TokenType.UINT128, - "INT256": TokenType.INT256, + "UINT16": TokenType.USMALLINT, "UINT256": TokenType.UINT256, - "TUPLE": TokenType.STRUCT, + "UINT32": TokenType.UINT, + "UINT64": TokenType.UBIGINT, + "UINT8": TokenType.UTINYINT, } class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore - "EXPONENTIALTIMEDECAYEDAVG": lambda params, args: exp.ExponentialTimeDecayedAvg( - this=seq_get(args, 0), - time=seq_get(args, 1), - decay=seq_get(params, 0), - ), - "GROUPUNIQARRAY": lambda params, args: exp.GroupUniqArray( - this=seq_get(args, 0), size=seq_get(params, 0) - ), - "HISTOGRAM": lambda params, args: exp.Histogram( - this=seq_get(args, 0), bins=seq_get(params, 0) - ), + **parser.Parser.FUNCTIONS, + "ANY": exp.AnyValue.from_arg_list, "MAP": parse_var_map, "MATCH": exp.RegexpLike.from_arg_list, - "QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params), - "QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args), - "QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args), + "UNIQ": exp.ApproxDistinct.from_arg_list, + } + + FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"} + + FUNCTION_PARSERS = { + **parser.Parser.FUNCTION_PARSERS, + "QUANTILE": lambda self: self._parse_quantile(), } - FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy() FUNCTION_PARSERS.pop("MATCH") + NO_PAREN_FUNCTION_PARSERS = parser.Parser.NO_PAREN_FUNCTION_PARSERS.copy() + NO_PAREN_FUNCTION_PARSERS.pop(TokenType.ANY) + RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, TokenType.GLOBAL: lambda self, this: self._match(TokenType.IN) and self._parse_in(this, is_global=True), } - JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore + # The PLACEHOLDER entry is popped because 1) it doesn't affect Clickhouse (it corresponds to + # the postgres-specific JSONBContains parser) and 2) it makes parsing the ternary op simpler. + COLUMN_OPERATORS = parser.Parser.COLUMN_OPERATORS.copy() + COLUMN_OPERATORS.pop(TokenType.PLACEHOLDER) + + JOIN_KINDS = { + *parser.Parser.JOIN_KINDS, + TokenType.ANY, + TokenType.ASOF, + TokenType.ANTI, + TokenType.SEMI, + } - TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore + TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - { + TokenType.ANY, + TokenType.ASOF, + TokenType.SEMI, + TokenType.ANTI, + TokenType.SETTINGS, + TokenType.FORMAT, + } LOG_DEFAULTS_TO_LN = True - def _parse_in( - self, this: t.Optional[exp.Expression], is_global: bool = False - ) -> exp.Expression: + QUERY_MODIFIER_PARSERS = { + **parser.Parser.QUERY_MODIFIER_PARSERS, + "settings": lambda self: self._parse_csv(self._parse_conjunction) + if self._match(TokenType.SETTINGS) + else None, + "format": lambda self: self._parse_id_var() if self._match(TokenType.FORMAT) else None, + } + + def _parse_conjunction(self) -> t.Optional[exp.Expression]: + this = super()._parse_conjunction() + + if self._match(TokenType.PLACEHOLDER): + return self.expression( + exp.If, + this=this, + true=self._parse_conjunction(), + false=self._match(TokenType.COLON) and self._parse_conjunction(), + ) + + return this + + def _parse_placeholder(self) -> t.Optional[exp.Expression]: + """ + Parse a placeholder expression like SELECT {abc: UInt32} or FROM {table: Identifier} + https://clickhouse.com/docs/en/sql-reference/syntax#defining-and-using-query-parameters + """ + if not self._match(TokenType.L_BRACE): + return None + + this = self._parse_id_var() + self._match(TokenType.COLON) + kind = self._parse_types(check_func=False) or ( + self._match_text_seq("IDENTIFIER") and "Identifier" + ) + + if not kind: + self.raise_error("Expecting a placeholder type or 'Identifier' for tables") + elif not self._match(TokenType.R_BRACE): + self.raise_error("Expecting }") + + return self.expression(exp.Placeholder, this=this, kind=kind) + + def _parse_in(self, this: t.Optional[exp.Expression], is_global: bool = False) -> exp.In: this = super()._parse_in(this) this.set("is_global", is_global) return this @@ -120,81 +182,142 @@ class ClickHouse(Dialect): return self.expression(exp.CTE, this=statement, alias=statement and statement.this) + def _parse_join_side_and_kind( + self, + ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]: + is_global = self._match(TokenType.GLOBAL) and self._prev + kind_pre = self._match_set(self.JOIN_KINDS, advance=False) and self._prev + if kind_pre: + kind = self._match_set(self.JOIN_KINDS) and self._prev + side = self._match_set(self.JOIN_SIDES) and self._prev + return is_global, side, kind + return ( + is_global, + self._match_set(self.JOIN_SIDES) and self._prev, + self._match_set(self.JOIN_KINDS) and self._prev, + ) + + def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]: + join = super()._parse_join(skip_join_token) + + if join: + join.set("global", join.args.pop("natural", None)) + return join + + def _parse_function( + self, functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False + ) -> t.Optional[exp.Expression]: + func = super()._parse_function(functions, anonymous) + + if isinstance(func, exp.Anonymous): + params = self._parse_func_params(func) + + if params: + return self.expression( + exp.ParameterizedAgg, + this=func.this, + expressions=func.expressions, + params=params, + ) + + return func + + def _parse_func_params( + self, this: t.Optional[exp.Func] = None + ) -> t.Optional[t.List[t.Optional[exp.Expression]]]: + if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN): + return self._parse_csv(self._parse_lambda) + if self._match(TokenType.L_PAREN): + params = self._parse_csv(self._parse_lambda) + self._match_r_paren(this) + return params + return None + + def _parse_quantile(self) -> exp.Quantile: + this = self._parse_lambda() + params = self._parse_func_params() + if params: + return self.expression(exp.Quantile, this=params[0], quantile=this) + return self.expression(exp.Quantile, this=this, quantile=exp.Literal.number(0.5)) + + def _parse_wrapped_id_vars( + self, optional: bool = False + ) -> t.List[t.Optional[exp.Expression]]: + return super()._parse_wrapped_id_vars(optional=True) + class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore - exp.DataType.Type.NULLABLE: "Nullable", - exp.DataType.Type.DATETIME: "DateTime64", - exp.DataType.Type.MAP: "Map", + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.ARRAY: "Array", + exp.DataType.Type.BIGINT: "Int64", + exp.DataType.Type.DATETIME64: "DateTime64", + exp.DataType.Type.DOUBLE: "Float64", + exp.DataType.Type.FLOAT: "Float32", + exp.DataType.Type.INT: "Int32", + exp.DataType.Type.INT128: "Int128", + exp.DataType.Type.INT256: "Int256", + exp.DataType.Type.MAP: "Map", + exp.DataType.Type.NULLABLE: "Nullable", + exp.DataType.Type.SMALLINT: "Int16", exp.DataType.Type.STRUCT: "Tuple", exp.DataType.Type.TINYINT: "Int8", - exp.DataType.Type.UTINYINT: "UInt8", - exp.DataType.Type.SMALLINT: "Int16", - exp.DataType.Type.USMALLINT: "UInt16", - exp.DataType.Type.INT: "Int32", - exp.DataType.Type.UINT: "UInt32", - exp.DataType.Type.BIGINT: "Int64", exp.DataType.Type.UBIGINT: "UInt64", - exp.DataType.Type.INT128: "Int128", + exp.DataType.Type.UINT: "UInt32", exp.DataType.Type.UINT128: "UInt128", - exp.DataType.Type.INT256: "Int256", exp.DataType.Type.UINT256: "UInt256", - exp.DataType.Type.FLOAT: "Float32", - exp.DataType.Type.DOUBLE: "Float64", + exp.DataType.Type.USMALLINT: "UInt16", + exp.DataType.Type.UTINYINT: "UInt8", } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, + exp.AnyValue: rename_func("any"), + exp.ApproxDistinct: rename_func("uniq"), exp.Array: inline_array_sql, - exp.ExponentialTimeDecayedAvg: lambda self, e: f"exponentialTimeDecayedAvg{self._param_args_sql(e, 'decay', ['this', 'time'])}", + exp.CastToStrType: rename_func("CAST"), exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", - exp.GroupUniqArray: lambda self, e: f"groupUniqArray{self._param_args_sql(e, 'size', 'this')}", - exp.Histogram: lambda self, e: f"histogram{self._param_args_sql(e, 'bins', 'this')}", exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)), - exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}", - exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}", - exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}", + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", + exp.Pivot: no_pivot_sql, + exp.Quantile: lambda self, e: self.func("quantile", e.args.get("quantile")) + + f"({self.sql(e, 'this')})", exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})", exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, } JOIN_HINTS = False TABLE_HINTS = False EXPLICIT_UNION = True - - def _param_args_sql( - self, - expression: exp.Expression, - param_names: str | t.List[str], - arg_names: str | t.List[str], - ) -> str: - params = self.format_args( - *( - arg - for name in ensure_list(param_names) - for arg in ensure_list(expression.args.get(name)) - ) - ) - args = self.format_args( - *( - arg - for name in ensure_list(arg_names) - for arg in ensure_list(expression.args.get(name)) - ) - ) - return f"({params})({args})" + GROUPINGS_SEP = "" def cte_sql(self, expression: exp.CTE) -> str: if isinstance(expression.this, exp.Alias): return self.sql(expression, "this") return super().cte_sql(expression) + + def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]: + return super().after_limit_modifiers(expression) + [ + self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True) + if expression.args.get("settings") + else "", + self.seg("FORMAT ") + self.sql(expression, "format") + if expression.args.get("format") + else "", + ] + + def parameterizedagg_sql(self, expression: exp.Anonymous) -> str: + params = self.expressions(expression, "params", flat=True) + return self.func(expression.name, *expression.expressions) + f"({params})" + + def placeholder_sql(self, expression: exp.Placeholder) -> str: + return f"{{{expression.name}: {self.sql(expression, 'kind')}}}" diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 51112a0..2149aca 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -25,7 +25,7 @@ class Databricks(Spark): class Generator(Spark.Generator): TRANSFORMS = { - **Spark.Generator.TRANSFORMS, # type: ignore + **Spark.Generator.TRANSFORMS, exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, exp.JSONExtract: lambda self, e: self.binary(e, ":"), diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 71269f2..890a3c3 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -8,10 +8,16 @@ from sqlglot.generator import Generator from sqlglot.helper import flatten, seq_get from sqlglot.parser import Parser from sqlglot.time import format_time -from sqlglot.tokens import Token, Tokenizer +from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import new_trie -E = t.TypeVar("E", bound=exp.Expression) +if t.TYPE_CHECKING: + from sqlglot._typing import E + + +# Only Snowflake is currently known to resolve unquoted identifiers as uppercase. +# https://docs.snowflake.com/en/sql-reference/identifiers-syntax +RESOLVES_IDENTIFIERS_AS_UPPERCASE = {"snowflake"} class Dialects(str, Enum): @@ -42,6 +48,19 @@ class Dialects(str, Enum): class _Dialect(type): classes: t.Dict[str, t.Type[Dialect]] = {} + def __eq__(cls, other: t.Any) -> bool: + if cls is other: + return True + if isinstance(other, str): + return cls is cls.get(other) + if isinstance(other, Dialect): + return cls is type(other) + + return False + + def __hash__(cls) -> int: + return hash(cls.__name__.lower()) + @classmethod def __getitem__(cls, key: str) -> t.Type[Dialect]: return cls.classes[key] @@ -70,17 +89,20 @@ class _Dialect(type): klass.tokenizer_class._IDENTIFIERS.items() )[0] - klass.bit_start, klass.bit_end = seq_get( - list(klass.tokenizer_class._BIT_STRINGS.items()), 0 - ) or (None, None) - - klass.hex_start, klass.hex_end = seq_get( - list(klass.tokenizer_class._HEX_STRINGS.items()), 0 - ) or (None, None) + def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: + return next( + ( + (s, e) + for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() + if t == token_type + ), + (None, None), + ) - klass.byte_start, klass.byte_end = seq_get( - list(klass.tokenizer_class._BYTE_STRINGS.items()), 0 - ) or (None, None) + klass.bit_start, klass.bit_end = get_start_end(TokenType.BIT_STRING) + klass.hex_start, klass.hex_end = get_start_end(TokenType.HEX_STRING) + klass.byte_start, klass.byte_end = get_start_end(TokenType.BYTE_STRING) + klass.raw_start, klass.raw_end = get_start_end(TokenType.RAW_STRING) return klass @@ -110,6 +132,12 @@ class Dialect(metaclass=_Dialect): parser_class = None generator_class = None + def __eq__(self, other: t.Any) -> bool: + return type(self) == other + + def __hash__(self) -> int: + return hash(type(self)) + @classmethod def get_or_raise(cls, dialect: DialectType) -> t.Type[Dialect]: if not dialect: @@ -192,6 +220,8 @@ class Dialect(metaclass=_Dialect): "hex_end": self.hex_end, "byte_start": self.byte_start, "byte_end": self.byte_end, + "raw_start": self.raw_start, + "raw_end": self.raw_end, "identifier_start": self.identifier_start, "identifier_end": self.identifier_end, "string_escape": self.tokenizer_class.STRING_ESCAPES[0], @@ -275,7 +305,7 @@ def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: self.unsupported("PIVOT unsupported") - return self.sql(expression) + return "" def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: @@ -328,7 +358,7 @@ def var_map_sql( def format_time_lambda( exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None -) -> t.Callable[[t.Sequence], E]: +) -> t.Callable[[t.List], E]: """Helper used for time expressions. Args: @@ -340,7 +370,7 @@ def format_time_lambda( A callable that can be used to return the appropriately formatted time expression. """ - def _format_time(args: t.Sequence): + def _format_time(args: t.List): return exp_class( this=seq_get(args, 0), format=Dialect[dialect].format_time( @@ -377,12 +407,12 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: def parse_date_delta( exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None -) -> t.Callable[[t.Sequence], E]: - def inner_func(args: t.Sequence) -> E: +) -> t.Callable[[t.List], E]: + def inner_func(args: t.List) -> E: unit_based = len(args) == 3 this = args[2] if unit_based else seq_get(args, 0) unit = args[0] if unit_based else exp.Literal.string("DAY") - unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit + unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit return exp_class(this=this, expression=seq_get(args, 1), unit=unit) return inner_func @@ -390,8 +420,8 @@ def parse_date_delta( def parse_date_delta_with_interval( expression_class: t.Type[E], -) -> t.Callable[[t.Sequence], t.Optional[E]]: - def func(args: t.Sequence) -> t.Optional[E]: +) -> t.Callable[[t.List], t.Optional[E]]: + def func(args: t.List) -> t.Optional[E]: if len(args) < 2: return None @@ -409,7 +439,7 @@ def parse_date_delta_with_interval( return func -def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc: +def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: unit = seq_get(args, 0) this = seq_get(args, 1) @@ -424,7 +454,7 @@ def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: ) -def locate_to_strposition(args: t.Sequence) -> exp.Expression: +def locate_to_strposition(args: t.List) -> exp.Expression: return exp.StrPosition( this=seq_get(args, 1), substr=seq_get(args, 0), @@ -483,7 +513,7 @@ def trim_sql(self: Generator, expression: exp.Trim) -> str: return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" -def str_to_time_sql(self, expression: exp.Expression) -> str: +def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: return self.func("STRPTIME", expression.this, self.format_time(expression)) @@ -496,3 +526,26 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: return f"CAST({self.sql(expression, 'this')} AS DATE)" return _ts_or_ds_to_date_sql + + +# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator +def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: + names = [] + for agg in aggregations: + if isinstance(agg, exp.Alias): + names.append(agg.alias) + else: + """ + This case corresponds to aggregations without aliases being used as suffixes + (e.g. col_avg(foo)). We need to unquote identifiers because they're going to + be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. + Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). + """ + agg_all_unquoted = agg.transform( + lambda node: exp.Identifier(this=node.name, quoted=False) + if isinstance(node, exp.Identifier) + else node + ) + names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) + + return names diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 7ad555e..924b979 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -95,7 +95,7 @@ class Drill(Dialect): STRICT_CAST = False FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "drill"), "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, "TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"), @@ -108,7 +108,7 @@ class Drill(Dialect): TABLE_HINTS = False TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.SMALLINT: "INTEGER", exp.DataType.Type.TINYINT: "INTEGER", @@ -121,13 +121,13 @@ class Drill(Dialect): } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.ArrayContains: rename_func("REPEATED_CONTAINS"), exp.ArraySize: rename_func("REPEATED_COUNT"), diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index bce956e..662882d 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -11,9 +11,9 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, format_time_lambda, no_comment_column_constraint_sql, - no_pivot_sql, no_properties_sql, no_safe_divide_sql, + pivot_column_names, rename_func, str_position_sql, str_to_time_sql, @@ -31,10 +31,11 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" -def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str: +def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" - return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" + op = "+" if isinstance(expression, exp.DateAdd) else "-" + return f"{this} {op} {self.sql(exp.Interval(this=expression.expression, unit=unit))}" def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str: @@ -50,11 +51,11 @@ def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str return f"ARRAY_SORT({this})" -def _sort_array_reverse(args: t.Sequence) -> exp.Expression: +def _sort_array_reverse(args: t.List) -> exp.Expression: return exp.SortArray(this=seq_get(args, 0), asc=exp.false()) -def _parse_date_diff(args: t.Sequence) -> exp.Expression: +def _parse_date_diff(args: t.List) -> exp.Expression: return exp.DateDiff( this=seq_get(args, 2), expression=seq_get(args, 1), @@ -89,11 +90,14 @@ def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract class DuckDB(Dialect): + null_ordering = "nulls_are_last" + class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "~": TokenType.RLIKE, ":=": TokenType.EQ, + "//": TokenType.DIV, "ATTACH": TokenType.COMMAND, "BINARY": TokenType.VARBINARY, "BPCHAR": TokenType.TEXT, @@ -104,6 +108,7 @@ class DuckDB(Dialect): "INT1": TokenType.TINYINT, "LOGICAL": TokenType.BOOLEAN, "NUMERIC": TokenType.DOUBLE, + "PIVOT_WIDER": TokenType.PIVOT, "SIGNED": TokenType.INT, "STRING": TokenType.VARCHAR, "UBIGINT": TokenType.UBIGINT, @@ -114,8 +119,7 @@ class DuckDB(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore - "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, + **parser.Parser.FUNCTIONS, "ARRAY_LENGTH": exp.ArraySize.from_arg_list, "ARRAY_SORT": exp.SortArray.from_arg_list, "ARRAY_REVERSE_SORT": _sort_array_reverse, @@ -152,11 +156,17 @@ class DuckDB(Dialect): TokenType.UTINYINT, } + def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: + if len(aggregations) == 1: + return super()._pivot_column_names(aggregations) + return pivot_column_names(aggregations, dialect="duckdb") + class Generator(generator.Generator): JOIN_HINTS = False TABLE_HINTS = False LIMIT_FETCH = "LIMIT" STRUCT_DELIMITER = ("(", ")") + RENAME_TABLE_WITH_DB = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -175,7 +185,8 @@ class DuckDB(Dialect): exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfYear: rename_func("DAYOFYEAR"), exp.DataType: _datatype_sql, - exp.DateAdd: _date_add_sql, + exp.DateAdd: _date_delta_sql, + exp.DateSub: _date_delta_sql, exp.DateDiff: lambda self, e: self.func( "DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this ), @@ -183,13 +194,13 @@ class DuckDB(Dialect): exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)", exp.Explode: rename_func("UNNEST"), + exp.IntDiv: lambda self, e: self.binary(e, "//"), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: arrow_json_extract_sql, exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), - exp.Pivot: no_pivot_sql, exp.Properties: no_properties_sql, exp.RegexpExtract: _regexp_extract_sql, exp.RegexpLike: rename_func("REGEXP_MATCHES"), @@ -232,11 +243,11 @@ class DuckDB(Dialect): STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"} PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } def tablesample_sql( - self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " + self, expression: exp.TableSample, seed_prefix: str = "SEED", sep: str = " AS " ) -> str: return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 871a180..fbd626a 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -147,13 +147,6 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str return f"TO_DATE({this})" -def _index_sql(self: generator.Generator, expression: exp.Index) -> str: - this = self.sql(expression, "this") - table = self.sql(expression, "table") - columns = self.sql(expression, "columns") - return f"{this} ON TABLE {table} {columns}" - - class Hive(Dialect): alias_post_tablesample = True @@ -225,8 +218,7 @@ class Hive(Dialect): STRICT_CAST = False FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore - "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, + **parser.Parser.FUNCTIONS, "BASE64": exp.ToBase64.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list, "DATE_ADD": lambda args: exp.TsOrDsAdd( @@ -271,21 +263,29 @@ class Hive(Dialect): } PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, # type: ignore + **parser.Parser.PROPERTY_PARSERS, "WITH SERDEPROPERTIES": lambda self: exp.SerdeProperties( expressions=self._parse_wrapped_csv(self._parse_property) ), } + QUERY_MODIFIER_PARSERS = { + **parser.Parser.QUERY_MODIFIER_PARSERS, + "distribute": lambda self: self._parse_sort(exp.Distribute, "DISTRIBUTE", "BY"), + "sort": lambda self: self._parse_sort(exp.Sort, "SORT", "BY"), + "cluster": lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"), + } + class Generator(generator.Generator): LIMIT_FETCH = "LIMIT" TABLESAMPLE_WITH_METHOD = False TABLESAMPLE_SIZE_IS_PERCENT = True JOIN_HINTS = False TABLE_HINTS = False + INDEX_ON = "ON TABLE" TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.DATETIME: "TIMESTAMP", exp.DataType.Type.VARBINARY: "BINARY", @@ -294,7 +294,7 @@ class Hive(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Select: transforms.preprocess( [ @@ -319,7 +319,6 @@ class Hive(Dialect): exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}", exp.FromBase64: rename_func("UNBASE64"), exp.If: if_sql, - exp.Index: _index_sql, exp.ILike: no_ilike_sql, exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), @@ -342,7 +341,6 @@ class Hive(Dialect): exp.StrToTime: _str_to_time_sql, exp.StrToUnix: _str_to_unix_sql, exp.StructExtract: struct_extract_sql, - exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}", exp.TimeStrToDate: rename_func("TO_DATE"), exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), @@ -363,14 +361,13 @@ class Hive(Dialect): exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"), exp.NumberToStr: rename_func("FORMAT_NUMBER"), exp.LastDateOfMonth: rename_func("LAST_DAY"), - exp.National: lambda self, e: self.sql(e, "this"), + exp.National: lambda self, e: self.national_sql(e, prefix=""), } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA, exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, - exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } @@ -396,3 +393,10 @@ class Hive(Dialect): expression = exp.DataType.build(expression.this) return super().datatype_sql(expression) + + def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]: + return super().after_having_modifiers(expression) + [ + self.sql(expression, "distribute"), + self.sql(expression, "sort"), + self.sql(expression, "cluster"), + ] diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 5342624..2b41860 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, @@ -11,6 +13,7 @@ from sqlglot.dialects.dialect import ( min_or_least, no_ilike_sql, no_paren_current_date_sql, + no_pivot_sql, no_tablesample_sql, no_trycast_sql, parse_date_delta_with_interval, @@ -21,14 +24,14 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _show_parser(*args, **kwargs): - def _parse(self): +def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], exp.Show]: + def _parse(self: MySQL.Parser) -> exp.Show: return self._parse_show_mysql(*args, **kwargs) return _parse -def _date_trunc_sql(self, expression): +def _date_trunc_sql(self: generator.Generator, expression: exp.DateTrunc) -> str: expr = self.sql(expression, "this") unit = expression.text("unit") @@ -54,17 +57,17 @@ def _date_trunc_sql(self, expression): return f"STR_TO_DATE({concat}, '{date_format}')" -def _str_to_date(args): +def _str_to_date(args: t.List) -> exp.StrToDate: date_format = MySQL.format_time(seq_get(args, 1)) return exp.StrToDate(this=seq_get(args, 0), format=date_format) -def _str_to_date_sql(self, expression): +def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate | exp.StrToTime) -> str: date_format = self.format_time(expression) return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})" -def _trim_sql(self, expression): +def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str: target = self.sql(expression, "this") trim_type = self.sql(expression, "position") remove_chars = self.sql(expression, "expression") @@ -79,8 +82,8 @@ def _trim_sql(self, expression): return f"TRIM({trim_type}{remove_chars}{from_part}{target})" -def _date_add_sql(kind): - def func(self, expression): +def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: + def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = expression.text("unit").upper() or "DAY" return ( @@ -175,10 +178,10 @@ class MySQL(Dialect): COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW} class Parser(parser.Parser): - FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} # type: ignore + FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"), "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), @@ -191,7 +194,7 @@ class MySQL(Dialect): } FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, # type: ignore + **parser.Parser.FUNCTION_PARSERS, "GROUP_CONCAT": lambda self: self.expression( exp.GroupConcat, this=self._parse_lambda(), @@ -199,13 +202,8 @@ class MySQL(Dialect): ), } - PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, # type: ignore - "ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty), - } - STATEMENT_PARSERS = { - **parser.Parser.STATEMENT_PARSERS, # type: ignore + **parser.Parser.STATEMENT_PARSERS, TokenType.SHOW: lambda self: self._parse_show(), } @@ -286,7 +284,13 @@ class MySQL(Dialect): LOG_DEFAULTS_TO_LN = True - def _parse_show_mysql(self, this, target=False, full=None, global_=None): + def _parse_show_mysql( + self, + this: str, + target: bool | str = False, + full: t.Optional[bool] = None, + global_: t.Optional[bool] = None, + ) -> exp.Show: if target: if isinstance(target, str): self._match_text_seq(target) @@ -342,10 +346,12 @@ class MySQL(Dialect): offset=offset, limit=limit, mutex=mutex, - **{"global": global_}, + **{"global": global_}, # type: ignore ) - def _parse_oldstyle_limit(self): + def _parse_oldstyle_limit( + self, + ) -> t.Tuple[t.Optional[exp.Expression], t.Optional[exp.Expression]]: limit = None offset = None if self._match_text_seq("LIMIT"): @@ -355,23 +361,20 @@ class MySQL(Dialect): elif len(parts) == 2: limit = parts[1] offset = parts[0] + return offset, limit - def _parse_set_item_charset(self, kind): + def _parse_set_item_charset(self, kind: str) -> exp.Expression: this = self._parse_string() or self._parse_id_var() + return self.expression(exp.SetItem, this=this, kind=kind) - return self.expression( - exp.SetItem, - this=this, - kind=kind, - ) - - def _parse_set_item_names(self): + def _parse_set_item_names(self) -> exp.Expression: charset = self._parse_string() or self._parse_id_var() if self._match_text_seq("COLLATE"): collate = self._parse_string() or self._parse_id_var() else: collate = None + return self.expression( exp.SetItem, this=charset, @@ -386,7 +389,7 @@ class MySQL(Dialect): TABLE_HINTS = False TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.CurrentDate: no_paren_current_date_sql, exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression), exp.DateAdd: _date_add_sql("ADD"), @@ -403,6 +406,7 @@ class MySQL(Dialect): exp.Min: min_or_least, exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), + exp.Pivot: no_pivot_sql, exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date_sql, @@ -422,7 +426,7 @@ class MySQL(Dialect): TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB) PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index c8af1c6..7722753 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -8,7 +8,7 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _parse_xml_table(self) -> exp.XMLTable: +def _parse_xml_table(self: parser.Parser) -> exp.XMLTable: this = self._parse_string() passing = None @@ -66,7 +66,7 @@ class Oracle(Dialect): WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP} FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), } @@ -107,7 +107,7 @@ class Oracle(Dialect): TABLE_HINTS = False TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "NUMBER", exp.DataType.Type.SMALLINT: "NUMBER", exp.DataType.Type.INT: "NUMBER", @@ -122,7 +122,7 @@ class Oracle(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.DateStrToDate: lambda self, e: self.func( "TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD") ), @@ -143,7 +143,7 @@ class Oracle(Dialect): } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 2132778..ab61880 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import ( max_or_greatest, min_or_least, no_paren_current_date_sql, + no_pivot_sql, no_tablesample_sql, no_trycast_sql, rename_func, @@ -33,8 +34,8 @@ DATE_DIFF_FACTOR = { } -def _date_add_sql(kind): - def func(self, expression): +def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: + def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: from sqlglot.optimizer.simplify import simplify this = self.sql(expression, "this") @@ -51,7 +52,7 @@ def _date_add_sql(kind): return func -def _date_diff_sql(self, expression): +def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: unit = expression.text("unit").upper() factor = DATE_DIFF_FACTOR.get(unit) @@ -77,7 +78,7 @@ def _date_diff_sql(self, expression): return f"CAST({unit} AS BIGINT)" -def _substring_sql(self, expression): +def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str: this = self.sql(expression, "this") start = self.sql(expression, "start") length = self.sql(expression, "length") @@ -88,7 +89,7 @@ def _substring_sql(self, expression): return f"SUBSTRING({this}{from_part}{for_part})" -def _string_agg_sql(self, expression): +def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str: expression = expression.copy() separator = expression.args.get("separator") or exp.Literal.string(",") @@ -102,13 +103,13 @@ def _string_agg_sql(self, expression): return f"STRING_AGG({self.format_args(this, separator)}{order})" -def _datatype_sql(self, expression): +def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: if expression.this == exp.DataType.Type.ARRAY: return f"{self.expressions(expression, flat=True)}[]" return self.datatype_sql(expression) -def _auto_increment_to_serial(expression): +def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression: auto = expression.find(exp.AutoIncrementColumnConstraint) if auto: @@ -126,7 +127,7 @@ def _auto_increment_to_serial(expression): return expression -def _serial_to_generated(expression): +def _serial_to_generated(expression: exp.Expression) -> exp.Expression: kind = expression.args["kind"] if kind.this == exp.DataType.Type.SERIAL: @@ -144,6 +145,7 @@ def _serial_to_generated(expression): constraints = expression.args["constraints"] generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False)) notnull = exp.ColumnConstraint(kind=exp.NotNullColumnConstraint()) + if notnull not in constraints: constraints.insert(0, notnull) if generated not in constraints: @@ -152,7 +154,7 @@ def _serial_to_generated(expression): return expression -def _generate_series(args): +def _generate_series(args: t.List) -> exp.Expression: # The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day step = seq_get(args, 2) @@ -168,11 +170,12 @@ def _generate_series(args): return exp.GenerateSeries.from_arg_list(args) -def _to_timestamp(args): +def _to_timestamp(args: t.List) -> exp.Expression: # TO_TIMESTAMP accepts either a single double argument or (text, text) if len(args) == 1: # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE return exp.UnixToTime.from_arg_list(args) + # https://www.postgresql.org/docs/current/functions-formatting.html return format_time_lambda(exp.StrToTime, "postgres")(args) @@ -255,7 +258,7 @@ class Postgres(Dialect): STRICT_CAST = False FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=seq_get(args, 0) ), @@ -271,7 +274,7 @@ class Postgres(Dialect): } BITWISE = { - **parser.Parser.BITWISE, # type: ignore + **parser.Parser.BITWISE, TokenType.HASH: exp.BitwiseXor, } @@ -280,7 +283,7 @@ class Postgres(Dialect): } RANGE_PARSERS = { - **parser.Parser.RANGE_PARSERS, # type: ignore + **parser.Parser.RANGE_PARSERS, TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps), TokenType.AT_GT: binary_range_parser(exp.ArrayContains), TokenType.LT_AT: binary_range_parser(exp.ArrayContained), @@ -303,14 +306,14 @@ class Postgres(Dialect): return self.expression(exp.Extract, this=part, expression=value) class Generator(generator.Generator): - INTERVAL_ALLOWS_PLURAL_FORM = False + SINGLE_STRING_INTERVAL = True LOCKING_READS_SUPPORTED = True JOIN_HINTS = False TABLE_HINTS = False PARAMETER_TOKEN = "$" TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "SMALLINT", exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", @@ -320,14 +323,9 @@ class Postgres(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.BitwiseXor: lambda self, e: self.binary(e, "#"), - exp.ColumnDef: transforms.preprocess( - [ - _auto_increment_to_serial, - _serial_to_generated, - ], - ), + exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), @@ -348,6 +346,7 @@ class Postgres(Dialect): exp.ArrayContains: lambda self, e: self.binary(e, "@>"), exp.ArrayContained: lambda self, e: self.binary(e, "<@"), exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]), + exp.Pivot: no_pivot_sql, exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), exp.StrPosition: str_position_sql, @@ -369,7 +368,7 @@ class Postgres(Dialect): } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 6133a27..52a04a4 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import ( format_time_lambda, if_sql, no_ilike_sql, + no_pivot_sql, no_safe_divide_sql, rename_func, struct_extract_sql, @@ -127,39 +128,12 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s ) -def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str: - start = expression.args["start"] - end = expression.args["end"] - step = expression.args.get("step") - - target_type = None - - if isinstance(start, exp.Cast): - target_type = start.to - elif isinstance(end, exp.Cast): - target_type = end.to - - if target_type and target_type.this == exp.DataType.Type.TIMESTAMP: - to = target_type.copy() - - if target_type is start.to: - end = exp.Cast(this=end, to=to) - else: - start = exp.Cast(this=start, to=to) - - sql = self.func("SEQUENCE", start, end, step) - if isinstance(expression.parent, exp.Table): - sql = f"UNNEST({sql})" - - return sql - - def _ensure_utf8(charset: exp.Literal) -> None: if charset.name.lower() != "utf-8": raise UnsupportedError(f"Unsupported charset {charset}") -def _approx_percentile(args: t.Sequence) -> exp.Expression: +def _approx_percentile(args: t.List) -> exp.Expression: if len(args) == 4: return exp.ApproxQuantile( this=seq_get(args, 0), @@ -176,7 +150,7 @@ def _approx_percentile(args: t.Sequence) -> exp.Expression: return exp.ApproxQuantile.from_arg_list(args) -def _from_unixtime(args: t.Sequence) -> exp.Expression: +def _from_unixtime(args: t.List) -> exp.Expression: if len(args) == 3: return exp.UnixToTime( this=seq_get(args, 0), @@ -191,22 +165,39 @@ def _from_unixtime(args: t.Sequence) -> exp.Expression: return exp.UnixToTime.from_arg_list(args) +def _unnest_sequence(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Table): + if isinstance(expression.this, exp.GenerateSeries): + unnest = exp.Unnest(expressions=[expression.this]) + + if expression.alias: + return exp.alias_( + unnest, + alias="_u", + table=[expression.alias], + copy=False, + ) + return unnest + return expression + + class Presto(Dialect): index_offset = 1 null_ordering = "nulls_are_last" - time_format = MySQL.time_format # type: ignore - time_mapping = MySQL.time_mapping # type: ignore + time_format = MySQL.time_format + time_mapping = MySQL.time_mapping class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "START": TokenType.BEGIN, + "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "ROW": TokenType.STRUCT, } class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, "APPROX_PERCENTILE": _approx_percentile, "CARDINALITY": exp.ArraySize.from_arg_list, @@ -252,13 +243,13 @@ class Presto(Dialect): STRUCT_DELIMITER = ("(", ")") PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.LocationProperty: exp.Properties.Location.UNSUPPORTED, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.BINARY: "VARBINARY", @@ -268,8 +259,9 @@ class Presto(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.ApproxDistinct: _approx_distinct_sql, + exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayContains: rename_func("CONTAINS"), @@ -293,7 +285,7 @@ class Presto(Dialect): exp.Decode: _decode_sql, exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", exp.Encode: _encode_sql, - exp.GenerateSeries: _sequence_sql, + exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Hex: rename_func("TO_HEX"), exp.If: if_sql, @@ -301,10 +293,10 @@ class Presto(Dialect): exp.Initcap: _initcap_sql, exp.Lateral: _explode_to_unnest_sql, exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), - exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), + exp.LogicalOr: rename_func("BOOL_OR"), + exp.Pivot: no_pivot_sql, exp.Quantile: _quantile_sql, - exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, exp.Select: transforms.preprocess( @@ -320,8 +312,7 @@ class Presto(Dialect): exp.StrToTime: _str_to_time_sql, exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", exp.StructExtract: struct_extract_sql, - exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'", - exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", + exp.Table: transforms.preprocess([_unnest_sequence]), exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToDate: timestrtotime_sql, exp.TimeStrToTime: timestrtotime_sql, @@ -336,6 +327,7 @@ class Presto(Dialect): exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", exp.VariancePop: rename_func("VAR_POP"), + exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]), exp.WithinGroup: transforms.preprocess( [transforms.remove_within_group_for_percentiles] ), @@ -351,3 +343,25 @@ class Presto(Dialect): modes = expression.args.get("modes") modes = f" {', '.join(modes)}" if modes else "" return f"START TRANSACTION{modes}" + + def generateseries_sql(self, expression: exp.GenerateSeries) -> str: + start = expression.args["start"] + end = expression.args["end"] + step = expression.args.get("step") + + if isinstance(start, exp.Cast): + target_type = start.to + elif isinstance(end, exp.Cast): + target_type = end.to + else: + target_type = None + + if target_type and target_type.is_type(exp.DataType.Type.TIMESTAMP): + to = target_type.copy() + + if target_type is start.to: + end = exp.Cast(this=end, to=to) + else: + start = exp.Cast(this=start, to=to) + + return self.func("SEQUENCE", start, end, step) diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 1b7cf31..55e393a 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -8,21 +8,21 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _json_sql(self, e) -> str: - return f'{self.sql(e, "this")}."{e.expression.name}"' +def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str: + return f'{self.sql(expression, "this")}."{expression.expression.name}"' class Redshift(Postgres): time_format = "'YYYY-MM-DD HH:MI:SS'" time_mapping = { - **Postgres.time_mapping, # type: ignore + **Postgres.time_mapping, "MON": "%b", "HH": "%H", } class Parser(Postgres.Parser): FUNCTIONS = { - **Postgres.Parser.FUNCTIONS, # type: ignore + **Postgres.Parser.FUNCTIONS, "DATEADD": lambda args: exp.DateAdd( this=seq_get(args, 2), expression=seq_get(args, 1), @@ -45,7 +45,7 @@ class Redshift(Postgres): isinstance(this, exp.DataType) and this.this == exp.DataType.Type.VARCHAR and this.expressions - and this.expressions[0] == exp.column("MAX") + and this.expressions[0].this == exp.column("MAX") ): this.set("expressions", [exp.Var(this="MAX")]) @@ -57,9 +57,7 @@ class Redshift(Postgres): STRING_ESCAPES = ["\\"] KEYWORDS = { - **Postgres.Tokenizer.KEYWORDS, # type: ignore - "GEOMETRY": TokenType.GEOMETRY, - "GEOGRAPHY": TokenType.GEOGRAPHY, + **Postgres.Tokenizer.KEYWORDS, "HLLSKETCH": TokenType.HLLSKETCH, "SUPER": TokenType.SUPER, "SYSDATE": TokenType.CURRENT_TIMESTAMP, @@ -76,22 +74,22 @@ class Redshift(Postgres): class Generator(Postgres.Generator): LOCKING_READS_SUPPORTED = False - SINGLE_STRING_INTERVAL = True + RENAME_TABLE_WITH_DB = False TYPE_MAPPING = { - **Postgres.Generator.TYPE_MAPPING, # type: ignore + **Postgres.Generator.TYPE_MAPPING, exp.DataType.Type.BINARY: "VARBYTE", exp.DataType.Type.VARBINARY: "VARBYTE", exp.DataType.Type.INT: "INTEGER", } PROPERTIES_LOCATION = { - **Postgres.Generator.PROPERTIES_LOCATION, # type: ignore + **Postgres.Generator.PROPERTIES_LOCATION, exp.LikeProperty: exp.Properties.Location.POST_WITH, } TRANSFORMS = { - **Postgres.Generator.TRANSFORMS, # type: ignore + **Postgres.Generator.TRANSFORMS, exp.CurrentTimestamp: lambda self, e: "SYSDATE", exp.DateAdd: lambda self, e: self.func( "DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this @@ -107,10 +105,13 @@ class Redshift(Postgres): exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", } + # Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots + TRANSFORMS.pop(exp.Pivot) + # Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres) TRANSFORMS.pop(exp.Pow) - RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot"} + RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"} def values_sql(self, expression: exp.Values) -> str: """ @@ -120,37 +121,36 @@ class Redshift(Postgres): evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be very slow. """ - if not isinstance(expression.unnest().parent, exp.From): + + # The VALUES clause is still valid in an `INSERT INTO ..` statement, for example + if not expression.find_ancestor(exp.From, exp.Join): return super().values_sql(expression) - rows = [tuple_exp.expressions for tuple_exp in expression.expressions] + + column_names = expression.alias and expression.args["alias"].columns + selects = [] + rows = [tuple_exp.expressions for tuple_exp in expression.expressions] + for i, row in enumerate(rows): - if i == 0 and expression.alias: + if i == 0 and column_names: row = [ exp.alias_(value, column_name) - for value, column_name in zip(row, expression.args["alias"].args["columns"]) + for value, column_name in zip(row, column_names) ] + selects.append(exp.Select(expressions=row)) - subquery_expression = selects[0] + + subquery_expression: exp.Select | exp.Union = selects[0] if len(selects) > 1: for select in selects[1:]: subquery_expression = exp.union(subquery_expression, select, distinct=False) + return self.subquery_sql(subquery_expression.subquery(expression.alias)) def with_properties(self, properties: exp.Properties) -> str: """Redshift doesn't have `WITH` as part of their with_properties so we remove it""" return self.properties(properties, prefix=" ", suffix="") - def renametable_sql(self, expression: exp.RenameTable) -> str: - """Redshift only supports defining the table name itself (not the db) when renaming tables""" - expression = expression.copy() - target_table = expression.this - for arg in target_table.args: - if arg != "this": - target_table.set(arg, None) - this = self.sql(expression, "this") - return f"RENAME TO {this}" - def datatype_sql(self, expression: exp.DataType) -> str: """ Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean @@ -162,6 +162,8 @@ class Redshift(Postgres): expression = expression.copy() expression.set("this", exp.DataType.Type.VARCHAR) precision = expression.args.get("expressions") + if not precision: expression.append("expressions", exp.Var(this="MAX")) + return super().datatype_sql(expression) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 70dcaa9..756e8e9 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -18,7 +18,7 @@ from sqlglot.dialects.dialect import ( var_map_sql, ) from sqlglot.expressions import Literal -from sqlglot.helper import flatten, seq_get +from sqlglot.helper import seq_get from sqlglot.parser import binary_range_parser from sqlglot.tokens import TokenType @@ -30,7 +30,7 @@ def _check_int(s: str) -> bool: # from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html -def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.UnixToTime]: +def _snowflake_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTime]: if len(args) == 2: first_arg, second_arg = args if second_arg.is_string: @@ -52,8 +52,12 @@ def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.Unix return exp.UnixToTime(this=first_arg, scale=timescale) + from sqlglot.optimizer.simplify import simplify_literals + + # The first argument might be an expression like 40 * 365 * 86400, so we try to + # reduce it using `simplify_literals` first and then check if it's a Literal. first_arg = seq_get(args, 0) - if not isinstance(first_arg, Literal): + if not isinstance(simplify_literals(first_arg, root=True), Literal): # case: <variant_expr> return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args) @@ -69,6 +73,19 @@ def _snowflake_to_timestamp(args: t.Sequence) -> t.Union[exp.StrToTime, exp.Unix return exp.UnixToTime.from_arg_list(args) +def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: + expression = parser.parse_var_map(args) + + if isinstance(expression, exp.StarMap): + return expression + + return exp.Struct( + expressions=[ + t.cast(exp.Condition, k).eq(v) for k, v in zip(expression.keys, expression.values) + ] + ) + + def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") timestamp = self.sql(expression, "this") @@ -116,7 +133,7 @@ def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]: # https://docs.snowflake.com/en/sql-reference/functions/div0 -def _div0_to_if(args: t.Sequence) -> exp.Expression: +def _div0_to_if(args: t.List) -> exp.Expression: cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0)) true = exp.Literal.number(0) false = exp.Div(this=seq_get(args, 0), expression=seq_get(args, 1)) @@ -124,13 +141,13 @@ def _div0_to_if(args: t.Sequence) -> exp.Expression: # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _zeroifnull_to_if(args: t.Sequence) -> exp.Expression: +def _zeroifnull_to_if(args: t.List) -> exp.Expression: cond = exp.Is(this=seq_get(args, 0), expression=exp.Null()) return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0)) # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _nullifzero_to_if(args: t.Sequence) -> exp.Expression: +def _nullifzero_to_if(args: t.List) -> exp.Expression: cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0)) return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) @@ -143,6 +160,12 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: return self.datatype_sql(expression) +def _parse_convert_timezone(args: t.List) -> exp.Expression: + if len(args) == 3: + return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args) + return exp.AtTimeZone(this=seq_get(args, 1), zone=seq_get(args, 0)) + + class Snowflake(Dialect): null_ordering = "nulls_are_large" time_format = "'yyyy-mm-dd hh24:mi:ss'" @@ -177,17 +200,14 @@ class Snowflake(Dialect): } class Parser(parser.Parser): - QUOTED_PIVOT_COLUMNS = True + IDENTIFY_PIVOT_STRINGS = True FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, - "CONVERT_TIMEZONE": lambda args: exp.AtTimeZone( - this=seq_get(args, 1), - zone=seq_get(args, 0), - ), + "CONVERT_TIMEZONE": _parse_convert_timezone, "DATE_TRUNC": date_trunc_to_time, "DATEADD": lambda args: exp.DateAdd( this=seq_get(args, 2), @@ -202,7 +222,7 @@ class Snowflake(Dialect): "DIV0": _div0_to_if, "IFF": exp.If.from_arg_list, "NULLIFZERO": _nullifzero_to_if, - "OBJECT_CONSTRUCT": parser.parse_var_map, + "OBJECT_CONSTRUCT": _parse_object_construct, "RLIKE": exp.RegexpLike.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "TO_ARRAY": exp.Array.from_arg_list, @@ -224,7 +244,7 @@ class Snowflake(Dialect): } COLUMN_OPERATORS = { - **parser.Parser.COLUMN_OPERATORS, # type: ignore + **parser.Parser.COLUMN_OPERATORS, TokenType.COLON: lambda self, this, path: self.expression( exp.Bracket, this=this, @@ -232,14 +252,16 @@ class Snowflake(Dialect): ), } + TIMESTAMPS = parser.Parser.TIMESTAMPS.copy() - {TokenType.TIME} + RANGE_PARSERS = { - **parser.Parser.RANGE_PARSERS, # type: ignore + **parser.Parser.RANGE_PARSERS, TokenType.LIKE_ANY: binary_range_parser(exp.LikeAny), TokenType.ILIKE_ANY: binary_range_parser(exp.ILikeAny), } ALTER_PARSERS = { - **parser.Parser.ALTER_PARSERS, # type: ignore + **parser.Parser.ALTER_PARSERS, "UNSET": lambda self: self._parse_alter_table_set_tag(unset=True), "SET": lambda self: self._parse_alter_table_set_tag(), } @@ -256,17 +278,20 @@ class Snowflake(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "CHAR VARYING": TokenType.VARCHAR, + "CHARACTER VARYING": TokenType.VARCHAR, "EXCLUDE": TokenType.EXCEPT, "ILIKE ANY": TokenType.ILIKE_ANY, "LIKE ANY": TokenType.LIKE_ANY, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, + "MINUS": TokenType.EXCEPT, + "NCHAR VARYING": TokenType.VARCHAR, "PUT": TokenType.COMMAND, "RENAME": TokenType.REPLACE, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, "TIMESTAMPNTZ": TokenType.TIMESTAMP, - "MINUS": TokenType.EXCEPT, "SAMPLE": TokenType.TABLE_SAMPLE, } @@ -285,7 +310,7 @@ class Snowflake(Dialect): TABLE_HINTS = False TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.ArrayJoin: rename_func("ARRAY_TO_STRING"), @@ -299,6 +324,7 @@ class Snowflake(Dialect): exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), + exp.Extract: rename_func("DATE_PART"), exp.If: rename_func("IFF"), exp.LogicalAnd: rename_func("BOOLAND_AGG"), exp.LogicalOr: rename_func("BOOLOR_AGG"), @@ -312,6 +338,10 @@ class Snowflake(Dialect): "POSITION", e.args.get("substr"), e.this, e.args.get("position") ), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.Struct: lambda self, e: self.func( + "OBJECT_CONSTRUCT", + *(arg for expression in e.expressions for arg in expression.flatten()), + ), exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.TimeToStr: lambda self, e: self.func( @@ -326,7 +356,7 @@ class Snowflake(Dialect): } TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", } @@ -336,7 +366,7 @@ class Snowflake(Dialect): } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.SetProperty: exp.Properties.Location.UNSUPPORTED, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } @@ -351,53 +381,10 @@ class Snowflake(Dialect): self.unsupported("INTERSECT with All is not supported in Snowflake") return super().intersect_op(expression) - def values_sql(self, expression: exp.Values) -> str: - """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted. - - We also want to make sure that after we find matches where we need to unquote a column that we prevent users - from adding quotes to the column by using the `identify` argument when generating the SQL. - """ - alias = expression.args.get("alias") - if alias and alias.args.get("columns"): - expression = expression.transform( - lambda node: exp.Identifier(**{**node.args, "quoted": False}) - if isinstance(node, exp.Identifier) - and isinstance(node.parent, exp.TableAlias) - and node.arg_key == "columns" - else node, - ) - return self.no_identify(lambda: super(self.__class__, self).values_sql(expression)) - return super().values_sql(expression) - def settag_sql(self, expression: exp.SetTag) -> str: action = "UNSET" if expression.args.get("unset") else "SET" return f"{action} TAG {self.expressions(expression)}" - def select_sql(self, expression: exp.Select) -> str: - """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also - that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need - to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when - generating the SQL. - - Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the - expression. This might not be true in a case where the same column name can be sourced from another table that can - properly quote but should be true in most cases. - """ - values_identifiers = set( - flatten( - (v.args.get("alias") or exp.Alias()).args.get("columns", []) - for v in expression.find_all(exp.Values) - ) - ) - if values_identifiers: - expression = expression.transform( - lambda node: exp.Identifier(**{**node.args, "quoted": False}) - if isinstance(node, exp.Identifier) and node in values_identifiers - else node, - ) - return self.no_identify(lambda: super(self.__class__, self).select_sql(expression)) - return super().select_sql(expression) - def describe_sql(self, expression: exp.Describe) -> str: # Default to table if kind is unknown kind_value = expression.args.get("kind") or "TABLE" diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 939f2fd..b7d1641 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -7,10 +7,10 @@ from sqlglot.dialects.spark2 import Spark2 from sqlglot.helper import seq_get -def _parse_datediff(args: t.Sequence) -> exp.Expression: +def _parse_datediff(args: t.List) -> exp.Expression: """ Although Spark docs don't mention the "unit" argument, Spark3 added support for - it at some point. Databricks also supports this variation (see below). + it at some point. Databricks also supports this variant (see below). For example, in spark-sql (v3.3.1): - SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4 @@ -36,7 +36,7 @@ def _parse_datediff(args: t.Sequence) -> exp.Expression: class Spark(Spark2): class Parser(Spark2.Parser): FUNCTIONS = { - **Spark2.Parser.FUNCTIONS, # type: ignore + **Spark2.Parser.FUNCTIONS, "DATEDIFF": _parse_datediff, } diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 584671f..912b86b 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -3,7 +3,12 @@ from __future__ import annotations import typing as t from sqlglot import exp, parser, transforms -from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql +from sqlglot.dialects.dialect import ( + create_with_partitions_sql, + pivot_column_names, + rename_func, + trim_sql, +) from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get @@ -26,7 +31,7 @@ def _map_sql(self: Hive.Generator, expression: exp.Map) -> str: return f"MAP_FROM_ARRAYS({keys}, {values})" -def _parse_as_cast(to_type: str) -> t.Callable[[t.Sequence], exp.Expression]: +def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]: return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type)) @@ -53,10 +58,56 @@ def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str: raise ValueError("Improper scale for timestamp") +def _unalias_pivot(expression: exp.Expression) -> exp.Expression: + """ + Spark doesn't allow PIVOT aliases, so we need to remove them and possibly wrap a + pivoted source in a subquery with the same alias to preserve the query's semantics. + + Example: + >>> from sqlglot import parse_one + >>> expr = parse_one("SELECT piv.x FROM tbl PIVOT (SUM(a) FOR b IN ('x')) piv") + >>> print(_unalias_pivot(expr).sql(dialect="spark")) + SELECT piv.x FROM (SELECT * FROM tbl PIVOT(SUM(a) FOR b IN ('x'))) AS piv + """ + if isinstance(expression, exp.From) and expression.this.args.get("pivots"): + pivot = expression.this.args["pivots"][0] + if pivot.alias: + alias = pivot.args["alias"].pop() + return exp.From( + this=expression.this.replace( + exp.select("*").from_(expression.this.copy()).subquery(alias=alias) + ) + ) + + return expression + + +def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: + """ + Spark doesn't allow the column referenced in the PIVOT's field to be qualified, + so we need to unqualify it. + + Example: + >>> from sqlglot import parse_one + >>> expr = parse_one("SELECT * FROM tbl PIVOT (SUM(tbl.sales) FOR tbl.quarter IN ('Q1', 'Q2'))") + >>> print(_unqualify_pivot_columns(expr).sql(dialect="spark")) + SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1')) + """ + if isinstance(expression, exp.Pivot): + expression.args["field"].transform( + lambda node: exp.column(node.output_name, quoted=node.this.quoted) + if isinstance(node, exp.Column) + else node, + copy=False, + ) + + return expression + + class Spark2(Hive): class Parser(Hive.Parser): FUNCTIONS = { - **Hive.Parser.FUNCTIONS, # type: ignore + **Hive.Parser.FUNCTIONS, "MAP_FROM_ARRAYS": exp.Map.from_arg_list, "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, "LEFT": lambda args: exp.Substring( @@ -110,7 +161,7 @@ class Spark2(Hive): } FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, # type: ignore + **parser.Parser.FUNCTION_PARSERS, "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), @@ -131,43 +182,21 @@ class Spark2(Hive): kind="COLUMNS", ) - def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]: - # Spark doesn't add a suffix to the pivot columns when there's a single aggregation - if len(pivot_columns) == 1: + def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: + if len(aggregations) == 1: return [""] - - names = [] - for agg in pivot_columns: - if isinstance(agg, exp.Alias): - names.append(agg.alias) - else: - """ - This case corresponds to aggregations without aliases being used as suffixes - (e.g. col_avg(foo)). We need to unquote identifiers because they're going to - be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. - Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). - - Moreover, function names are lowercased in order to mimic Spark's naming scheme. - """ - agg_all_unquoted = agg.transform( - lambda node: exp.Identifier(this=node.name, quoted=False) - if isinstance(node, exp.Identifier) - else node - ) - names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower")) - - return names + return pivot_column_names(aggregations, dialect="spark") class Generator(Hive.Generator): TYPE_MAPPING = { - **Hive.Generator.TYPE_MAPPING, # type: ignore + **Hive.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "BYTE", exp.DataType.Type.SMALLINT: "SHORT", exp.DataType.Type.BIGINT: "LONG", } PROPERTIES_LOCATION = { - **Hive.Generator.PROPERTIES_LOCATION, # type: ignore + **Hive.Generator.PROPERTIES_LOCATION, exp.EngineProperty: exp.Properties.Location.UNSUPPORTED, exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED, exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED, @@ -175,7 +204,7 @@ class Spark2(Hive): } TRANSFORMS = { - **Hive.Generator.TRANSFORMS, # type: ignore + **Hive.Generator.TRANSFORMS, exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", @@ -188,11 +217,12 @@ class Spark2(Hive): exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfYear: rename_func("DAYOFYEAR"), exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", + exp.From: transforms.preprocess([_unalias_pivot]), exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalOr: rename_func("BOOL_OR"), exp.Map: _map_sql, - exp.Pivot: transforms.preprocess([transforms.unqualify_pivot_columns]), + exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]), exp.Reduce: rename_func("AGGREGATE"), exp.StrToDate: _str_to_date, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index f2efe32..56e7773 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -7,6 +7,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_sql, count_if_to_sum, no_ilike_sql, + no_pivot_sql, no_tablesample_sql, no_trycast_sql, rename_func, @@ -14,7 +15,7 @@ from sqlglot.dialects.dialect import ( from sqlglot.tokens import TokenType -def _date_add_sql(self, expression): +def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str: modifier = expression.expression modifier = modifier.name if modifier.is_string else self.sql(modifier) unit = expression.args.get("unit") @@ -67,7 +68,7 @@ class SQLite(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "EDITDIST3": exp.Levenshtein.from_arg_list, } @@ -76,7 +77,7 @@ class SQLite(Dialect): TABLE_HINTS = False TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.BOOLEAN: "INTEGER", exp.DataType.Type.TINYINT: "INTEGER", exp.DataType.Type.SMALLINT: "INTEGER", @@ -98,7 +99,7 @@ class SQLite(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.CountIf: count_if_to_sum, exp.Create: transforms.preprocess([_transform_create]), exp.CurrentDate: lambda *_: "CURRENT_DATE", @@ -114,6 +115,7 @@ class SQLite(Dialect): exp.Levenshtein: rename_func("EDITDIST3"), exp.LogicalOr: rename_func("MAX"), exp.LogicalAnd: rename_func("MIN"), + exp.Pivot: no_pivot_sql, exp.Select: transforms.preprocess( [transforms.eliminate_distinct_on, transforms.eliminate_qualify] ), @@ -163,12 +165,15 @@ class SQLite(Dialect): return f"CAST({sql} AS INTEGER)" # https://www.sqlite.org/lang_aggfunc.html#group_concat - def groupconcat_sql(self, expression): + def groupconcat_sql(self, expression: exp.GroupConcat) -> str: this = expression.this distinct = expression.find(exp.Distinct) + if distinct: this = distinct.expressions[0] - distinct = "DISTINCT " + distinct_sql = "DISTINCT " + else: + distinct_sql = "" if isinstance(expression.this, exp.Order): self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.") @@ -176,7 +181,7 @@ class SQLite(Dialect): this = expression.this.this separator = expression.args.get("separator") - return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})" + return f"GROUP_CONCAT({distinct_sql}{self.format_args(this, separator)})" def least_sql(self, expression: exp.Least) -> str: if len(expression.expressions) > 1: diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 895588a..0390113 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -11,25 +11,24 @@ from sqlglot.helper import seq_get class StarRocks(MySQL): - class Parser(MySQL.Parser): # type: ignore + class Parser(MySQL.Parser): FUNCTIONS = { **MySQL.Parser.FUNCTIONS, - "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=seq_get(args, 0) ), } - class Generator(MySQL.Generator): # type: ignore + class Generator(MySQL.Generator): TYPE_MAPPING = { - **MySQL.Generator.TYPE_MAPPING, # type: ignore + **MySQL.Generator.TYPE_MAPPING, exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.TIMESTAMP: "DATETIME", exp.DataType.Type.TIMESTAMPTZ: "DATETIME", } TRANSFORMS = { - **MySQL.Generator.TRANSFORMS, # type: ignore + **MySQL.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, @@ -43,4 +42,5 @@ class StarRocks(MySQL): exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), } + TRANSFORMS.pop(exp.DateTrunc) diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index 51e685b..d5fba17 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -4,41 +4,38 @@ from sqlglot import exp, generator, parser, transforms from sqlglot.dialects.dialect import Dialect -def _if_sql(self, expression): - return f"IF {self.sql(expression, 'this')} THEN {self.sql(expression, 'true')} ELSE {self.sql(expression, 'false')} END" - - -def _coalesce_sql(self, expression): - return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})" - - -def _count_sql(self, expression): - this = expression.this - if isinstance(this, exp.Distinct): - return f"COUNTD({self.expressions(this, flat=True)})" - return f"COUNT({self.sql(expression, 'this')})" - - class Tableau(Dialect): class Generator(generator.Generator): JOIN_HINTS = False TABLE_HINTS = False TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore - exp.If: _if_sql, - exp.Coalesce: _coalesce_sql, - exp.Count: _count_sql, + **generator.Generator.TRANSFORMS, exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def if_sql(self, expression: exp.If) -> str: + this = self.sql(expression, "this") + true = self.sql(expression, "true") + false = self.sql(expression, "false") + return f"IF {this} THEN {true} ELSE {false} END" + + def coalesce_sql(self, expression: exp.Coalesce) -> str: + return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})" + + def count_sql(self, expression: exp.Count) -> str: + this = expression.this + if isinstance(this, exp.Distinct): + return f"COUNTD({self.expressions(this, flat=True)})" + return f"COUNT({self.sql(expression, 'this')})" + class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)), } diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index a79eaeb..9b39178 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -75,12 +75,12 @@ class Teradata(Dialect): FUNC_TOKENS.remove(TokenType.REPLACE) STATEMENT_PARSERS = { - **parser.Parser.STATEMENT_PARSERS, # type: ignore + **parser.Parser.STATEMENT_PARSERS, TokenType.REPLACE: lambda self: self._parse_create(), } FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, # type: ignore + **parser.Parser.FUNCTION_PARSERS, "RANGE_N": lambda self: self._parse_rangen(), "TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST), } @@ -106,7 +106,7 @@ class Teradata(Dialect): exp.Update, **{ # type: ignore "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), - "from": self._parse_from(), + "from": self._parse_from(modifiers=True), "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), "where": self._parse_where(), @@ -135,13 +135,15 @@ class Teradata(Dialect): TABLE_HINTS = False TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.GEOMETRY: "ST_GEOMETRY", } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore - exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX, + **generator.Generator.PROPERTIES_LOCATION, + exp.OnCommitProperty: exp.Properties.Location.POST_INDEX, + exp.PartitionedByProperty: exp.Properties.Location.POST_EXPRESSION, + exp.StabilityProperty: exp.Properties.Location.POST_CREATE, } TRANSFORMS = { diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py index c7b34fe..af0f78d 100644 --- a/sqlglot/dialects/trino.py +++ b/sqlglot/dialects/trino.py @@ -7,7 +7,7 @@ from sqlglot.dialects.presto import Presto class Trino(Presto): class Generator(Presto.Generator): TRANSFORMS = { - **Presto.Generator.TRANSFORMS, # type: ignore + **Presto.Generator.TRANSFORMS, exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 03de99c..f6ad888 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -16,6 +16,9 @@ from sqlglot.helper import seq_get from sqlglot.time import format_time from sqlglot.tokens import TokenType +if t.TYPE_CHECKING: + from sqlglot._typing import E + FULL_FORMAT_TIME_MAPPING = { "weekday": "%A", "dw": "%A", @@ -50,13 +53,17 @@ DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{ TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"} -def _format_time_lambda(exp_class, full_format_mapping=None, default=None): - def _format_time(args): +def _format_time_lambda( + exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None +) -> t.Callable[[t.List], E]: + def _format_time(args: t.List) -> E: + assert len(args) == 2 + return exp_class( - this=seq_get(args, 1), + this=args[1], format=exp.Literal.string( format_time( - seq_get(args, 0).name or (TSQL.time_format if default is True else default), + args[0].name, {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.time_mapping, @@ -67,13 +74,17 @@ def _format_time_lambda(exp_class, full_format_mapping=None, default=None): return _format_time -def _parse_format(args): - fmt = seq_get(args, 1) - number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this) +def _parse_format(args: t.List) -> exp.Expression: + assert len(args) == 2 + + fmt = args[1] + number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name) + if number_fmt: - return exp.NumberToStr(this=seq_get(args, 0), format=fmt) + return exp.NumberToStr(this=args[0], format=fmt) + return exp.TimeToStr( - this=seq_get(args, 0), + this=args[0], format=exp.Literal.string( format_time(fmt.name, TSQL.format_time_mapping) if len(fmt.name) == 1 @@ -82,7 +93,7 @@ def _parse_format(args): ) -def _parse_eomonth(args): +def _parse_eomonth(args: t.List) -> exp.Expression: date = seq_get(args, 0) month_lag = seq_get(args, 1) unit = DATE_DELTA_INTERVAL.get("month") @@ -96,7 +107,7 @@ def _parse_eomonth(args): return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit)) -def _parse_hashbytes(args): +def _parse_hashbytes(args: t.List) -> exp.Expression: kind, data = args kind = kind.name.upper() if kind.is_string else "" @@ -110,40 +121,47 @@ def _parse_hashbytes(args): return exp.SHA2(this=data, length=exp.Literal.number(256)) if kind == "SHA2_512": return exp.SHA2(this=data, length=exp.Literal.number(512)) + return exp.func("HASHBYTES", *args) -def generate_date_delta_with_unit_sql(self, e): - func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF" - return self.func(func, e.text("unit"), e.expression, e.this) +def generate_date_delta_with_unit_sql( + self: generator.Generator, expression: exp.DateAdd | exp.DateDiff +) -> str: + func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF" + return self.func(func, expression.text("unit"), expression.expression, expression.this) -def _format_sql(self, e): +def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str: fmt = ( - e.args["format"] - if isinstance(e, exp.NumberToStr) - else exp.Literal.string(format_time(e.text("format"), TSQL.inverse_time_mapping)) + expression.args["format"] + if isinstance(expression, exp.NumberToStr) + else exp.Literal.string( + format_time( + expression.text("format"), t.cast(t.Dict[str, str], TSQL.inverse_time_mapping) + ) + ) ) - return self.func("FORMAT", e.this, fmt) + return self.func("FORMAT", expression.this, fmt) -def _string_agg_sql(self, e): - e = e.copy() +def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str: + expression = expression.copy() - this = e.this - distinct = e.find(exp.Distinct) + this = expression.this + distinct = expression.find(exp.Distinct) if distinct: # exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.") this = distinct.pop().expressions[0] order = "" - if isinstance(e.this, exp.Order): - if e.this.this: - this = e.this.this.pop() - order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space + if isinstance(expression.this, exp.Order): + if expression.this.this: + this = expression.this.this.pop() + order = f" WITHIN GROUP ({self.sql(expression.this)[1:]})" # Order has a leading space - separator = e.args.get("separator") or exp.Literal.string(",") + separator = expression.args.get("separator") or exp.Literal.string(",") return f"STRING_AGG({self.format_args(this, separator)}){order}" @@ -292,7 +310,7 @@ class TSQL(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "CHARINDEX": lambda args: exp.StrPosition( this=seq_get(args, 1), substr=seq_get(args, 0), @@ -332,13 +350,13 @@ class TSQL(Dialect): DataType.Type.NCHAR, } - RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { # type: ignore + RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { TokenType.TABLE, - *parser.Parser.TYPE_TOKENS, # type: ignore + *parser.Parser.TYPE_TOKENS, } STATEMENT_PARSERS = { - **parser.Parser.STATEMENT_PARSERS, # type: ignore + **parser.Parser.STATEMENT_PARSERS, TokenType.END: lambda self: self._parse_command(), } @@ -377,7 +395,7 @@ class TSQL(Dialect): return system_time - def _parse_table_parts(self, schema: bool = False) -> exp.Expression: + def _parse_table_parts(self, schema: bool = False) -> exp.Table: table = super()._parse_table_parts(schema=schema) table.set("system_time", self._parse_system_time()) return table @@ -450,7 +468,7 @@ class TSQL(Dialect): LOCKING_READS_SUPPORTED = True TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.DECIMAL: "NUMERIC", exp.DataType.Type.DATETIME: "DATETIME2", @@ -458,7 +476,7 @@ class TSQL(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), @@ -480,7 +498,7 @@ class TSQL(Dialect): TRANSFORMS.pop(exp.ReturnsProperty) PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } diff --git a/sqlglot/diff.py b/sqlglot/diff.py index 86665e0..c10d640 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -53,7 +53,8 @@ class Keep: if t.TYPE_CHECKING: - T = t.TypeVar("T") + from sqlglot._typing import T + Edit = t.Union[Insert, Remove, Move, Update, Keep] @@ -240,7 +241,7 @@ class ChangeDistiller: return matching_set def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]: - candidate_matchings: t.List[t.Tuple[float, int, exp.Expression, exp.Expression]] = [] + candidate_matchings: t.List[t.Tuple[float, int, int, exp.Expression, exp.Expression]] = [] source_leaves = list(_get_leaves(self._source)) target_leaves = list(_get_leaves(self._target)) for source_leaf in source_leaves: @@ -252,6 +253,7 @@ class ChangeDistiller: candidate_matchings, ( -similarity_score, + -_parent_similarity_score(source_leaf, target_leaf), len(candidate_matchings), source_leaf, target_leaf, @@ -261,7 +263,7 @@ class ChangeDistiller: # Pick best matchings based on the highest score matching_set = set() while candidate_matchings: - _, _, source_leaf, target_leaf = heappop(candidate_matchings) + _, _, _, source_leaf, target_leaf = heappop(candidate_matchings) if ( id(source_leaf) in self._unmatched_source_nodes and id(target_leaf) in self._unmatched_target_nodes @@ -327,6 +329,15 @@ def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool: return False +def _parent_similarity_score( + source: t.Optional[exp.Expression], target: t.Optional[exp.Expression] +) -> int: + if source is None or target is None or type(source) is not type(target): + return 0 + + return 1 + _parent_similarity_score(source.parent, target.parent) + + def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]: args: t.List[t.Union[exp.Expression, t.List]] = [] if expression: diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py index a67c155..017d5bc 100644 --- a/sqlglot/executor/__init__.py +++ b/sqlglot/executor/__init__.py @@ -14,9 +14,10 @@ from sqlglot import maybe_parse from sqlglot.errors import ExecuteError from sqlglot.executor.python import PythonExecutor from sqlglot.executor.table import Table, ensure_tables +from sqlglot.helper import dict_depth from sqlglot.optimizer import optimize from sqlglot.planner import Plan -from sqlglot.schema import ensure_schema +from sqlglot.schema import ensure_schema, flatten_schema, nested_get, nested_set logger = logging.getLogger("sqlglot") @@ -52,10 +53,15 @@ def execute( tables_ = ensure_tables(tables) if not schema: - schema = { - name: {column: type(table[0][column]).__name__ for column in table.columns} - for name, table in tables_.mapping.items() - } + schema = {} + flattened_tables = flatten_schema(tables_.mapping, depth=dict_depth(tables_.mapping)) + + for keys in flattened_tables: + table = nested_get(tables_.mapping, *zip(keys, keys)) + assert table is not None + + for column in table.columns: + nested_set(schema, [*keys, column], type(table[0][column]).__name__) schema = ensure_schema(schema, dialect=read) diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index 8f64cce..51cffbd 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -5,6 +5,7 @@ import statistics from functools import wraps from sqlglot import exp +from sqlglot.generator import Generator from sqlglot.helper import PYTHON_VERSION @@ -102,6 +103,8 @@ def cast(this, to): return datetime.date.fromisoformat(this) if to == exp.DataType.Type.DATETIME: return datetime.datetime.fromisoformat(this) + if to == exp.DataType.Type.BOOLEAN: + return bool(this) if to in exp.DataType.TEXT_TYPES: return str(this) if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}: @@ -119,9 +122,11 @@ def ordered(this, desc, nulls_first): @null_if_any def interval(this, unit): - if unit == "DAY": - return datetime.timedelta(days=float(this)) - raise NotImplementedError + unit = unit.lower() + plural = unit + "s" + if plural in Generator.TIME_PART_SINGULARS: + unit = plural + return datetime.timedelta(**{unit: float(this)}) ENV = { @@ -147,7 +152,9 @@ ENV = { "COALESCE": lambda *args: next((a for a in args if a is not None), None), "CONCAT": null_if_any(lambda *args: "".join(args)), "CONCATWS": null_if_any(lambda this, *args: this.join(args)), + "DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)), "DIV": null_if_any(lambda e, this: e / this), + "DOT": null_if_any(lambda e, this: e[this]), "EQ": null_if_any(lambda this, e: this == e), "EXTRACT": null_if_any(lambda this, e: getattr(e, this)), "GT": null_if_any(lambda this, e: this > e), @@ -162,6 +169,7 @@ ENV = { "LOWER": null_if_any(lambda arg: arg.lower()), "LT": null_if_any(lambda this, e: this < e), "LTE": null_if_any(lambda this, e: this <= e), + "MAP": null_if_any(lambda *args: dict(zip(*args))), # type: ignore "MOD": null_if_any(lambda e, this: e % this), "MUL": null_if_any(lambda e, this: e * this), "NEQ": null_if_any(lambda this, e: this != e), @@ -180,4 +188,5 @@ ENV = { "CURRENTTIMESTAMP": datetime.datetime.now, "CURRENTTIME": datetime.datetime.now, "CURRENTDATE": datetime.date.today, + "STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)), } diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index b71cc6a..f114e5c 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -360,11 +360,19 @@ def _ordered_py(self, expression): def _rename(self, e): try: - if "expressions" in e.args: - this = self.sql(e, "this") - this = f"{this}, " if this else "" - return f"{e.key.upper()}({this}{self.expressions(e)})" - return self.func(e.key, *e.args.values()) + values = list(e.args.values()) + + if len(values) == 1: + values = values[0] + if not isinstance(values, list): + return self.func(e.key, values) + return self.func(e.key, *values) + + if isinstance(e, exp.Func) and e.is_var_len_args: + *head, tail = values + return self.func(e.key, *head, *tail) + + return self.func(e.key, *values) except Exception as ex: raise Exception(f"Could not rename {repr(e)}") from ex @@ -413,6 +421,7 @@ class Python(Dialect): exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})", exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})", + exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')", exp.Is: lambda self, e: self.binary(e, "is"), exp.Lambda: _lambda_sql, exp.Not: lambda self, e: f"not {self.sql(e.this)}", diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 9e7379d..a4c4e95 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -21,6 +21,7 @@ from collections import deque from copy import deepcopy from enum import auto +from sqlglot._typing import E from sqlglot.errors import ParseError from sqlglot.helper import ( AutoName, @@ -28,7 +29,6 @@ from sqlglot.helper import ( ensure_collection, ensure_list, seq_get, - split_num_words, subclasses, ) from sqlglot.tokens import Token @@ -36,8 +36,6 @@ from sqlglot.tokens import Token if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType -E = t.TypeVar("E", bound="Expression") - class _Expression(type): def __new__(cls, clsname, bases, attrs): @@ -200,11 +198,11 @@ class Expression(metaclass=_Expression): return self.text("this") @property - def alias_or_name(self): + def alias_or_name(self) -> str: return self.alias or self.name @property - def output_name(self): + def output_name(self) -> str: """ Name of the output column if this expression is a selection. @@ -264,7 +262,7 @@ class Expression(metaclass=_Expression): if comments: self.comments.extend(comments) - def append(self, arg_key, value): + def append(self, arg_key: str, value: t.Any) -> None: """ Appends value to arg_key if it's a list or sets it as a new list. @@ -277,7 +275,7 @@ class Expression(metaclass=_Expression): self.args[arg_key].append(value) self._set_parent(arg_key, value) - def set(self, arg_key, value): + def set(self, arg_key: str, value: t.Any) -> None: """ Sets `arg_key` to `value`. @@ -288,7 +286,7 @@ class Expression(metaclass=_Expression): self.args[arg_key] = value self._set_parent(arg_key, value) - def _set_parent(self, arg_key, value): + def _set_parent(self, arg_key: str, value: t.Any) -> None: if hasattr(value, "parent"): value.parent = self value.arg_key = arg_key @@ -299,7 +297,7 @@ class Expression(metaclass=_Expression): v.arg_key = arg_key @property - def depth(self): + def depth(self) -> int: """ Returns the depth of this tree. """ @@ -318,26 +316,28 @@ class Expression(metaclass=_Expression): if hasattr(vs, "parent"): yield k, vs - def find(self, *expression_types: t.Type[E], bfs=True) -> E | None: + def find(self, *expression_types: t.Type[E], bfs: bool = True) -> t.Optional[E]: """ Returns the first node in this tree which matches at least one of the specified types. Args: expression_types: the expression type(s) to match. + bfs: whether to search the AST using the BFS algorithm (DFS is used if false). Returns: The node which matches the criteria or None if no such node was found. """ return next(self.find_all(*expression_types, bfs=bfs), None) - def find_all(self, *expression_types: t.Type[E], bfs=True) -> t.Iterator[E]: + def find_all(self, *expression_types: t.Type[E], bfs: bool = True) -> t.Iterator[E]: """ Returns a generator object which visits all nodes in this tree and only yields those that match at least one of the specified expression types. Args: expression_types: the expression type(s) to match. + bfs: whether to search the AST using the BFS algorithm (DFS is used if false). Returns: The generator object. @@ -346,7 +346,7 @@ class Expression(metaclass=_Expression): if isinstance(expression, expression_types): yield expression - def find_ancestor(self, *expression_types: t.Type[E]) -> E | None: + def find_ancestor(self, *expression_types: t.Type[E]) -> t.Optional[E]: """ Returns a nearest parent matching expression_types. @@ -362,14 +362,14 @@ class Expression(metaclass=_Expression): return t.cast(E, ancestor) @property - def parent_select(self): + def parent_select(self) -> t.Optional[Select]: """ Returns the parent select statement. """ return self.find_ancestor(Select) @property - def same_parent(self): + def same_parent(self) -> bool: """Returns if the parent is the same class as itself.""" return type(self.parent) is self.__class__ @@ -469,10 +469,10 @@ class Expression(metaclass=_Expression): if not type(node) is self.__class__: yield node.unnest() if unnest else node - def __str__(self): + def __str__(self) -> str: return self.sql() - def __repr__(self): + def __repr__(self) -> str: return self._to_s() def sql(self, dialect: DialectType = None, **opts) -> str: @@ -541,6 +541,14 @@ class Expression(metaclass=_Expression): replace_children(new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs)) return new_node + @t.overload + def replace(self, expression: E) -> E: + ... + + @t.overload + def replace(self, expression: None) -> None: + ... + def replace(self, expression): """ Swap out this expression with a new expression. @@ -554,7 +562,7 @@ class Expression(metaclass=_Expression): 'SELECT y FROM tbl' Args: - expression (Expression|None): new node + expression: new node Returns: The new expression or expressions. @@ -568,7 +576,7 @@ class Expression(metaclass=_Expression): replace_children(parent, lambda child: expression if child is self else child) return expression - def pop(self): + def pop(self: E) -> E: """ Remove this expression from its AST. @@ -578,7 +586,7 @@ class Expression(metaclass=_Expression): self.replace(None) return self - def assert_is(self, type_): + def assert_is(self, type_: t.Type[E]) -> E: """ Assert that this `Expression` is an instance of `type_`. @@ -656,7 +664,13 @@ ExpOrStr = t.Union[str, Expression] class Condition(Expression): - def and_(self, *expressions, dialect=None, copy=True, **opts): + def and_( + self, + *expressions: t.Optional[ExpOrStr], + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Condition: """ AND this condition with one or multiple expressions. @@ -665,18 +679,24 @@ class Condition(Expression): 'x = 1 AND y = 1' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. - dialect (str): the dialect used to parse the input expression. - copy (bool): whether or not to copy the involved expressions (only applies to Expressions). - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: whether or not to copy the involved expressions (only applies to Expressions). + opts: other options to use to parse the input expressions. Returns: - And: the new condition. + The new And condition. """ return and_(self, *expressions, dialect=dialect, copy=copy, **opts) - def or_(self, *expressions, dialect=None, copy=True, **opts): + def or_( + self, + *expressions: t.Optional[ExpOrStr], + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Condition: """ OR this condition with one or multiple expressions. @@ -685,18 +705,18 @@ class Condition(Expression): 'x = 1 OR y = 1' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. - dialect (str): the dialect used to parse the input expression. - copy (bool): whether or not to copy the involved expressions (only applies to Expressions). - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: whether or not to copy the involved expressions (only applies to Expressions). + opts: other options to use to parse the input expressions. Returns: - Or: the new condition. + The new Or condition. """ return or_(self, *expressions, dialect=dialect, copy=copy, **opts) - def not_(self, copy=True): + def not_(self, copy: bool = True): """ Wrap this condition with NOT. @@ -705,14 +725,24 @@ class Condition(Expression): 'NOT x = 1' Args: - copy (bool): whether or not to copy this object. + copy: whether or not to copy this object. Returns: - Not: the new condition. + The new Not instance. """ return not_(self, copy=copy) - def _binop(self, klass: t.Type[E], other: ExpOrStr, reverse=False) -> E: + def as_( + self, + alias: str | Identifier, + quoted: t.Optional[bool] = None, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Alias: + return alias_(self, alias, quoted=quoted, dialect=dialect, copy=copy, **opts) + + def _binop(self, klass: t.Type[E], other: t.Any, reverse: bool = False) -> E: this = self.copy() other = convert(other, copy=True) if not isinstance(this, klass) and not isinstance(other, klass): @@ -728,7 +758,7 @@ class Condition(Expression): ) def isin( - self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts + self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy: bool = True, **opts ) -> In: return In( this=_maybe_copy(self, copy), @@ -736,92 +766,95 @@ class Condition(Expression): query=maybe_parse(query, copy=copy, **opts) if query else None, ) - def between(self, low: t.Any, high: t.Any, copy=True, **opts) -> Between: + def between(self, low: t.Any, high: t.Any, copy: bool = True, **opts) -> Between: return Between( this=_maybe_copy(self, copy), low=convert(low, copy=copy, **opts), high=convert(high, copy=copy, **opts), ) + def is_(self, other: ExpOrStr) -> Is: + return self._binop(Is, other) + def like(self, other: ExpOrStr) -> Like: return self._binop(Like, other) def ilike(self, other: ExpOrStr) -> ILike: return self._binop(ILike, other) - def eq(self, other: ExpOrStr) -> EQ: + def eq(self, other: t.Any) -> EQ: return self._binop(EQ, other) - def neq(self, other: ExpOrStr) -> NEQ: + def neq(self, other: t.Any) -> NEQ: return self._binop(NEQ, other) def rlike(self, other: ExpOrStr) -> RegexpLike: return self._binop(RegexpLike, other) - def __lt__(self, other: ExpOrStr) -> LT: + def __lt__(self, other: t.Any) -> LT: return self._binop(LT, other) - def __le__(self, other: ExpOrStr) -> LTE: + def __le__(self, other: t.Any) -> LTE: return self._binop(LTE, other) - def __gt__(self, other: ExpOrStr) -> GT: + def __gt__(self, other: t.Any) -> GT: return self._binop(GT, other) - def __ge__(self, other: ExpOrStr) -> GTE: + def __ge__(self, other: t.Any) -> GTE: return self._binop(GTE, other) - def __add__(self, other: ExpOrStr) -> Add: + def __add__(self, other: t.Any) -> Add: return self._binop(Add, other) - def __radd__(self, other: ExpOrStr) -> Add: + def __radd__(self, other: t.Any) -> Add: return self._binop(Add, other, reverse=True) - def __sub__(self, other: ExpOrStr) -> Sub: + def __sub__(self, other: t.Any) -> Sub: return self._binop(Sub, other) - def __rsub__(self, other: ExpOrStr) -> Sub: + def __rsub__(self, other: t.Any) -> Sub: return self._binop(Sub, other, reverse=True) - def __mul__(self, other: ExpOrStr) -> Mul: + def __mul__(self, other: t.Any) -> Mul: return self._binop(Mul, other) - def __rmul__(self, other: ExpOrStr) -> Mul: + def __rmul__(self, other: t.Any) -> Mul: return self._binop(Mul, other, reverse=True) - def __truediv__(self, other: ExpOrStr) -> Div: + def __truediv__(self, other: t.Any) -> Div: return self._binop(Div, other) - def __rtruediv__(self, other: ExpOrStr) -> Div: + def __rtruediv__(self, other: t.Any) -> Div: return self._binop(Div, other, reverse=True) - def __floordiv__(self, other: ExpOrStr) -> IntDiv: + def __floordiv__(self, other: t.Any) -> IntDiv: return self._binop(IntDiv, other) - def __rfloordiv__(self, other: ExpOrStr) -> IntDiv: + def __rfloordiv__(self, other: t.Any) -> IntDiv: return self._binop(IntDiv, other, reverse=True) - def __mod__(self, other: ExpOrStr) -> Mod: + def __mod__(self, other: t.Any) -> Mod: return self._binop(Mod, other) - def __rmod__(self, other: ExpOrStr) -> Mod: + def __rmod__(self, other: t.Any) -> Mod: return self._binop(Mod, other, reverse=True) - def __pow__(self, other: ExpOrStr) -> Pow: + def __pow__(self, other: t.Any) -> Pow: return self._binop(Pow, other) - def __rpow__(self, other: ExpOrStr) -> Pow: + def __rpow__(self, other: t.Any) -> Pow: return self._binop(Pow, other, reverse=True) - def __and__(self, other: ExpOrStr) -> And: + def __and__(self, other: t.Any) -> And: return self._binop(And, other) - def __rand__(self, other: ExpOrStr) -> And: + def __rand__(self, other: t.Any) -> And: return self._binop(And, other, reverse=True) - def __or__(self, other: ExpOrStr) -> Or: + def __or__(self, other: t.Any) -> Or: return self._binop(Or, other) - def __ror__(self, other: ExpOrStr) -> Or: + def __ror__(self, other: t.Any) -> Or: return self._binop(Or, other, reverse=True) def __neg__(self) -> Neg: @@ -837,12 +870,11 @@ class Predicate(Condition): class DerivedTable(Expression): @property - def alias_column_names(self): + def alias_column_names(self) -> t.List[str]: table_alias = self.args.get("alias") if not table_alias: return [] - column_list = table_alias.assert_is(TableAlias).args.get("columns") or [] - return [c.name for c in column_list] + return [c.name for c in table_alias.args.get("columns") or []] @property def selects(self): @@ -854,7 +886,9 @@ class DerivedTable(Expression): class Unionable(Expression): - def union(self, expression, distinct=True, dialect=None, **opts): + def union( + self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts + ) -> Unionable: """ Builds a UNION expression. @@ -864,17 +898,20 @@ class Unionable(Expression): 'SELECT * FROM foo UNION SELECT * FROM bla' Args: - expression (str | Expression): the SQL code string. + expression: the SQL code string. If an `Expression` instance is passed, it will be used as-is. - distinct (bool): set the DISTINCT flag if and only if this is true. - dialect (str): the dialect used to parse the input expression. - opts (kwargs): other options to use to parse the input expressions. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + opts: other options to use to parse the input expressions. + Returns: - Union: the Union expression. + The new Union expression. """ return union(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) - def intersect(self, expression, distinct=True, dialect=None, **opts): + def intersect( + self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts + ) -> Unionable: """ Builds an INTERSECT expression. @@ -884,17 +921,20 @@ class Unionable(Expression): 'SELECT * FROM foo INTERSECT SELECT * FROM bla' Args: - expression (str | Expression): the SQL code string. + expression: the SQL code string. If an `Expression` instance is passed, it will be used as-is. - distinct (bool): set the DISTINCT flag if and only if this is true. - dialect (str): the dialect used to parse the input expression. - opts (kwargs): other options to use to parse the input expressions. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + opts: other options to use to parse the input expressions. + Returns: - Intersect: the Intersect expression + The new Intersect expression. """ return intersect(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) - def except_(self, expression, distinct=True, dialect=None, **opts): + def except_( + self, expression: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts + ) -> Unionable: """ Builds an EXCEPT expression. @@ -904,13 +944,14 @@ class Unionable(Expression): 'SELECT * FROM foo EXCEPT SELECT * FROM bla' Args: - expression (str | Expression): the SQL code string. + expression: the SQL code string. If an `Expression` instance is passed, it will be used as-is. - distinct (bool): set the DISTINCT flag if and only if this is true. - dialect (str): the dialect used to parse the input expression. - opts (kwargs): other options to use to parse the input expressions. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + opts: other options to use to parse the input expressions. + Returns: - Except: the Except expression + The new Except expression. """ return except_(left=self, right=expression, distinct=distinct, dialect=dialect, **opts) @@ -949,6 +990,17 @@ class Create(Expression): "indexes": False, "no_schema_binding": False, "begin": False, + "clone": False, + } + + +# https://docs.snowflake.com/en/sql-reference/sql/create-clone +class Clone(Expression): + arg_types = { + "this": True, + "when": False, + "kind": False, + "expression": False, } @@ -1038,6 +1090,10 @@ class ByteString(Condition): pass +class RawString(Condition): + pass + + class Column(Condition): arg_types = {"this": True, "table": False, "db": False, "catalog": False, "join_mark": False} @@ -1060,7 +1116,11 @@ class Column(Condition): @property def parts(self) -> t.List[Identifier]: """Return the parts of a column in order catalog, db, table, name.""" - return [part for part in reversed(list(self.args.values())) if part] + return [ + t.cast(Identifier, self.args[part]) + for part in ("catalog", "db", "table", "this") + if self.args.get(part) + ] def to_dot(self) -> Dot: """Converts the column into a dot expression.""" @@ -1116,6 +1176,27 @@ class Comment(Expression): arg_types = {"this": True, "kind": True, "expression": True, "exists": False} +# https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl +class MergeTreeTTLAction(Expression): + arg_types = { + "this": True, + "delete": False, + "recompress": False, + "to_disk": False, + "to_volume": False, + } + + +# https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl +class MergeTreeTTL(Expression): + arg_types = { + "expressions": True, + "where": False, + "group": False, + "aggregates": False, + } + + class ColumnConstraint(Expression): arg_types = {"this": False, "kind": True} @@ -1172,6 +1253,8 @@ class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): # this: True -> ALWAYS, this: False -> BY DEFAULT arg_types = { "this": False, + "expression": False, + "on_null": False, "start": False, "increment": False, "minvalue": False, @@ -1202,7 +1285,7 @@ class TitleColumnConstraint(ColumnConstraintKind): class UniqueColumnConstraint(ColumnConstraintKind): - arg_types: t.Dict[str, t.Any] = {} + arg_types = {"this": False} class UppercaseColumnConstraint(ColumnConstraintKind): @@ -1255,7 +1338,7 @@ class Delete(Expression): def where( self, - *expressions: ExpOrStr, + *expressions: t.Optional[ExpOrStr], append: bool = True, dialect: DialectType = None, copy: bool = True, @@ -1367,10 +1450,6 @@ class PrimaryKey(Expression): arg_types = {"expressions": True, "options": False} -class Unique(Expression): - arg_types = {"expressions": True} - - # https://www.postgresql.org/docs/9.1/sql-selectinto.html # https://docs.aws.amazon.com/redshift/latest/dg/r_SELECT_INTO.html#r_SELECT_INTO-examples class Into(Expression): @@ -1378,7 +1457,13 @@ class Into(Expression): class From(Expression): - arg_types = {"expressions": True} + @property + def name(self) -> str: + return self.this.name + + @property + def alias_or_name(self) -> str: + return self.this.alias_or_name class Having(Expression): @@ -1397,7 +1482,7 @@ class Identifier(Expression): arg_types = {"this": True, "quoted": False} @property - def quoted(self): + def quoted(self) -> bool: return bool(self.args.get("quoted")) @property @@ -1407,7 +1492,7 @@ class Identifier(Expression): return self.this.lower() @property - def output_name(self): + def output_name(self) -> str: return self.name @@ -1420,6 +1505,7 @@ class Index(Expression): "unique": False, "primary": False, "amp": False, # teradata + "partition_by": False, # teradata } @@ -1436,6 +1522,42 @@ class Insert(Expression): "alternative": False, } + def with_( + self, + alias: ExpOrStr, + as_: ExpOrStr, + recursive: t.Optional[bool] = None, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Insert: + """ + Append to or set the common table expressions. + + Example: + >>> insert("SELECT x FROM cte", "t").with_("cte", as_="SELECT * FROM tbl").sql() + 'WITH cte AS (SELECT * FROM tbl) INSERT INTO t SELECT x FROM cte' + + Args: + alias: the SQL code string to parse as the table name. + If an `Expression` instance is passed, this is used as-is. + as_: the SQL code string to parse as the table expression. + If an `Expression` instance is passed, it will be used as-is. + recursive: set the RECURSIVE part of the expression. Defaults to `False`. + append: if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + The modified expression. + """ + return _apply_cte_builder( + self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts + ) + class OnConflict(Expression): arg_types = { @@ -1492,6 +1614,7 @@ class Group(Expression): "grouping_sets": False, "cube": False, "rollup": False, + "totals": False, } @@ -1519,7 +1642,7 @@ class Literal(Condition): return cls(this=str(string), is_string=True) @property - def output_name(self): + def output_name(self) -> str: return self.name @@ -1531,26 +1654,34 @@ class Join(Expression): "kind": False, "using": False, "natural": False, + "global": False, "hint": False, } @property - def kind(self): + def kind(self) -> str: return self.text("kind").upper() @property - def side(self): + def side(self) -> str: return self.text("side").upper() @property - def hint(self): + def hint(self) -> str: return self.text("hint").upper() @property - def alias_or_name(self): + def alias_or_name(self) -> str: return self.this.alias_or_name - def on(self, *expressions, append=True, dialect=None, copy=True, **opts): + def on( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Join: """ Append to or set the ON expressions. @@ -1560,17 +1691,17 @@ class Join(Expression): 'JOIN x ON y = 1' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. Multiple expressions are combined with an AND operator. - append (bool): if `True`, AND the new expressions to any existing expression. + append: if `True`, AND the new expressions to any existing expression. Otherwise, this resets the expression. - dialect (str): the dialect used to parse the input expressions. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Join: the modified join expression. + The modified Join expression. """ join = _apply_conjunction_builder( *expressions, @@ -1587,7 +1718,14 @@ class Join(Expression): return join - def using(self, *expressions, append=True, dialect=None, copy=True, **opts): + def using( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Join: """ Append to or set the USING expressions. @@ -1597,16 +1735,16 @@ class Join(Expression): 'JOIN x USING (foo, bla)' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. - append (bool): if `True`, concatenate the new expressions to the existing "using" list. + append: if `True`, concatenate the new expressions to the existing "using" list. Otherwise, this resets the expression. - dialect (str): the dialect used to parse the input expressions. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Join: the modified join expression. + The modified Join expression. """ join = _apply_list_builder( *expressions, @@ -1677,10 +1815,6 @@ class Property(Expression): arg_types = {"this": True, "value": True} -class AfterJournalProperty(Property): - arg_types = {"no": True, "dual": False, "local": False} - - class AlgorithmProperty(Property): arg_types = {"this": True} @@ -1706,7 +1840,13 @@ class CollateProperty(Property): class DataBlocksizeProperty(Property): - arg_types = {"size": False, "units": False, "min": False, "default": False} + arg_types = { + "size": False, + "units": False, + "minimum": False, + "maximum": False, + "default": False, + } class DefinerProperty(Property): @@ -1760,7 +1900,13 @@ class IsolatedLoadingProperty(Property): class JournalProperty(Property): - arg_types = {"no": True, "dual": False, "before": False} + arg_types = { + "no": False, + "dual": False, + "before": False, + "local": False, + "after": False, + } class LanguageProperty(Property): @@ -1798,11 +1944,11 @@ class MergeBlockRatioProperty(Property): class NoPrimaryIndexProperty(Property): - arg_types = {"this": False} + arg_types = {} class OnCommitProperty(Property): - arg_type = {"this": False} + arg_type = {"delete": False} class PartitionedByProperty(Property): @@ -1846,6 +1992,10 @@ class SetProperty(Property): arg_types = {"multi": True} +class SettingsProperty(Property): + arg_types = {"expressions": True} + + class SortKeyProperty(Property): arg_types = {"this": True, "compound": False} @@ -1858,12 +2008,8 @@ class StabilityProperty(Property): arg_types = {"this": True} -class TableFormatProperty(Property): - arg_types = {"this": True} - - class TemporaryProperty(Property): - arg_types = {"global_": True} + arg_types = {} class TransientProperty(Property): @@ -1903,7 +2049,6 @@ class Properties(Expression): "RETURNS": ReturnsProperty, "ROW_FORMAT": RowFormatProperty, "SORTKEY": SortKeyProperty, - "TABLE_FORMAT": TableFormatProperty, } PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()} @@ -1932,7 +2077,7 @@ class Properties(Expression): UNSUPPORTED = auto() @classmethod - def from_dict(cls, properties_dict) -> Properties: + def from_dict(cls, properties_dict: t.Dict) -> Properties: expressions = [] for key, value in properties_dict.items(): property_cls = cls.NAME_TO_PROPERTY.get(key.upper()) @@ -1961,7 +2106,7 @@ class Tuple(Expression): arg_types = {"expressions": False} def isin( - self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy=True, **opts + self, *expressions: t.Any, query: t.Optional[ExpOrStr] = None, copy: bool = True, **opts ) -> In: return In( this=_maybe_copy(self, copy), @@ -1971,7 +2116,7 @@ class Tuple(Expression): class Subqueryable(Unionable): - def subquery(self, alias=None, copy=True) -> Subquery: + def subquery(self, alias: t.Optional[ExpOrStr] = None, copy: bool = True) -> Subquery: """ Convert this expression to an aliased expression that can be used as a Subquery. @@ -1988,12 +2133,14 @@ class Subqueryable(Unionable): Alias: the subquery """ instance = _maybe_copy(self, copy) - return Subquery( - this=instance, - alias=TableAlias(this=to_identifier(alias)) if alias else None, - ) + if not isinstance(alias, Expression): + alias = TableAlias(this=to_identifier(alias)) if alias else None + + return Subquery(this=instance, alias=alias) - def limit(self, expression, dialect=None, copy=True, **opts) -> Select: + def limit( + self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts + ) -> Select: raise NotImplementedError @property @@ -2013,14 +2160,14 @@ class Subqueryable(Unionable): def with_( self, - alias, - as_, - recursive=None, - append=True, - dialect=None, - copy=True, + alias: ExpOrStr, + as_: ExpOrStr, + recursive: t.Optional[bool] = None, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, **opts, - ): + ) -> Subqueryable: """ Append to or set the common table expressions. @@ -2029,43 +2176,22 @@ class Subqueryable(Unionable): 'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2' Args: - alias (str | Expression): the SQL code string to parse as the table name. + alias: the SQL code string to parse as the table name. If an `Expression` instance is passed, this is used as-is. - as_ (str | Expression): the SQL code string to parse as the table expression. + as_: the SQL code string to parse as the table expression. If an `Expression` instance is passed, it will be used as-is. - recursive (bool): set the RECURSIVE part of the expression. Defaults to `False`. - append (bool): if `True`, add to any existing expressions. + recursive: set the RECURSIVE part of the expression. Defaults to `False`. + append: if `True`, add to any existing expressions. Otherwise, this resets the expressions. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified expression. """ - alias_expression = maybe_parse( - alias, - dialect=dialect, - into=TableAlias, - **opts, - ) - as_expression = maybe_parse( - as_, - dialect=dialect, - **opts, - ) - cte = CTE( - this=as_expression, - alias=alias_expression, - ) - return _apply_child_list_builder( - cte, - instance=self, - arg="with", - append=append, - copy=copy, - into=With, - properties={"recursive": recursive or False}, + return _apply_cte_builder( + self, alias, as_, recursive=recursive, append=append, dialect=dialect, copy=copy, **opts ) @@ -2085,8 +2211,10 @@ QUERY_MODIFIERS = { "order": False, "limit": False, "offset": False, - "lock": False, + "locks": False, "sample": False, + "settings": False, + "format": False, } @@ -2111,6 +2239,15 @@ class Table(Expression): def catalog(self) -> str: return self.text("catalog") + @property + def parts(self) -> t.List[Identifier]: + """Return the parts of a table in order catalog, db, table.""" + return [ + t.cast(Identifier, self.args[part]) + for part in ("catalog", "db", "this") + if self.args.get(part) + ] + # See the TSQL "Querying data in a system-versioned temporal table" page class SystemTime(Expression): @@ -2130,7 +2267,9 @@ class Union(Subqueryable): **QUERY_MODIFIERS, } - def limit(self, expression, dialect=None, copy=True, **opts) -> Select: + def limit( + self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts + ) -> Select: """ Set the LIMIT expression. @@ -2139,16 +2278,16 @@ class Union(Subqueryable): 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1' Args: - expression (str | int | Expression): the SQL code string to parse. + expression: the SQL code string to parse. This can also be an integer. If a `Limit` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Limit`. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: The limited subqueryable. + The limited subqueryable. """ return ( select("*") @@ -2158,7 +2297,7 @@ class Union(Subqueryable): def select( self, - *expressions: ExpOrStr, + *expressions: t.Optional[ExpOrStr], append: bool = True, dialect: DialectType = None, copy: bool = True, @@ -2255,10 +2394,10 @@ class Schema(Expression): arg_types = {"this": False, "expressions": False} -# Used to represent the FOR UPDATE and FOR SHARE locking read types. -# https://dev.mysql.com/doc/refman/8.0/en/innodb-locking-reads.html +# https://dev.mysql.com/doc/refman/8.0/en/select.html +# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/SELECT.html class Lock(Expression): - arg_types = {"update": True} + arg_types = {"update": True, "expressions": False, "wait": False} class Select(Subqueryable): @@ -2275,7 +2414,9 @@ class Select(Subqueryable): **QUERY_MODIFIERS, } - def from_(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def from_( + self, expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts + ) -> Select: """ Set the FROM expression. @@ -2284,31 +2425,35 @@ class Select(Subqueryable): 'SELECT x FROM tbl' Args: - *expressions (str | Expression): the SQL code strings to parse. + expression : the SQL code strings to parse. If a `From` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `From`. - append (bool): if `True`, add to any existing expressions. - Otherwise, this flattens all the `From` expression into a single expression. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ - return _apply_child_list_builder( - *expressions, + return _apply_builder( + expression=expression, instance=self, arg="from", - append=append, - copy=copy, - prefix="FROM", into=From, + prefix="FROM", dialect=dialect, + copy=copy, **opts, ) - def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def group_by( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: """ Set the GROUP BY expression. @@ -2317,21 +2462,22 @@ class Select(Subqueryable): 'SELECT x, COUNT(1) FROM tbl GROUP BY x' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If a `Group` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Group`. If nothing is passed in then a group by is not applied to the expression - append (bool): if `True`, add to any existing expressions. + append: if `True`, add to any existing expressions. Otherwise, this flattens all the `Group` expression into a single expression. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ if not expressions: return self if not copy else self.copy() + return _apply_child_list_builder( *expressions, instance=self, @@ -2344,7 +2490,14 @@ class Select(Subqueryable): **opts, ) - def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def order_by( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: """ Set the ORDER BY expression. @@ -2353,17 +2506,17 @@ class Select(Subqueryable): 'SELECT x FROM tbl ORDER BY x DESC' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If a `Group` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Order`. - append (bool): if `True`, add to any existing expressions. + append: if `True`, add to any existing expressions. Otherwise, this flattens all the `Order` expression into a single expression. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ return _apply_child_list_builder( *expressions, @@ -2377,26 +2530,33 @@ class Select(Subqueryable): **opts, ) - def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def sort_by( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: """ Set the SORT BY expression. Example: - >>> Select().from_("tbl").select("x").sort_by("x DESC").sql() + >>> Select().from_("tbl").select("x").sort_by("x DESC").sql(dialect="hive") 'SELECT x FROM tbl SORT BY x DESC' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If a `Group` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `SORT`. - append (bool): if `True`, add to any existing expressions. + append: if `True`, add to any existing expressions. Otherwise, this flattens all the `Order` expression into a single expression. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ return _apply_child_list_builder( *expressions, @@ -2410,26 +2570,33 @@ class Select(Subqueryable): **opts, ) - def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def cluster_by( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: """ Set the CLUSTER BY expression. Example: - >>> Select().from_("tbl").select("x").cluster_by("x DESC").sql() + >>> Select().from_("tbl").select("x").cluster_by("x DESC").sql(dialect="hive") 'SELECT x FROM tbl CLUSTER BY x DESC' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If a `Group` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Cluster`. - append (bool): if `True`, add to any existing expressions. + append: if `True`, add to any existing expressions. Otherwise, this flattens all the `Order` expression into a single expression. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ return _apply_child_list_builder( *expressions, @@ -2443,7 +2610,9 @@ class Select(Subqueryable): **opts, ) - def limit(self, expression, dialect=None, copy=True, **opts) -> Select: + def limit( + self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts + ) -> Select: """ Set the LIMIT expression. @@ -2452,13 +2621,13 @@ class Select(Subqueryable): 'SELECT x FROM tbl LIMIT 10' Args: - expression (str | int | Expression): the SQL code string to parse. + expression: the SQL code string to parse. This can also be an integer. If a `Limit` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Limit`. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: Select: the modified expression. @@ -2474,7 +2643,9 @@ class Select(Subqueryable): **opts, ) - def offset(self, expression, dialect=None, copy=True, **opts) -> Select: + def offset( + self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts + ) -> Select: """ Set the OFFSET expression. @@ -2483,16 +2654,16 @@ class Select(Subqueryable): 'SELECT x FROM tbl OFFSET 10' Args: - expression (str | int | Expression): the SQL code string to parse. + expression: the SQL code string to parse. This can also be an integer. If a `Offset` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Offset`. - dialect (str): the dialect used to parse the input expression. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ return _apply_builder( expression=expression, @@ -2507,7 +2678,7 @@ class Select(Subqueryable): def select( self, - *expressions: ExpOrStr, + *expressions: t.Optional[ExpOrStr], append: bool = True, dialect: DialectType = None, copy: bool = True, @@ -2530,7 +2701,7 @@ class Select(Subqueryable): opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ return _apply_list_builder( *expressions, @@ -2542,7 +2713,14 @@ class Select(Subqueryable): **opts, ) - def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def lateral( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: """ Append to or set the LATERAL expressions. @@ -2551,16 +2729,16 @@ class Select(Subqueryable): 'SELECT x FROM tbl LATERAL VIEW OUTER EXPLODE(y) tbl2 AS z' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. - append (bool): if `True`, add to any existing expressions. + append: if `True`, add to any existing expressions. Otherwise, this resets the expressions. - dialect (str): the dialect used to parse the input expressions. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ return _apply_list_builder( *expressions, @@ -2576,14 +2754,14 @@ class Select(Subqueryable): def join( self, - expression, - on=None, - using=None, - append=True, - join_type=None, - join_alias=None, - dialect=None, - copy=True, + expression: ExpOrStr, + on: t.Optional[ExpOrStr] = None, + using: t.Optional[ExpOrStr | t.List[ExpOrStr]] = None, + append: bool = True, + join_type: t.Optional[str] = None, + join_alias: t.Optional[Identifier | str] = None, + dialect: DialectType = None, + copy: bool = True, **opts, ) -> Select: """ @@ -2602,18 +2780,19 @@ class Select(Subqueryable): 'SELECT * FROM tbl LEFT OUTER JOIN tbl2 ON tbl1.y = tbl2.y' Args: - expression (str | Expression): the SQL code string to parse. + expression: the SQL code string to parse. If an `Expression` instance is passed, it will be used as-is. - on (str | Expression): optionally specify the join "on" criteria as a SQL string. + on: optionally specify the join "on" criteria as a SQL string. If an `Expression` instance is passed, it will be used as-is. - using (str | Expression): optionally specify the join "using" criteria as a SQL string. + using: optionally specify the join "using" criteria as a SQL string. If an `Expression` instance is passed, it will be used as-is. - append (bool): if `True`, add to any existing expressions. + append: if `True`, add to any existing expressions. Otherwise, this resets the expressions. - join_type (str): If set, alter the parsed join type - dialect (str): the dialect used to parse the input expressions. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + join_type: if set, alter the parsed join type. + join_alias: an optional alias for the joined source. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: Select: the modified expression. @@ -2621,9 +2800,9 @@ class Select(Subqueryable): parse_args = {"dialect": dialect, **opts} try: - expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args) + expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args) # type: ignore except ParseError: - expression = maybe_parse(expression, into=(Join, Expression), **parse_args) + expression = maybe_parse(expression, into=(Join, Expression), **parse_args) # type: ignore join = expression if isinstance(expression, Join) else Join(this=expression) @@ -2645,12 +2824,12 @@ class Select(Subqueryable): join.set("kind", kind.text) if on: - on = and_(*ensure_collection(on), dialect=dialect, copy=copy, **opts) + on = and_(*ensure_list(on), dialect=dialect, copy=copy, **opts) join.set("on", on) if using: join = _apply_list_builder( - *ensure_collection(using), + *ensure_list(using), instance=join, arg="using", append=append, @@ -2660,6 +2839,7 @@ class Select(Subqueryable): if join_alias: join.set("this", alias_(join.this, join_alias, table=True)) + return _apply_list_builder( join, instance=self, @@ -2669,7 +2849,14 @@ class Select(Subqueryable): **opts, ) - def where(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def where( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: """ Append to or set the WHERE expressions. @@ -2678,14 +2865,14 @@ class Select(Subqueryable): "SELECT x FROM tbl WHERE x = 'a' OR x < 'b'" Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. Multiple expressions are combined with an AND operator. - append (bool): if `True`, AND the new expressions to any existing expression. + append: if `True`, AND the new expressions to any existing expression. Otherwise, this resets the expression. - dialect (str): the dialect used to parse the input expressions. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: Select: the modified expression. @@ -2701,7 +2888,14 @@ class Select(Subqueryable): **opts, ) - def having(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def having( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: """ Append to or set the HAVING expressions. @@ -2710,17 +2904,17 @@ class Select(Subqueryable): 'SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 3' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. Multiple expressions are combined with an AND operator. - append (bool): if `True`, AND the new expressions to any existing expression. + append: if `True`, AND the new expressions to any existing expression. Otherwise, this resets the expression. - dialect (str): the dialect used to parse the input expressions. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input expressions. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. Returns: - Select: the modified expression. + The modified Select expression. """ return _apply_conjunction_builder( *expressions, @@ -2733,7 +2927,14 @@ class Select(Subqueryable): **opts, ) - def window(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def window( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: return _apply_list_builder( *expressions, instance=self, @@ -2745,7 +2946,14 @@ class Select(Subqueryable): **opts, ) - def qualify(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + def qualify( + self, + *expressions: t.Optional[ExpOrStr], + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Select: return _apply_conjunction_builder( *expressions, instance=self, @@ -2757,7 +2965,9 @@ class Select(Subqueryable): **opts, ) - def distinct(self, *ons: ExpOrStr, distinct: bool = True, copy: bool = True) -> Select: + def distinct( + self, *ons: t.Optional[ExpOrStr], distinct: bool = True, copy: bool = True + ) -> Select: """ Set the OFFSET expression. @@ -2774,11 +2984,18 @@ class Select(Subqueryable): Select: the modified expression. """ instance = _maybe_copy(self, copy) - on = Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons]) if ons else None + on = Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons if on]) if ons else None instance.set("distinct", Distinct(on=on) if distinct else None) return instance - def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create: + def ctas( + self, + table: ExpOrStr, + properties: t.Optional[t.Dict] = None, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Create: """ Convert this expression to a CREATE TABLE AS statement. @@ -2787,15 +3004,15 @@ class Select(Subqueryable): 'CREATE TABLE x AS SELECT * FROM tbl' Args: - table (str | Expression): the SQL code string to parse as the table name. + table: the SQL code string to parse as the table name. If another `Expression` instance is passed, it will be used as-is. - properties (dict): an optional mapping of table properties - dialect (str): the dialect used to parse the input table. - copy (bool): if `False`, modify this expression instance in-place. - opts (kwargs): other options to use to parse the input table. + properties: an optional mapping of table properties + dialect: the dialect used to parse the input table. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input table. Returns: - Create: the CREATE TABLE AS expression + The new Create expression. """ instance = _maybe_copy(self, copy) table_expression = maybe_parse( @@ -2835,7 +3052,7 @@ class Select(Subqueryable): """ inst = _maybe_copy(self, copy) - inst.set("lock", Lock(update=update)) + inst.set("locks", [Lock(update=update)]) return inst @@ -2874,7 +3091,7 @@ class Subquery(DerivedTable, Unionable): return self.this.is_star @property - def output_name(self): + def output_name(self) -> str: return self.alias @@ -2903,13 +3120,17 @@ class Tag(Expression): } +# Represents both the standard SQL PIVOT operator and DuckDB's "simplified" PIVOT syntax +# https://duckdb.org/docs/sql/statements/pivot class Pivot(Expression): arg_types = { "this": False, "alias": False, "expressions": True, - "field": True, - "unpivot": True, + "field": False, + "unpivot": False, + "using": False, + "group": False, "columns": False, } @@ -2948,7 +3169,7 @@ class Star(Expression): return "*" @property - def output_name(self): + def output_name(self) -> str: return self.name @@ -2961,7 +3182,7 @@ class SessionParameter(Expression): class Placeholder(Expression): - arg_types = {"this": False} + arg_types = {"this": False, "kind": False} class Null(Condition): @@ -2976,6 +3197,10 @@ class Boolean(Condition): pass +class DataTypeSize(Expression): + arg_types = {"this": True, "expression": False} + + class DataType(Expression): arg_types = { "this": True, @@ -2986,68 +3211,69 @@ class DataType(Expression): } class Type(AutoName): - CHAR = auto() - NCHAR = auto() - VARCHAR = auto() - NVARCHAR = auto() - TEXT = auto() - MEDIUMTEXT = auto() - LONGTEXT = auto() - MEDIUMBLOB = auto() - LONGBLOB = auto() - BINARY = auto() - VARBINARY = auto() - INT = auto() - UINT = auto() - TINYINT = auto() - UTINYINT = auto() - SMALLINT = auto() - USMALLINT = auto() - BIGINT = auto() - UBIGINT = auto() - INT128 = auto() - UINT128 = auto() - INT256 = auto() - UINT256 = auto() - FLOAT = auto() - DOUBLE = auto() - DECIMAL = auto() + ARRAY = auto() BIGDECIMAL = auto() + BIGINT = auto() + BIGSERIAL = auto() + BINARY = auto() BIT = auto() BOOLEAN = auto() - JSON = auto() - JSONB = auto() - INTERVAL = auto() - TIME = auto() - TIMESTAMP = auto() - TIMESTAMPTZ = auto() - TIMESTAMPLTZ = auto() + CHAR = auto() DATE = auto() DATETIME = auto() - ARRAY = auto() - MAP = auto() - UUID = auto() + DATETIME64 = auto() + DECIMAL = auto() + DOUBLE = auto() + FLOAT = auto() GEOGRAPHY = auto() GEOMETRY = auto() - STRUCT = auto() - NULLABLE = auto() HLLSKETCH = auto() HSTORE = auto() - SUPER = auto() - SERIAL = auto() - SMALLSERIAL = auto() - BIGSERIAL = auto() - XML = auto() - UNIQUEIDENTIFIER = auto() - MONEY = auto() - SMALLMONEY = auto() - ROWVERSION = auto() IMAGE = auto() - VARIANT = auto() - OBJECT = auto() INET = auto() + INT = auto() + INT128 = auto() + INT256 = auto() + INTERVAL = auto() + JSON = auto() + JSONB = auto() + LONGBLOB = auto() + LONGTEXT = auto() + MAP = auto() + MEDIUMBLOB = auto() + MEDIUMTEXT = auto() + MONEY = auto() + NCHAR = auto() NULL = auto() + NULLABLE = auto() + NVARCHAR = auto() + OBJECT = auto() + ROWVERSION = auto() + SERIAL = auto() + SMALLINT = auto() + SMALLMONEY = auto() + SMALLSERIAL = auto() + STRUCT = auto() + SUPER = auto() + TEXT = auto() + TIME = auto() + TIMESTAMP = auto() + TIMESTAMPTZ = auto() + TIMESTAMPLTZ = auto() + TINYINT = auto() + UBIGINT = auto() + UINT = auto() + USMALLINT = auto() + UTINYINT = auto() UNKNOWN = auto() # Sentinel value, useful for type annotation + UINT128 = auto() + UINT256 = auto() + UNIQUEIDENTIFIER = auto() + UUID = auto() + VARBINARY = auto() + VARCHAR = auto() + VARIANT = auto() + XML = auto() TEXT_TYPES = { Type.CHAR, @@ -3079,6 +3305,7 @@ class DataType(Expression): Type.TIMESTAMPLTZ, Type.DATE, Type.DATETIME, + Type.DATETIME64, } @classmethod @@ -3092,6 +3319,7 @@ class DataType(Expression): data_type_exp: t.Optional[Expression] = DataType(this=DataType.Type[dtype.upper()]) else: data_type_exp = parse_one(dtype, read=dialect, into=DataType) + if data_type_exp is None: raise ValueError(f"Unparsable data type value: {dtype}") elif isinstance(dtype, DataType.Type): @@ -3100,6 +3328,7 @@ class DataType(Expression): return dtype else: raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type") + return DataType(**{**data_type_exp.args, **kwargs}) def is_type(self, dtype: DataType.Type) -> bool: @@ -3361,7 +3590,7 @@ class Alias(Expression): arg_types = {"this": True, "alias": False} @property - def output_name(self): + def output_name(self) -> str: return self.alias @@ -3411,12 +3640,17 @@ class TimeUnit(Expression): args["unit"] = Var(this=unit.name) elif isinstance(unit, Week): unit.set("this", Var(this=unit.this.name)) + super().__init__(**args) class Interval(TimeUnit): arg_types = {"this": False, "unit": False} + @property + def unit(self) -> t.Optional[Var]: + return self.args.get("unit") + class IgnoreNulls(Expression): pass @@ -3480,6 +3714,10 @@ class AggFunc(Func): pass +class ParameterizedAgg(AggFunc): + arg_types = {"this": True, "expressions": True, "params": True} + + class Abs(Func): pass @@ -3498,6 +3736,7 @@ class Hll(AggFunc): class ApproxDistinct(AggFunc): arg_types = {"this": True, "accuracy": False} + _sql_names = ["APPROX_DISTINCT", "APPROX_COUNT_DISTINCT"] class Array(Func): @@ -3600,17 +3839,21 @@ class Cast(Func): return self.this.name @property - def to(self): + def to(self) -> DataType: return self.args["to"] @property - def output_name(self): + def output_name(self) -> str: return self.name def is_type(self, dtype: DataType.Type) -> bool: return self.to.is_type(dtype) +class CastToStrType(Func): + arg_types = {"this": True, "expression": True} + + class Collate(Binary): pass @@ -3796,10 +4039,6 @@ class Explode(Func): pass -class ExponentialTimeDecayedAvg(AggFunc): - arg_types = {"this": True, "time": False, "decay": False} - - class Floor(Func): arg_types = {"this": True, "decimals": False} @@ -3821,18 +4060,10 @@ class GroupConcat(Func): arg_types = {"this": True, "separator": False} -class GroupUniqArray(AggFunc): - arg_types = {"this": True, "size": False} - - class Hex(Func): pass -class Histogram(AggFunc): - arg_types = {"this": True, "bins": False} - - class If(Func): arg_types = {"this": True, "true": True, "false": False} @@ -3843,7 +4074,7 @@ class IfNull(Func): class Initcap(Func): - pass + arg_types = {"this": True, "expression": False} class JSONKeyValue(Expression): @@ -3861,6 +4092,14 @@ class JSONObject(Func): } +class OpenJSONColumnDef(Expression): + arg_types = {"this": True, "kind": True, "path": False, "as_json": False} + + +class OpenJSON(Func): + arg_types = {"this": True, "path": False, "expressions": False} + + class JSONBContains(Binary): _sql_names = ["JSONB_CONTAINS"] @@ -3945,6 +4184,14 @@ class VarMap(Func): arg_types = {"keys": True, "values": True} is_var_len_args = True + @property + def keys(self) -> t.List[Expression]: + return self.args["keys"].expressions + + @property + def values(self) -> t.List[Expression]: + return self.args["values"].expressions + # https://dev.mysql.com/doc/refman/8.0/en/fulltext-search.html class MatchAgainst(Func): @@ -3993,17 +4240,6 @@ class Quantile(AggFunc): arg_types = {"this": True, "quantile": True} -# Clickhouse-specific: -# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/quantiles/#quantiles -class Quantiles(AggFunc): - arg_types = {"parameters": True, "expressions": True} - is_var_len_args = True - - -class QuantileIf(AggFunc): - arg_types = {"parameters": True, "expressions": True} - - class ApproxQuantile(Quantile): arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False} @@ -4089,6 +4325,10 @@ class Substring(Func): arg_types = {"this": True, "start": False, "length": False} +class StandardHash(Func): + arg_types = {"this": True, "expression": False} + + class StrPosition(Func): arg_types = { "this": True, @@ -4328,15 +4568,19 @@ def maybe_parse( return sql_or_expression.copy() return sql_or_expression + if sql_or_expression is None: + raise ParseError(f"SQL cannot be None") + import sqlglot sql = str(sql_or_expression) if prefix: sql = f"{prefix} {sql}" + return sqlglot.parse_one(sql, read=dialect, into=into, **opts) -def _maybe_copy(instance, copy=True): +def _maybe_copy(instance: E, copy: bool = True) -> E: return instance.copy() if copy else instance @@ -4383,16 +4627,18 @@ def _apply_child_list_builder( instance = _maybe_copy(instance, copy) parsed = [] for expression in expressions: - if _is_wrong_expression(expression, into): - expression = into(expressions=[expression]) - expression = maybe_parse( - expression, - into=into, - dialect=dialect, - prefix=prefix, - **opts, - ) - parsed.extend(expression.expressions) + if expression is not None: + if _is_wrong_expression(expression, into): + expression = into(expressions=[expression]) + + expression = maybe_parse( + expression, + into=into, + dialect=dialect, + prefix=prefix, + **opts, + ) + parsed.extend(expression.expressions) existing = instance.args.get(arg) if append and existing: @@ -4402,6 +4648,7 @@ def _apply_child_list_builder( for k, v in (properties or {}).items(): child.set(k, v) instance.set(arg, child) + return instance @@ -4427,6 +4674,7 @@ def _apply_list_builder( **opts, ) for expression in expressions + if expression is not None ] existing_expressions = inst.args.get(arg) @@ -4463,25 +4711,59 @@ def _apply_conjunction_builder( return inst -def _combine(expressions, operator, dialect=None, copy=True, **opts): - expressions = [ - condition(expression, dialect=dialect, copy=copy, **opts) for expression in expressions +def _apply_cte_builder( + instance: E, + alias: ExpOrStr, + as_: ExpOrStr, + recursive: t.Optional[bool] = None, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, +) -> E: + alias_expression = maybe_parse(alias, dialect=dialect, into=TableAlias, **opts) + as_expression = maybe_parse(as_, dialect=dialect, **opts) + cte = CTE(this=as_expression, alias=alias_expression) + return _apply_child_list_builder( + cte, + instance=instance, + arg="with", + append=append, + copy=copy, + into=With, + properties={"recursive": recursive or False}, + ) + + +def _combine( + expressions: t.Sequence[t.Optional[ExpOrStr]], + operator: t.Type[Connector], + dialect: DialectType = None, + copy: bool = True, + **opts, +) -> Expression: + conditions = [ + condition(expression, dialect=dialect, copy=copy, **opts) + for expression in expressions + if expression is not None ] - this = expressions[0] - if expressions[1:]: + + this, *rest = conditions + if rest: this = _wrap(this, Connector) - for expression in expressions[1:]: + for expression in rest: this = operator(this=this, expression=_wrap(expression, Connector)) + return this def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren: - if isinstance(expression, kind): - return Paren(this=expression) - return expression + return Paren(this=expression) if isinstance(expression, kind) else expression -def union(left, right, distinct=True, dialect=None, **opts): +def union( + left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts +) -> Union: """ Initializes a syntax tree from one UNION expression. @@ -4490,15 +4772,16 @@ def union(left, right, distinct=True, dialect=None, **opts): 'SELECT * FROM foo UNION SELECT * FROM bla' Args: - left (str | Expression): the SQL code string corresponding to the left-hand side. + left: the SQL code string corresponding to the left-hand side. If an `Expression` instance is passed, it will be used as-is. - right (str | Expression): the SQL code string corresponding to the right-hand side. + right: the SQL code string corresponding to the right-hand side. If an `Expression` instance is passed, it will be used as-is. - distinct (bool): set the DISTINCT flag if and only if this is true. - dialect (str): the dialect used to parse the input expression. - opts (kwargs): other options to use to parse the input expressions. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + opts: other options to use to parse the input expressions. + Returns: - Union: the syntax tree for the UNION expression. + The new Union instance. """ left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) @@ -4506,7 +4789,9 @@ def union(left, right, distinct=True, dialect=None, **opts): return Union(this=left, expression=right, distinct=distinct) -def intersect(left, right, distinct=True, dialect=None, **opts): +def intersect( + left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts +) -> Intersect: """ Initializes a syntax tree from one INTERSECT expression. @@ -4515,15 +4800,16 @@ def intersect(left, right, distinct=True, dialect=None, **opts): 'SELECT * FROM foo INTERSECT SELECT * FROM bla' Args: - left (str | Expression): the SQL code string corresponding to the left-hand side. + left: the SQL code string corresponding to the left-hand side. If an `Expression` instance is passed, it will be used as-is. - right (str | Expression): the SQL code string corresponding to the right-hand side. + right: the SQL code string corresponding to the right-hand side. If an `Expression` instance is passed, it will be used as-is. - distinct (bool): set the DISTINCT flag if and only if this is true. - dialect (str): the dialect used to parse the input expression. - opts (kwargs): other options to use to parse the input expressions. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + opts: other options to use to parse the input expressions. + Returns: - Intersect: the syntax tree for the INTERSECT expression. + The new Intersect instance. """ left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) @@ -4531,7 +4817,9 @@ def intersect(left, right, distinct=True, dialect=None, **opts): return Intersect(this=left, expression=right, distinct=distinct) -def except_(left, right, distinct=True, dialect=None, **opts): +def except_( + left: ExpOrStr, right: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts +) -> Except: """ Initializes a syntax tree from one EXCEPT expression. @@ -4540,15 +4828,16 @@ def except_(left, right, distinct=True, dialect=None, **opts): 'SELECT * FROM foo EXCEPT SELECT * FROM bla' Args: - left (str | Expression): the SQL code string corresponding to the left-hand side. + left: the SQL code string corresponding to the left-hand side. If an `Expression` instance is passed, it will be used as-is. - right (str | Expression): the SQL code string corresponding to the right-hand side. + right: the SQL code string corresponding to the right-hand side. If an `Expression` instance is passed, it will be used as-is. - distinct (bool): set the DISTINCT flag if and only if this is true. - dialect (str): the dialect used to parse the input expression. - opts (kwargs): other options to use to parse the input expressions. + distinct: set the DISTINCT flag if and only if this is true. + dialect: the dialect used to parse the input expression. + opts: other options to use to parse the input expressions. + Returns: - Except: the syntax tree for the EXCEPT statement. + The new Except instance. """ left = maybe_parse(sql_or_expression=left, dialect=dialect, **opts) right = maybe_parse(sql_or_expression=right, dialect=dialect, **opts) @@ -4578,7 +4867,7 @@ def select(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Selec return Select().select(*expressions, dialect=dialect, **opts) -def from_(*expressions, dialect=None, **opts) -> Select: +def from_(expression: ExpOrStr, dialect: DialectType = None, **opts) -> Select: """ Initializes a syntax tree from a FROM expression. @@ -4587,9 +4876,9 @@ def from_(*expressions, dialect=None, **opts) -> Select: 'SELECT col1, col2 FROM tbl' Args: - *expressions (str | Expression): the SQL code string to parse as the FROM expressions of a + *expression: the SQL code string to parse as the FROM expressions of a SELECT statement. If an Expression instance is passed, this is used as-is. - dialect (str): the dialect used to parse the input expression (in the case that the + dialect: the dialect used to parse the input expression (in the case that the input expression is a SQL string). **opts: other options to use to parse the input expressions (again, in the case that the input expression is a SQL string). @@ -4597,7 +4886,7 @@ def from_(*expressions, dialect=None, **opts) -> Select: Returns: Select: the syntax tree for the SELECT statement. """ - return Select().from_(*expressions, dialect=dialect, **opts) + return Select().from_(expression, dialect=dialect, **opts) def update( @@ -4680,7 +4969,54 @@ def delete( return delete_expr -def condition(expression, dialect=None, copy=True, **opts) -> Condition: +def insert( + expression: ExpOrStr, + into: ExpOrStr, + columns: t.Optional[t.Sequence[ExpOrStr]] = None, + overwrite: t.Optional[bool] = None, + dialect: DialectType = None, + copy: bool = True, + **opts, +) -> Insert: + """ + Builds an INSERT statement. + + Example: + >>> insert("VALUES (1, 2, 3)", "tbl").sql() + 'INSERT INTO tbl VALUES (1, 2, 3)' + + Args: + expression: the sql string or expression of the INSERT statement + into: the tbl to insert data to. + columns: optionally the table's column names. + overwrite: whether to INSERT OVERWRITE or not. + dialect: the dialect used to parse the input expressions. + copy: whether or not to copy the expression. + **opts: other options to use to parse the input expressions. + + Returns: + Insert: the syntax tree for the INSERT statement. + """ + expr = maybe_parse(expression, dialect=dialect, copy=copy, **opts) + this: Table | Schema = maybe_parse(into, into=Table, dialect=dialect, copy=copy, **opts) + + if columns: + this = _apply_list_builder( + *columns, + instance=Schema(this=this), + arg="expressions", + into=Identifier, + copy=False, + dialect=dialect, + **opts, + ) + + return Insert(this=this, expression=expr, overwrite=overwrite) + + +def condition( + expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts +) -> Condition: """ Initialize a logical condition expression. @@ -4695,18 +5031,18 @@ def condition(expression, dialect=None, copy=True, **opts) -> Condition: 'SELECT * FROM tbl WHERE x = 1 AND y = 1' Args: - *expression (str | Expression): the SQL code string to parse. + *expression: the SQL code string to parse. If an Expression instance is passed, this is used as-is. - dialect (str): the dialect used to parse the input expression (in the case that the + dialect: the dialect used to parse the input expression (in the case that the input expression is a SQL string). - copy (bool): Whether or not to copy `expression` (only applies to expressions). + copy: Whether or not to copy `expression` (only applies to expressions). **opts: other options to use to parse the input expressions (again, in the case that the input expression is a SQL string). Returns: - Condition: the expression + The new Condition instance """ - return maybe_parse( # type: ignore + return maybe_parse( expression, into=Condition, dialect=dialect, @@ -4715,7 +5051,9 @@ def condition(expression, dialect=None, copy=True, **opts) -> Condition: ) -def and_(*expressions, dialect=None, copy=True, **opts) -> And: +def and_( + *expressions: t.Optional[ExpOrStr], dialect: DialectType = None, copy: bool = True, **opts +) -> Condition: """ Combine multiple conditions with an AND logical operator. @@ -4724,19 +5062,21 @@ def and_(*expressions, dialect=None, copy=True, **opts) -> And: 'x = 1 AND (y = 1 AND z = 1)' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an Expression instance is passed, this is used as-is. - dialect (str): the dialect used to parse the input expression. - copy (bool): whether or not to copy `expressions` (only applies to Expressions). + dialect: the dialect used to parse the input expression. + copy: whether or not to copy `expressions` (only applies to Expressions). **opts: other options to use to parse the input expressions. Returns: And: the new condition """ - return _combine(expressions, And, dialect, copy=copy, **opts) + return t.cast(Condition, _combine(expressions, And, dialect, copy=copy, **opts)) -def or_(*expressions, dialect=None, copy=True, **opts) -> Or: +def or_( + *expressions: t.Optional[ExpOrStr], dialect: DialectType = None, copy: bool = True, **opts +) -> Condition: """ Combine multiple conditions with an OR logical operator. @@ -4745,19 +5085,19 @@ def or_(*expressions, dialect=None, copy=True, **opts) -> Or: 'x = 1 OR (y = 1 OR z = 1)' Args: - *expressions (str | Expression): the SQL code strings to parse. + *expressions: the SQL code strings to parse. If an Expression instance is passed, this is used as-is. - dialect (str): the dialect used to parse the input expression. - copy (bool): whether or not to copy `expressions` (only applies to Expressions). + dialect: the dialect used to parse the input expression. + copy: whether or not to copy `expressions` (only applies to Expressions). **opts: other options to use to parse the input expressions. Returns: Or: the new condition """ - return _combine(expressions, Or, dialect, copy=copy, **opts) + return t.cast(Condition, _combine(expressions, Or, dialect, copy=copy, **opts)) -def not_(expression, dialect=None, copy=True, **opts) -> Not: +def not_(expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts) -> Not: """ Wrap a condition with a NOT operator. @@ -4766,13 +5106,14 @@ def not_(expression, dialect=None, copy=True, **opts) -> Not: "NOT this_suit = 'black'" Args: - expression (str | Expression): the SQL code strings to parse. + expression: the SQL code string to parse. If an Expression instance is passed, this is used as-is. - dialect (str): the dialect used to parse the input expression. + dialect: the dialect used to parse the input expression. + copy: whether to copy the expression or not. **opts: other options to use to parse the input expressions. Returns: - Not: the new condition + The new condition. """ this = condition( expression, @@ -4783,29 +5124,47 @@ def not_(expression, dialect=None, copy=True, **opts) -> Not: return Not(this=_wrap(this, Connector)) -def paren(expression, copy=True) -> Paren: - return Paren(this=_maybe_copy(expression, copy)) +def paren(expression: ExpOrStr, copy: bool = True) -> Paren: + """ + Wrap an expression in parentheses. + + Example: + >>> paren("5 + 3").sql() + '(5 + 3)' + + Args: + expression: the SQL code string to parse. + If an Expression instance is passed, this is used as-is. + copy: whether to copy the expression or not. + + Returns: + The wrapped expression. + """ + return Paren(this=maybe_parse(expression, copy=copy)) SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$") @t.overload -def to_identifier(name: None, quoted: t.Optional[bool] = None) -> None: +def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: ... @t.overload -def to_identifier(name: str | Identifier, quoted: t.Optional[bool] = None) -> Identifier: +def to_identifier( + name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True +) -> Identifier: ... -def to_identifier(name, quoted=None): +def to_identifier(name, quoted=None, copy=True): """Builds an identifier. Args: name: The name to turn into an identifier. quoted: Whether or not force quote the identifier. + copy: Whether or not to copy a passed in Identefier node. Returns: The identifier ast node. @@ -4815,7 +5174,7 @@ def to_identifier(name, quoted=None): return None if isinstance(name, Identifier): - identifier = name + identifier = _maybe_copy(name, copy) elif isinstance(name, str): identifier = Identifier( this=name, @@ -4858,13 +5217,17 @@ def to_table(sql_path: None, **kwargs) -> None: ... -def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]: +def to_table( + sql_path: t.Optional[str | Table], dialect: DialectType = None, **kwargs +) -> t.Optional[Table]: """ Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. If a table is passed in then that table is returned. Args: sql_path: a `[catalog].[schema].[table]` string. + dialect: the source dialect according to which the table name will be parsed. + kwargs: the kwargs to instantiate the resulting `Table` expression with. Returns: A table expression. @@ -4874,8 +5237,12 @@ def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]: if not isinstance(sql_path, str): raise ValueError(f"Invalid type provided for a table: {type(sql_path)}") - catalog, db, table_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 3)) - return Table(this=table_name, db=db, catalog=catalog, **kwargs) + table = maybe_parse(sql_path, into=Table, dialect=dialect) + if table: + for k, v in kwargs.items(): + table.set(k, v) + + return table def to_column(sql_path: str | Column, **kwargs) -> Column: @@ -4902,6 +5269,7 @@ def alias_( table: bool | t.Sequence[str | Identifier] = False, quoted: t.Optional[bool] = None, dialect: DialectType = None, + copy: bool = True, **opts, ): """Create an Alias expression. @@ -4921,18 +5289,17 @@ def alias_( table: Whether or not to create a table alias, can also be a list of columns. quoted: whether or not to quote the alias dialect: the dialect used to parse the input expression. + copy: Whether or not to copy the expression. **opts: other options to use to parse the input expressions. Returns: Alias: the aliased expression """ - exp = maybe_parse(expression, dialect=dialect, **opts) + exp = maybe_parse(expression, dialect=dialect, copy=copy, **opts) alias = to_identifier(alias, quoted=quoted) if table: table_alias = TableAlias(this=alias) - - exp = exp.copy() if isinstance(expression, Expression) else exp exp.set("alias", table_alias) if not isinstance(table, bool): @@ -4948,13 +5315,17 @@ def alias_( # [1]: https://cloud.google.com/bigquery/docs/reference/standard-sql/window-function-calls if "alias" in exp.arg_types and not isinstance(exp, Window): - exp = exp.copy() exp.set("alias", alias) return exp return Alias(this=exp, alias=alias) -def subquery(expression, alias=None, dialect=None, **opts): +def subquery( + expression: ExpOrStr, + alias: t.Optional[Identifier | str] = None, + dialect: DialectType = None, + **opts, +) -> Select: """ Build a subquery expression. @@ -4963,14 +5334,14 @@ def subquery(expression, alias=None, dialect=None, **opts): 'SELECT x FROM (SELECT x FROM tbl) AS bar' Args: - expression (str | Expression): the SQL code strings to parse. + expression: the SQL code strings to parse. If an Expression instance is passed, this is used as-is. - alias (str | Expression): the alias name to use. - dialect (str): the dialect used to parse the input expression. + alias: the alias name to use. + dialect: the dialect used to parse the input expression. **opts: other options to use to parse the input expressions. Returns: - Select: a new select with the subquery expression included + A new Select instance with the subquery expression included. """ expression = maybe_parse(expression, dialect=dialect, **opts).subquery(alias) @@ -4988,13 +5359,14 @@ def column( Build a Column. Args: - col: column name - table: table name - db: db name - catalog: catalog name - quoted: whether or not to force quote each part + col: Column name. + table: Table name. + db: Database name. + catalog: Catalog name. + quoted: Whether to force quotes on the column's identifiers. + Returns: - Column: column instance + The new Column instance. """ return Column( this=to_identifier(col, quoted=quoted), @@ -5016,22 +5388,30 @@ def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Ca to: The datatype to cast to. Returns: - A cast node. + The new Cast instance. """ expression = maybe_parse(expression, **opts) return Cast(this=expression, to=DataType.build(to, **opts)) -def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table: +def table_( + table: Identifier | str, + db: t.Optional[Identifier | str] = None, + catalog: t.Optional[Identifier | str] = None, + quoted: t.Optional[bool] = None, + alias: t.Optional[Identifier | str] = None, +) -> Table: """Build a Table. Args: - table (str | Expression): column name - db (str | Expression): db name - catalog (str | Expression): catalog name + table: Table name. + db: Database name. + catalog: Catalog name. + quote: Whether to force quotes on the table's identifiers. + alias: Table's alias. Returns: - Table: table instance + The new Table instance. """ return Table( this=to_identifier(table, quoted=quoted), @@ -5160,7 +5540,7 @@ def convert(value: t.Any, copy: bool = False) -> Expression: raise ValueError(f"Cannot convert {value}") -def replace_children(expression, fun, *args, **kwargs): +def replace_children(expression: Expression, fun: t.Callable, *args, **kwargs) -> None: """ Replace children of an expression with the result of a lambda fun(child) -> exp. """ @@ -5182,7 +5562,7 @@ def replace_children(expression, fun, *args, **kwargs): expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0) -def column_table_names(expression): +def column_table_names(expression: Expression) -> t.List[str]: """ Return all table names referenced through columns in an expression. @@ -5192,19 +5572,19 @@ def column_table_names(expression): ['c', 'a'] Args: - expression (sqlglot.Expression): expression to find table names + expression: expression to find table names. Returns: - list: A list of unique names + A list of unique names. """ return list(dict.fromkeys(column.table for column in expression.find_all(Column))) -def table_name(table) -> str: +def table_name(table: Table | str) -> str: """Get the full name of a table as a string. Args: - table (exp.Table | str): table expression node or string. + table: table expression node or string. Examples: >>> from sqlglot import exp, parse_one @@ -5220,23 +5600,15 @@ def table_name(table) -> str: if not table: raise ValueError(f"Cannot parse {table}") - return ".".join( - part - for part in ( - table.text("catalog"), - table.text("db"), - table.name, - ) - if part - ) + return ".".join(part for part in (table.text("catalog"), table.text("db"), table.name) if part) -def replace_tables(expression, mapping): +def replace_tables(expression: E, mapping: t.Dict[str, str]) -> E: """Replace all tables in expression according to the mapping. Args: - expression (sqlglot.Expression): expression node to be transformed and replaced. - mapping (Dict[str, str]): mapping of table names. + expression: expression node to be transformed and replaced. + mapping: mapping of table names. Examples: >>> from sqlglot import exp, parse_one @@ -5247,7 +5619,7 @@ def replace_tables(expression, mapping): The mapped expression. """ - def _replace_tables(node): + def _replace_tables(node: Expression) -> Expression: if isinstance(node, Table): new_name = mapping.get(table_name(node)) if new_name: @@ -5260,11 +5632,11 @@ def replace_tables(expression, mapping): return expression.transform(_replace_tables) -def replace_placeholders(expression, *args, **kwargs): +def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression: """Replace placeholders in an expression. Args: - expression (sqlglot.Expression): expression node to be transformed and replaced. + expression: expression node to be transformed and replaced. args: positional names that will substitute unnamed placeholders in the given order. kwargs: keyword arguments that will substitute named placeholders. @@ -5280,7 +5652,7 @@ def replace_placeholders(expression, *args, **kwargs): The mapped expression. """ - def _replace_placeholders(node, args, **kwargs): + def _replace_placeholders(node: Expression, args, **kwargs) -> Expression: if isinstance(node, Placeholder): if node.name: new_name = kwargs.get(node.name) @@ -5378,21 +5750,21 @@ def func(name: str, *args, dialect: DialectType = None, **kwargs) -> Func: return function -def true(): +def true() -> Boolean: """ Returns a true Boolean expression. """ return Boolean(this=True) -def false(): +def false() -> Boolean: """ Returns a false Boolean expression. """ return Boolean(this=False) -def null(): +def null() -> Null: """ Returns a Null expression. """ diff --git a/sqlglot/generator.py b/sqlglot/generator.py index d7dcea0..f1ec398 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -31,6 +31,8 @@ class Generator: hex_end (str): specifies which ending character to use to delimit hex literals. Default: None. byte_start (str): specifies which starting character to use to delimit byte literals. Default: None. byte_end (str): specifies which ending character to use to delimit byte literals. Default: None. + raw_start (str): specifies which starting character to use to delimit raw literals. Default: None. + raw_end (str): specifies which ending character to use to delimit raw literals. Default: None. identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always. normalize (bool): if set to True all identifiers will lower cased string_escape (str): specifies a string escape character. Default: '. @@ -76,11 +78,12 @@ class Generator: exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", exp.MaterializedProperty: lambda self, e: "MATERIALIZED", exp.NoPrimaryIndexProperty: lambda self, e: "NO PRIMARY INDEX", - exp.OnCommitProperty: lambda self, e: "ON COMMIT PRESERVE ROWS", + exp.OnCommitProperty: lambda self, e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS", exp.ReturnsProperty: lambda self, e: self.naked_property(e), exp.SetProperty: lambda self, e: f"{'MULTI' if e.args.get('multi') else ''}SET", + exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}", exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", - exp.TemporaryProperty: lambda self, e: f"{'GLOBAL ' if e.args.get('global_') else ''}TEMPORARY", + exp.TemporaryProperty: lambda self, e: f"TEMPORARY", exp.TransientProperty: lambda self, e: "TRANSIENT", exp.StabilityProperty: lambda self, e: e.name, exp.VolatileProperty: lambda self, e: "VOLATILE", @@ -133,6 +136,15 @@ class Generator: # Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH") LIMIT_FETCH = "ALL" + # Whether a table is allowed to be renamed with a db + RENAME_TABLE_WITH_DB = True + + # The separator for grouping sets and rollups + GROUPINGS_SEP = "," + + # The string used for creating index on a table + INDEX_ON = "ON" + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -167,7 +179,6 @@ class Generator: PARAMETER_TOKEN = "@" PROPERTIES_LOCATION = { - exp.AfterJournalProperty: exp.Properties.Location.POST_NAME, exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE, exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA, exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME, @@ -196,7 +207,9 @@ class Generator: exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME, exp.NoPrimaryIndexProperty: exp.Properties.Location.POST_EXPRESSION, exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION, + exp.Order: exp.Properties.Location.POST_SCHEMA, exp.PartitionedByProperty: exp.Properties.Location.POST_WITH, + exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA, exp.Property: exp.Properties.Location.POST_WITH, exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA, exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA, @@ -204,13 +217,15 @@ class Generator: exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA, exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA, exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA, + exp.Set: exp.Properties.Location.POST_SCHEMA, + exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA, exp.SetProperty: exp.Properties.Location.POST_CREATE, exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA, exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA, - exp.TableFormatProperty: exp.Properties.Location.POST_WITH, exp.TemporaryProperty: exp.Properties.Location.POST_CREATE, exp.TransientProperty: exp.Properties.Location.POST_CREATE, + exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA, exp.VolatileProperty: exp.Properties.Location.POST_CREATE, exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION, exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME, @@ -221,7 +236,7 @@ class Generator: RESERVED_KEYWORDS: t.Set[str] = set() WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With) - UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren, exp.Column) + UNWRAPPED_INTERVAL_VALUES = (exp.Column, exp.Literal, exp.Neg, exp.Paren) SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" @@ -239,6 +254,8 @@ class Generator: "hex_end", "byte_start", "byte_end", + "raw_start", + "raw_end", "identify", "normalize", "string_escape", @@ -276,6 +293,8 @@ class Generator: hex_end=None, byte_start=None, byte_end=None, + raw_start=None, + raw_end=None, identify=False, normalize=False, string_escape=None, @@ -308,6 +327,8 @@ class Generator: self.hex_end = hex_end self.byte_start = byte_start self.byte_end = byte_end + self.raw_start = raw_start + self.raw_end = raw_end self.identify = identify self.normalize = normalize self.string_escape = string_escape or "'" @@ -399,7 +420,11 @@ class Generator: return sql if isinstance(expression, self.WITH_SEPARATED_COMMENTS): - return f"{comments_sql}{self.sep()}{sql}" + return ( + f"{self.sep()}{comments_sql}{sql}" + if sql[0].isspace() + else f"{comments_sql}{self.sep()}{sql}" + ) return f"{sql} {comments_sql}" @@ -567,7 +592,9 @@ class Generator: ) -> str: this = "" if expression.this is not None: - this = " ALWAYS " if expression.this else " BY DEFAULT " + on_null = "ON NULL " if expression.args.get("on_null") else "" + this = " ALWAYS " if expression.this else f" BY DEFAULT {on_null}" + start = expression.args.get("start") start = f"START WITH {start}" if start else "" increment = expression.args.get("increment") @@ -578,14 +605,20 @@ class Generator: maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else "" cycle = expression.args.get("cycle") cycle_sql = "" + if cycle is not None: cycle_sql = f"{' NO' if not cycle else ''} CYCLE" cycle_sql = cycle_sql.strip() if not start and not increment else cycle_sql + sequence_opts = "" if start or increment or cycle_sql: sequence_opts = f"{start}{increment}{minvalue}{maxvalue}{cycle_sql}" sequence_opts = f" ({sequence_opts.strip()})" - return f"GENERATED{this}AS IDENTITY{sequence_opts}" + + expr = self.sql(expression, "expression") + expr = f"({expr})" if expr else "IDENTITY" + + return f"GENERATED{this}AS {expr}{sequence_opts}" def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str: return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL" @@ -596,8 +629,10 @@ class Generator: return f"PRIMARY KEY{' DESC' if desc else ' ASC'}" return f"PRIMARY KEY" - def uniquecolumnconstraint_sql(self, _) -> str: - return "UNIQUE" + def uniquecolumnconstraint_sql(self, expression: exp.UniqueColumnConstraint) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + return f"UNIQUE{this}" def create_sql(self, expression: exp.Create) -> str: kind = self.sql(expression, "kind").upper() @@ -653,33 +688,9 @@ class Generator: prefix=" ", ) - indexes = expression.args.get("indexes") - if indexes: - indexes_sql: t.List[str] = [] - for index in indexes: - ind_unique = " UNIQUE" if index.args.get("unique") else "" - ind_primary = " PRIMARY" if index.args.get("primary") else "" - ind_amp = " AMP" if index.args.get("amp") else "" - ind_name = f" {index.name}" if index.name else "" - ind_columns = ( - f' ({self.expressions(index, key="columns", flat=True)})' - if index.args.get("columns") - else "" - ) - ind_sql = f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}" - - if indexes_sql: - indexes_sql.append(ind_sql) - else: - indexes_sql.append( - f"{ind_sql}{postindex_props_sql}" - if index.args.get("primary") - else f"{postindex_props_sql}{ind_sql}" - ) - - index_sql = "".join(indexes_sql) - else: - index_sql = postindex_props_sql + indexes = self.expressions(expression, key="indexes", indent=False, sep=" ") + indexes = f" {indexes}" if indexes else "" + index_sql = indexes + postindex_props_sql replace = " OR REPLACE" if expression.args.get("replace") else "" unique = " UNIQUE" if expression.args.get("unique") else "" @@ -711,9 +722,23 @@ class Generator: " WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else "" ) - expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}" + clone = self.sql(expression, "clone") + clone = f" {clone}" if clone else "" + + expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties_sql}{expression_sql}{postexpression_props_sql}{index_sql}{no_schema_binding}{clone}" return self.prepend_ctes(expression, expression_sql) + def clone_sql(self, expression: exp.Clone) -> str: + this = self.sql(expression, "this") + when = self.sql(expression, "when") + + if when: + kind = self.sql(expression, "kind") + expr = self.sql(expression, "expression") + return f"CLONE {this} {when} ({kind} => {expr})" + + return f"CLONE {this}" + def describe_sql(self, expression: exp.Describe) -> str: return f"DESCRIBE {self.sql(expression, 'this')}" @@ -757,6 +782,17 @@ class Generator: return f"{self.byte_start}{this}{self.byte_end}" return this + def rawstring_sql(self, expression: exp.RawString) -> str: + if self.raw_start: + return f"{self.raw_start}{expression.name}{self.raw_end}" + return self.sql(exp.Literal.string(expression.name.replace("\\", "\\\\"))) + + def datatypesize_sql(self, expression: exp.DataTypeSize) -> str: + this = self.sql(expression, "this") + specifier = self.sql(expression, "expression") + specifier = f" {specifier}" if specifier else "" + return f"{this}{specifier}" + def datatype_sql(self, expression: exp.DataType) -> str: type_value = expression.this type_sql = self.TYPE_MAPPING.get(type_value, type_value.value) @@ -768,7 +804,8 @@ class Generator: nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}" if expression.args.get("values") is not None: delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")") - values = f"{delimiters[0]}{self.expressions(expression, key='values')}{delimiters[1]}" + values = self.expressions(expression, key="values", flat=True) + values = f"{delimiters[0]}{values}{delimiters[1]}" else: nested = f"({interior})" @@ -836,10 +873,17 @@ class Generator: return "" def index_sql(self, expression: exp.Index) -> str: - this = self.sql(expression, "this") + unique = "UNIQUE " if expression.args.get("unique") else "" + primary = "PRIMARY " if expression.args.get("primary") else "" + amp = "AMP " if expression.args.get("amp") else "" + name = f"{expression.name} " if expression.name else "" table = self.sql(expression, "table") - columns = self.sql(expression, "columns") - return f"{this} ON {table} {columns}" + table = f"{self.INDEX_ON} {table} " if table else "" + index = "INDEX " if not table else "" + columns = self.expressions(expression, key="columns", flat=True) + partition_by = self.expressions(expression, key="partition_by", flat=True) + partition_by = f" PARTITION BY {partition_by}" if partition_by else "" + return f"{unique}{primary}{amp}{index}{name}{table}({columns}){partition_by}" def identifier_sql(self, expression: exp.Identifier) -> str: text = expression.name @@ -861,8 +905,9 @@ class Generator: output_format = f"OUTPUTFORMAT {output_format}" if output_format else "" return self.sep().join((input_format, output_format)) - def national_sql(self, expression: exp.National) -> str: - return f"N{self.sql(expression, 'this')}" + def national_sql(self, expression: exp.National, prefix: str = "N") -> str: + string = self.sql(exp.Literal.string(expression.name)) + return f"{prefix}{string}" def partition_sql(self, expression: exp.Partition) -> str: return f"PARTITION({self.expressions(expression)})" @@ -955,23 +1000,18 @@ class Generator: def journalproperty_sql(self, expression: exp.JournalProperty) -> str: no = "NO " if expression.args.get("no") else "" + local = expression.args.get("local") + local = f"{local} " if local else "" dual = "DUAL " if expression.args.get("dual") else "" before = "BEFORE " if expression.args.get("before") else "" - return f"{no}{dual}{before}JOURNAL" + after = "AFTER " if expression.args.get("after") else "" + return f"{no}{local}{dual}{before}{after}JOURNAL" def freespaceproperty_sql(self, expression: exp.FreespaceProperty) -> str: freespace = self.sql(expression, "this") percent = " PERCENT" if expression.args.get("percent") else "" return f"FREESPACE={freespace}{percent}" - def afterjournalproperty_sql(self, expression: exp.AfterJournalProperty) -> str: - no = "NO " if expression.args.get("no") else "" - dual = "DUAL " if expression.args.get("dual") else "" - local = "" - if expression.args.get("local") is not None: - local = "LOCAL " if expression.args.get("local") else "NOT LOCAL " - return f"{no}{dual}{local}AFTER JOURNAL" - def checksumproperty_sql(self, expression: exp.ChecksumProperty) -> str: if expression.args.get("default"): property = "DEFAULT" @@ -992,19 +1032,19 @@ class Generator: def datablocksizeproperty_sql(self, expression: exp.DataBlocksizeProperty) -> str: default = expression.args.get("default") - min = expression.args.get("min") - if default is not None or min is not None: + minimum = expression.args.get("minimum") + maximum = expression.args.get("maximum") + if default or minimum or maximum: if default: - property = "DEFAULT" - elif min: - property = "MINIMUM" + prop = "DEFAULT" + elif minimum: + prop = "MINIMUM" else: - property = "MAXIMUM" - return f"{property} DATABLOCKSIZE" - else: - units = expression.args.get("units") - units = f" {units}" if units else "" - return f"DATABLOCKSIZE={self.sql(expression, 'size')}{units}" + prop = "MAXIMUM" + return f"{prop} DATABLOCKSIZE" + units = expression.args.get("units") + units = f" {units}" if units else "" + return f"DATABLOCKSIZE={self.sql(expression, 'size')}{units}" def blockcompressionproperty_sql(self, expression: exp.BlockCompressionProperty) -> str: autotemp = expression.args.get("autotemp") @@ -1014,16 +1054,16 @@ class Generator: never = expression.args.get("never") if autotemp is not None: - property = f"AUTOTEMP({self.expressions(autotemp)})" + prop = f"AUTOTEMP({self.expressions(autotemp)})" elif always: - property = "ALWAYS" + prop = "ALWAYS" elif default: - property = "DEFAULT" + prop = "DEFAULT" elif manual: - property = "MANUAL" + prop = "MANUAL" elif never: - property = "NEVER" - return f"BLOCKCOMPRESSION={property}" + prop = "NEVER" + return f"BLOCKCOMPRESSION={prop}" def isolatedloadingproperty_sql(self, expression: exp.IsolatedLoadingProperty) -> str: no = expression.args.get("no") @@ -1138,21 +1178,24 @@ class Generator: alias = self.sql(expression, "alias") alias = f"{sep}{alias}" if alias else "" - hints = self.expressions(expression, key="hints", sep=", ", flat=True) + hints = self.expressions(expression, key="hints", flat=True) hints = f" WITH ({hints})" if hints and self.TABLE_HINTS else "" - laterals = self.expressions(expression, key="laterals", sep="") + pivots = self.expressions(expression, key="pivots", sep=" ", flat=True) + pivots = f" {pivots}" if pivots else "" joins = self.expressions(expression, key="joins", sep="") - pivots = self.expressions(expression, key="pivots", sep="") + laterals = self.expressions(expression, key="laterals", sep="") system_time = expression.args.get("system_time") system_time = f" {self.sql(expression, 'system_time')}" if system_time else "" - return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}" + return f"{table}{system_time}{alias}{hints}{pivots}{joins}{laterals}" def tablesample_sql( self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " ) -> str: if self.alias_post_tablesample and expression.this.alias: - this = self.sql(expression.this, "this") + table = expression.this.copy() + table.set("alias", None) + this = self.sql(table) alias = f"{sep}{self.sql(expression.this, 'alias')}" else: this = self.sql(expression, "this") @@ -1177,14 +1220,22 @@ class Generator: return f"{this} {kind} {method}({bucket}{percent}{rows}{size}){seed}{alias}" def pivot_sql(self, expression: exp.Pivot) -> str: - this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True) + + if expression.this: + this = self.sql(expression, "this") + on = f"{self.seg('ON')} {expressions}" + using = self.expressions(expression, key="using", flat=True) + using = f"{self.seg('USING')} {using}" if using else "" + group = self.sql(expression, "group") + return f"PIVOT {this}{on}{using}{group}" + alias = self.sql(expression, "alias") alias = f" AS {alias}" if alias else "" unpivot = expression.args.get("unpivot") direction = "UNPIVOT" if unpivot else "PIVOT" - expressions = self.expressions(expression, key="expressions") field = self.sql(expression, "field") - return f"{this} {direction}({expressions} FOR {field}){alias}" + return f"{direction}({expressions} FOR {field}){alias}" def tuple_sql(self, expression: exp.Tuple) -> str: return f"({self.expressions(expression, flat=True)})" @@ -1218,8 +1269,7 @@ class Generator: return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}" def from_sql(self, expression: exp.From) -> str: - expressions = self.expressions(expression, flat=True) - return f"{self.seg('FROM')} {expressions}" + return f"{self.seg('FROM')} {self.sql(expression, 'this')}" def group_sql(self, expression: exp.Group) -> str: group_by = self.op_expressions("GROUP BY", expression) @@ -1242,10 +1292,16 @@ class Generator: rollup_sql = self.expressions(expression, key="rollup", indent=False) rollup_sql = f"{self.seg('ROLLUP')} {self.wrap(rollup_sql)}" if rollup_sql else "" - groupings = csv(grouping_sets, cube_sql, rollup_sql, sep=",") + groupings = csv( + grouping_sets, + cube_sql, + rollup_sql, + self.seg("WITH TOTALS") if expression.args.get("totals") else "", + sep=self.GROUPINGS_SEP, + ) if expression.args.get("expressions") and groupings: - group_by = f"{group_by}," + group_by = f"{group_by}{self.GROUPINGS_SEP}" return f"{group_by}{groupings}" @@ -1254,18 +1310,16 @@ class Generator: return f"{self.seg('HAVING')}{self.sep()}{this}" def join_sql(self, expression: exp.Join) -> str: - op_sql = self.seg( - " ".join( - op - for op in ( - "NATURAL" if expression.args.get("natural") else None, - expression.side, - expression.kind, - expression.hint if self.JOIN_HINTS else None, - "JOIN", - ) - if op + op_sql = " ".join( + op + for op in ( + "NATURAL" if expression.args.get("natural") else None, + "GLOBAL" if expression.args.get("global") else None, + expression.side, + expression.kind, + expression.hint if self.JOIN_HINTS else None, ) + if op ) on_sql = self.sql(expression, "on") using = expression.args.get("using") @@ -1273,6 +1327,8 @@ class Generator: if not on_sql and using: on_sql = csv(*(self.sql(column) for column in using)) + this_sql = self.sql(expression, "this") + if on_sql: on_sql = self.indent(on_sql, skip_first=True) space = self.seg(" " * self.pad) if self.pretty else " " @@ -1280,10 +1336,11 @@ class Generator: on_sql = f"{space}USING ({on_sql})" else: on_sql = f"{space}ON {on_sql}" + elif not op_sql: + return f", {this_sql}" - expression_sql = self.sql(expression, "expression") - this_sql = self.sql(expression, "this") - return f"{expression_sql}{op_sql} {this_sql}{on_sql}" + op_sql = f"{op_sql} JOIN" if op_sql else "JOIN" + return f"{self.seg(op_sql)} {this_sql}{on_sql}" def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str: args = self.expressions(expression, flat=True) @@ -1336,12 +1393,22 @@ class Generator: return f"PRAGMA {self.sql(expression, 'this')}" def lock_sql(self, expression: exp.Lock) -> str: - if self.LOCKING_READS_SUPPORTED: - lock_type = "UPDATE" if expression.args["update"] else "SHARE" - return self.seg(f"FOR {lock_type}") + if not self.LOCKING_READS_SUPPORTED: + self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported") + return "" - self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported") - return "" + lock_type = "FOR UPDATE" if expression.args["update"] else "FOR SHARE" + expressions = self.expressions(expression, flat=True) + expressions = f" OF {expressions}" if expressions else "" + wait = expression.args.get("wait") + + if wait is not None: + if isinstance(wait, exp.Literal): + wait = f" WAIT {self.sql(wait)}" + else: + wait = " NOWAIT" if wait else " SKIP LOCKED" + + return f"{lock_type}{expressions}{wait or ''}" def literal_sql(self, expression: exp.Literal) -> str: text = expression.this or "" @@ -1460,27 +1527,33 @@ class Generator: return csv( *sqls, - *[self.sql(sql) for sql in expression.args.get("joins") or []], + *[self.sql(join) for join in expression.args.get("joins") or []], self.sql(expression, "match"), - *[self.sql(sql) for sql in expression.args.get("laterals") or []], + *[self.sql(lateral) for lateral in expression.args.get("laterals") or []], self.sql(expression, "where"), self.sql(expression, "group"), self.sql(expression, "having"), - self.sql(expression, "qualify"), - self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True) - if expression.args.get("windows") - else "", - self.sql(expression, "distribute"), - self.sql(expression, "sort"), - self.sql(expression, "cluster"), + *self.after_having_modifiers(expression), self.sql(expression, "order"), self.sql(expression, "offset") if fetch else self.sql(limit), self.sql(limit) if fetch else self.sql(expression, "offset"), - self.sql(expression, "lock"), - self.sql(expression, "sample"), + *self.after_limit_modifiers(expression), sep="", ) + def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]: + return [ + self.sql(expression, "qualify"), + self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True) + if expression.args.get("windows") + else "", + ] + + def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]: + locks = self.expressions(expression, key="locks", sep=" ") + locks = f" {locks}" if locks else "" + return [locks, self.sql(expression, "sample")] + def select_sql(self, expression: exp.Select) -> str: hint = self.sql(expression, "hint") distinct = self.sql(expression, "distinct") @@ -1529,13 +1602,10 @@ class Generator: alias = self.sql(expression, "alias") alias = f"{sep}{alias}" if alias else "" - sql = self.query_modifiers( - expression, - self.wrap(expression), - alias, - self.expressions(expression, key="pivots", sep=" "), - ) + pivots = self.expressions(expression, key="pivots", sep=" ", flat=True) + pivots = f" {pivots}" if pivots else "" + sql = self.query_modifiers(expression, self.wrap(expression), alias, pivots) return self.prepend_ctes(expression, sql) def qualify_sql(self, expression: exp.Qualify) -> str: @@ -1712,10 +1782,6 @@ class Generator: options = f" {options}" if options else "" return f"PRIMARY KEY ({expressions}){options}" - def unique_sql(self, expression: exp.Unique) -> str: - columns = self.expressions(expression, key="expressions") - return f"UNIQUE ({columns})" - def if_sql(self, expression: exp.If) -> str: return self.case_sql( exp.Case(ifs=[expression.copy()], default=expression.args.get("false")) @@ -1745,6 +1811,26 @@ class Generator: encoding = f" ENCODING {encoding}" if encoding else "" return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})" + def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str: + this = self.sql(expression, "this") + kind = self.sql(expression, "kind") + path = self.sql(expression, "path") + path = f" {path}" if path else "" + as_json = " AS JSON" if expression.args.get("as_json") else "" + return f"{this} {kind}{path}{as_json}" + + def openjson_sql(self, expression: exp.OpenJSON) -> str: + this = self.sql(expression, "this") + path = self.sql(expression, "path") + path = f", {path}" if path else "" + expressions = self.expressions(expression) + with_ = ( + f" WITH ({self.seg(self.indent(expressions), sep='')}{self.seg(')', sep='')}" + if expressions + else "" + ) + return f"OPENJSON({this}{path}){with_}" + def in_sql(self, expression: exp.In) -> str: query = expression.args.get("query") unnest = expression.args.get("unnest") @@ -1773,7 +1859,7 @@ class Generator: if self.SINGLE_STRING_INTERVAL: this = expression.this.name if expression.this else "" - return f"INTERVAL '{this}{unit}'" + return f"INTERVAL '{this}{unit}'" if this else f"INTERVAL{unit}" this = self.sql(expression, "this") if this: @@ -1883,6 +1969,28 @@ class Generator: expression_sql = self.sql(expression, "expression") return f"COMMENT{exists_sql}ON {kind} {this} IS {expression_sql}" + def mergetreettlaction_sql(self, expression: exp.MergeTreeTTLAction) -> str: + this = self.sql(expression, "this") + delete = " DELETE" if expression.args.get("delete") else "" + recompress = self.sql(expression, "recompress") + recompress = f" RECOMPRESS {recompress}" if recompress else "" + to_disk = self.sql(expression, "to_disk") + to_disk = f" TO DISK {to_disk}" if to_disk else "" + to_volume = self.sql(expression, "to_volume") + to_volume = f" TO VOLUME {to_volume}" if to_volume else "" + return f"{this}{delete}{recompress}{to_disk}{to_volume}" + + def mergetreettl_sql(self, expression: exp.MergeTreeTTL) -> str: + where = self.sql(expression, "where") + group = self.sql(expression, "group") + aggregates = self.expressions(expression, key="aggregates") + aggregates = self.seg("SET") + self.seg(aggregates) if aggregates else "" + + if not (where or group or aggregates) and len(expression.expressions) == 1: + return f"TTL {self.expressions(expression, flat=True)}" + + return f"TTL{self.seg(self.expressions(expression))}{where}{group}{aggregates}" + def transaction_sql(self, expression: exp.Transaction) -> str: return "BEGIN" @@ -1919,6 +2027,11 @@ class Generator: return f"ALTER COLUMN {this} DROP DEFAULT" def renametable_sql(self, expression: exp.RenameTable) -> str: + if not self.RENAME_TABLE_WITH_DB: + # Remove db from tables + expression = expression.transform( + lambda n: exp.table_(n.this) if isinstance(n, exp.Table) else n + ) this = self.sql(expression, "this") return f"RENAME TO {this}" @@ -2208,3 +2321,12 @@ class Generator: self.unsupported("Format argument unsupported for TO_CHAR/TO_VARCHAR function") return self.sql(exp.cast(expression.this, "text")) + + +def cached_generator( + cache: t.Optional[t.Dict[int, str]] = None +) -> t.Callable[[exp.Expression], str]: + """Returns a cached generator.""" + cache = {} if cache is None else cache + generator = Generator(normalize=True, identify="safe") + return lambda e: generator.generate(e, cache) diff --git a/sqlglot/helper.py b/sqlglot/helper.py index b2f0520..4215fee 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -9,14 +9,14 @@ from collections.abc import Collection from contextlib import contextmanager from copy import copy from enum import Enum +from itertools import count if t.TYPE_CHECKING: from sqlglot import exp + from sqlglot._typing import E, T + from sqlglot.dialects.dialect import DialectType from sqlglot.expressions import Expression - T = t.TypeVar("T") - E = t.TypeVar("E", bound=Expression) - CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])") PYTHON_VERSION = sys.version_info[:2] logger = logging.getLogger("sqlglot") @@ -25,7 +25,7 @@ logger = logging.getLogger("sqlglot") class AutoName(Enum): """This is used for creating enum classes where `auto()` is the string form of the corresponding value's name.""" - def _generate_next_value_(name, _start, _count, _last_values): # type: ignore + def _generate_next_value_(name, _start, _count, _last_values): return name @@ -92,7 +92,7 @@ def ensure_collection(value): ) -def csv(*args, sep: str = ", ") -> str: +def csv(*args: str, sep: str = ", ") -> str: """ Formats any number of string arguments as CSV. @@ -304,9 +304,18 @@ def find_new_name(taken: t.Collection[str], base: str) -> str: return new +def name_sequence(prefix: str) -> t.Callable[[], str]: + """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").""" + sequence = count() + return lambda: f"{prefix}{next(sequence)}" + + def object_to_dict(obj: t.Any, **kwargs) -> t.Dict: """Returns a dictionary created from an object's attributes.""" - return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs} + return { + **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()}, + **kwargs, + } def split_num_words( @@ -381,15 +390,6 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]: yield value -def count_params(function: t.Callable) -> int: - """ - Returns the number of formal parameters expected by a function, without counting "self" - and "cls", in case of instance and class methods, respectively. - """ - count = function.__code__.co_argcount - return count - 1 if inspect.ismethod(function) else count - - def dict_depth(d: t.Dict) -> int: """ Get the nesting depth of a dictionary. @@ -430,12 +430,23 @@ def first(it: t.Iterable[T]) -> T: return next(i for i in it) -def should_identify(text: str, identify: str | bool) -> bool: +def case_sensitive(text: str, dialect: DialectType) -> bool: + """Checks if text contains any case sensitive characters depending on dialect.""" + from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE + + unsafe = str.islower if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else str.isupper + return any(unsafe(char) for char in text) + + +def should_identify(text: str, identify: str | bool, dialect: DialectType = None) -> bool: """Checks if text should be identified given an identify option. Args: text: the text to check. - identify: "always" | True - always returns true, "safe" - true if no upper case + identify: + "always" or `True`: always returns true. + "safe": true if there is no uppercase or lowercase character in `text`, depending on `dialect`. + dialect: the dialect to use in order to decide whether a text should be identified. Returns: Whether or not a string should be identified. @@ -443,5 +454,5 @@ def should_identify(text: str, identify: str | bool) -> bool: if identify is True or identify == "always": return True if identify == "safe": - return not any(char.isupper() for char in text) + return not case_sensitive(text, dialect) return False diff --git a/sqlglot/lineage.py b/sqlglot/lineage.py index 0eac870..04a8073 100644 --- a/sqlglot/lineage.py +++ b/sqlglot/lineage.py @@ -5,10 +5,8 @@ import typing as t from dataclasses import dataclass, field from sqlglot import Schema, exp, maybe_parse -from sqlglot.optimizer import Scope, build_scope, optimize -from sqlglot.optimizer.expand_laterals import expand_laterals -from sqlglot.optimizer.qualify_columns import qualify_columns -from sqlglot.optimizer.qualify_tables import qualify_tables +from sqlglot.errors import SqlglotError +from sqlglot.optimizer import Scope, build_scope, qualify if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType @@ -40,8 +38,8 @@ def lineage( sql: str | exp.Expression, schema: t.Optional[t.Dict | Schema] = None, sources: t.Optional[t.Dict[str, str | exp.Subqueryable]] = None, - rules: t.Sequence[t.Callable] = (qualify_tables, qualify_columns, expand_laterals), dialect: DialectType = None, + **kwargs, ) -> Node: """Build the lineage graph for a column of a SQL query. @@ -50,8 +48,8 @@ def lineage( sql: The SQL string or expression. schema: The schema of tables. sources: A mapping of queries which will be used to continue building lineage. - rules: Optimizer rules to apply, by default only qualifying tables and columns. dialect: The dialect of input SQL. + **kwargs: Qualification optimizer kwargs. Returns: A lineage node. @@ -68,8 +66,17 @@ def lineage( }, ) - optimized = optimize(expression, schema=schema, rules=rules) - scope = build_scope(optimized) + qualified = qualify.qualify( + expression, + dialect=dialect, + schema=schema, + **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore + ) + + scope = build_scope(qualified) + + if not scope: + raise SqlglotError("Cannot build lineage, sql must be SELECT") def to_node( column_name: str, @@ -109,10 +116,7 @@ def lineage( # a version that has only the column we care about. # "x", SELECT x, y FROM foo # => "x", SELECT x FROM foo - source = optimize( - scope.expression.select(select, append=False), schema=schema, rules=rules - ) - select = source.selects[0] + source = t.cast(exp.Expression, scope.expression.select(select, append=False)) else: source = scope.expression diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index ef929ac..da2fce8 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -3,10 +3,9 @@ from __future__ import annotations import itertools from sqlglot import exp -from sqlglot.helper import should_identify -def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expression: +def canonicalize(expression: exp.Expression) -> exp.Expression: """Converts a sql expression into a standard form. This method relies on annotate_types because many of the @@ -14,19 +13,14 @@ def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expr Args: expression: The expression to canonicalize. - identify: Whether or not to force identify identifier. """ - exp.replace_children(expression, canonicalize, identify=identify) + exp.replace_children(expression, canonicalize) expression = add_text_to_concat(expression) expression = coerce_type(expression) expression = remove_redundant_casts(expression) expression = ensure_bool_predicates(expression) - if isinstance(expression, exp.Identifier): - if should_identify(expression.this, identify): - expression.set("quoted", True) - return expression diff --git a/sqlglot/optimizer/eliminate_ctes.py b/sqlglot/optimizer/eliminate_ctes.py index 7b862c6..6f1865c 100644 --- a/sqlglot/optimizer/eliminate_ctes.py +++ b/sqlglot/optimizer/eliminate_ctes.py @@ -19,24 +19,25 @@ def eliminate_ctes(expression): """ root = build_scope(expression) - ref_count = root.ref_count() - - # Traverse the scope tree in reverse so we can remove chains of unused CTEs - for scope in reversed(list(root.traverse())): - if scope.is_cte: - count = ref_count[id(scope)] - if count <= 0: - cte_node = scope.expression.parent - with_node = cte_node.parent - cte_node.pop() - - # Pop the entire WITH clause if this is the last CTE - if len(with_node.expressions) <= 0: - with_node.pop() - - # Decrement the ref count for all sources this CTE selects from - for _, source in scope.selected_sources.values(): - if isinstance(source, Scope): - ref_count[id(source)] -= 1 + if root: + ref_count = root.ref_count() + + # Traverse the scope tree in reverse so we can remove chains of unused CTEs + for scope in reversed(list(root.traverse())): + if scope.is_cte: + count = ref_count[id(scope)] + if count <= 0: + cte_node = scope.expression.parent + with_node = cte_node.parent + cte_node.pop() + + # Pop the entire WITH clause if this is the last CTE + if len(with_node.expressions) <= 0: + with_node.pop() + + # Decrement the ref count for all sources this CTE selects from + for _, source in scope.selected_sources.values(): + if isinstance(source, Scope): + ref_count[id(source)] -= 1 return expression diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index a39fe96..84f50e9 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -16,9 +16,9 @@ def eliminate_subqueries(expression): 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y' This also deduplicates common subqueries: - >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z") + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z") >>> eliminate_subqueries(expression).sql() - 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z' + 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z' Args: expression (sqlglot.Expression): expression @@ -32,6 +32,9 @@ def eliminate_subqueries(expression): root = build_scope(expression) + if not root: + return expression + # Map of alias->Scope|Table # These are all aliases that are already used in the expression. # We don't want to create new CTEs that conflict with these names. @@ -112,7 +115,7 @@ def _eliminate_union(scope, existing_ctes, taken): # Try to maintain the selections expressions = scope.selects selects = [ - exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name) + exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False) for e in expressions if e.alias_or_name ] @@ -120,7 +123,9 @@ def _eliminate_union(scope, existing_ctes, taken): if len(selects) != len(expressions): selects = ["*"] - scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias))) + scope.expression.replace( + exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias, copy=False)) + ) if not duplicate_cte_alias: existing_ctes[scope.expression] = alias @@ -131,6 +136,10 @@ def _eliminate_union(scope, existing_ctes, taken): def _eliminate_derived_table(scope, existing_ctes, taken): + # This ensures we don't drop the "pivot" arg from a pivoted subquery + if scope.parent.pivots: + return None + parent = scope.expression.parent name, cte = _new_cte(scope, existing_ctes, taken) @@ -153,7 +162,7 @@ def _eliminate_cte(scope, existing_ctes, taken): for child_scope in scope.parent.traverse(): for table, source in child_scope.selected_sources.values(): if source is scope: - new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name) + new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False) table.replace(new_table) return cte diff --git a/sqlglot/optimizer/expand_laterals.py b/sqlglot/optimizer/expand_laterals.py deleted file mode 100644 index 5b2f706..0000000 --- a/sqlglot/optimizer/expand_laterals.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp - - -def expand_laterals(expression: exp.Expression) -> exp.Expression: - """ - Expand lateral column alias references. - - This assumes `qualify_columns` as already run. - - Example: - >>> import sqlglot - >>> sql = "SELECT x.a + 1 AS b, b + 1 AS c FROM x" - >>> expression = sqlglot.parse_one(sql) - >>> expand_laterals(expression).sql() - 'SELECT x.a + 1 AS b, x.a + 1 + 1 AS c FROM x' - - Args: - expression: expression to optimize - Returns: - optimized expression - """ - for select in expression.find_all(exp.Select): - alias_to_expression: t.Dict[str, exp.Expression] = {} - for projection in select.expressions: - for column in projection.find_all(exp.Column): - if not column.table and column.name in alias_to_expression: - column.replace(alias_to_expression[column.name].copy()) - if isinstance(projection, exp.Alias): - alias_to_expression[projection.alias] = projection.this - return expression diff --git a/sqlglot/optimizer/expand_multi_table_selects.py b/sqlglot/optimizer/expand_multi_table_selects.py deleted file mode 100644 index 86f0c2d..0000000 --- a/sqlglot/optimizer/expand_multi_table_selects.py +++ /dev/null @@ -1,24 +0,0 @@ -from sqlglot import exp - - -def expand_multi_table_selects(expression): - """ - Replace multiple FROM expressions with JOINs. - - Example: - >>> from sqlglot import parse_one - >>> expand_multi_table_selects(parse_one("SELECT * FROM x, y")).sql() - 'SELECT * FROM x CROSS JOIN y' - """ - for from_ in expression.find_all(exp.From): - parent = from_.parent - - for query in from_.expressions[1:]: - parent.join( - query, - join_type="CROSS", - copy=False, - ) - from_.expressions.remove(query) - - return expression diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py index 5d78353..5dfa4aa 100644 --- a/sqlglot/optimizer/isolate_table_selects.py +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -21,7 +21,7 @@ def isolate_table_selects(expression, schema=None): source.replace( exp.select("*") .from_( - alias(source.copy(), source.name or source.alias, table=True), + alias(source, source.name or source.alias, table=True), copy=False, ) .subquery(source.alias, copy=False) diff --git a/sqlglot/optimizer/lower_identities.py b/sqlglot/optimizer/lower_identities.py deleted file mode 100644 index fae1726..0000000 --- a/sqlglot/optimizer/lower_identities.py +++ /dev/null @@ -1,88 +0,0 @@ -from sqlglot import exp - - -def lower_identities(expression): - """ - Convert all unquoted identifiers to lower case. - - Assuming the schema is all lower case, this essentially makes identifiers case-insensitive. - - Example: - >>> import sqlglot - >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar') - >>> lower_identities(expression).sql() - 'SELECT bar.a AS A FROM "Foo".bar' - - Args: - expression (sqlglot.Expression): expression to quote - Returns: - sqlglot.Expression: quoted expression - """ - # We need to leave the output aliases unchanged, so the selects need special handling - _lower_selects(expression) - - # These clauses can reference output aliases and also need special handling - _lower_order(expression) - _lower_having(expression) - - # We've already handled these args, so don't traverse into them - traversed = {"expressions", "order", "having"} - - if isinstance(expression, exp.Subquery): - # Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1 - lower_identities(expression.this) - traversed |= {"this"} - - if isinstance(expression, exp.Union): - # Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X - lower_identities(expression.left) - lower_identities(expression.right) - traversed |= {"this", "expression"} - - for k, v in expression.iter_expressions(): - if k in traversed: - continue - v.transform(_lower, copy=False) - - return expression - - -def _lower_selects(expression): - for e in expression.expressions: - # Leave output aliases as-is - e.unalias().transform(_lower, copy=False) - - -def _lower_order(expression): - order = expression.args.get("order") - - if not order: - return - - output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)} - - for ordered in order.expressions: - # Don't lower references to output aliases - if not ( - isinstance(ordered.this, exp.Column) - and not ordered.this.table - and ordered.this.name in output_aliases - ): - ordered.transform(_lower, copy=False) - - -def _lower_having(expression): - having = expression.args.get("having") - - if not having: - return - - # Don't lower references to output aliases - for agg in having.find_all(exp.AggFunc): - agg.transform(_lower, copy=False) - - -def _lower(node): - if isinstance(node, exp.Identifier) and not node.quoted: - node.set("this", node.this.lower()) - return node diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index c3467b2..f9c9664 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -13,15 +13,15 @@ def merge_subqueries(expression, leave_tables_isolated=False): Example: >>> import sqlglot - >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") >>> merge_subqueries(expression).sql() - 'SELECT x.a FROM x JOIN y' + 'SELECT x.a FROM x CROSS JOIN y' If `leave_tables_isolated` is True, this will not merge inner queries into outer queries if it would result in multiple table selects in a single query: - >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") >>> merge_subqueries(expression, leave_tables_isolated=True).sql() - 'SELECT a FROM (SELECT x.a FROM x) JOIN y' + 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y' Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html @@ -154,7 +154,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): inner_from = inner_scope.expression.args.get("from") if not inner_from: return False - inner_from_table = inner_from.expressions[0].alias_or_name + inner_from_table = inner_from.alias_or_name inner_projections = {s.alias_or_name: s for s in inner_scope.selects} return any( col.table != inner_from_table @@ -167,6 +167,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): and isinstance(inner_select, exp.Select) and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) and inner_select.args.get("from") + and not outer_scope.pivots and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) and not ( @@ -210,7 +211,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias): elif isinstance(source, exp.Table) and source.alias: source.set("alias", new_alias) elif isinstance(source, exp.Table): - source.replace(exp.alias_(source.copy(), new_alias)) + source.replace(exp.alias_(source, new_alias)) for column in inner_scope.source_columns(conflict): column.set("table", exp.to_identifier(new_name)) @@ -228,7 +229,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias): node_to_replace (exp.Subquery|exp.Table) alias (str) """ - new_subquery = inner_scope.expression.args.get("from").expressions[0] + new_subquery = inner_scope.expression.args["from"].this node_to_replace.replace(new_subquery) for join_hint in outer_scope.join_hints: tables = join_hint.find_all(exp.Table) @@ -319,7 +320,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join): # Merge predicates from an outer join to the ON clause # if it only has columns that are already joined from_ = expression.args.get("from") - sources = {table.alias_or_name for table in from_.expressions} if from_ else {} + sources = {from_.alias_or_name} if from_ else {} for join in expression.args["joins"]: source = join.alias_or_name diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index b013312..1db094e 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -1,12 +1,12 @@ from __future__ import annotations import logging -import typing as t from sqlglot import exp from sqlglot.errors import OptimizeError +from sqlglot.generator import cached_generator from sqlglot.helper import while_changing -from sqlglot.optimizer.simplify import flatten, uniq_sort +from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort logger = logging.getLogger("sqlglot") @@ -28,13 +28,16 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = Returns: sqlglot.Expression: normalized expression """ - cache: t.Dict[int, str] = {} + generate = cached_generator() for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): if isinstance(node, exp.Connector): if normalized(node, dnf=dnf): continue + root = node is expression + original = node.copy() + node.transform(rewrite_between, copy=False) distance = normalization_distance(node, dnf=dnf) if distance > max_distance: @@ -43,11 +46,9 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = ) return expression - root = node is expression - original = node.copy() try: node = node.replace( - while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache)) + while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate)) ) except OptimizeError as e: logger.info(e) @@ -111,7 +112,7 @@ def _predicate_lengths(expression, dnf): return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) -def distributive_law(expression, dnf, max_distance, cache=None): +def distributive_law(expression, dnf, max_distance, generate): """ x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) @@ -124,7 +125,7 @@ def distributive_law(expression, dnf, max_distance, cache=None): if distance > max_distance: raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") - exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache)) + exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate)) to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) if isinstance(expression, from_exp): @@ -135,30 +136,30 @@ def distributive_law(expression, dnf, max_distance, cache=None): if isinstance(a, to_exp) and isinstance(b, to_exp): if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): - return _distribute(a, b, from_func, to_func, cache) - return _distribute(b, a, from_func, to_func, cache) + return _distribute(a, b, from_func, to_func, generate) + return _distribute(b, a, from_func, to_func, generate) if isinstance(a, to_exp): - return _distribute(b, a, from_func, to_func, cache) + return _distribute(b, a, from_func, to_func, generate) if isinstance(b, to_exp): - return _distribute(a, b, from_func, to_func, cache) + return _distribute(a, b, from_func, to_func, generate) return expression -def _distribute(a, b, from_func, to_func, cache): +def _distribute(a, b, from_func, to_func, generate): if isinstance(a, exp.Connector): exp.replace_children( a, lambda c: to_func( - uniq_sort(flatten(from_func(c, b.left)), cache), - uniq_sort(flatten(from_func(c, b.right)), cache), + uniq_sort(flatten(from_func(c, b.left)), generate), + uniq_sort(flatten(from_func(c, b.right)), generate), copy=False, ), ) else: a = to_func( - uniq_sort(flatten(from_func(a, b.left)), cache), - uniq_sort(flatten(from_func(a, b.right)), cache), + uniq_sort(flatten(from_func(a, b.left)), generate), + uniq_sort(flatten(from_func(a, b.right)), generate), copy=False, ) diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py new file mode 100644 index 0000000..1e5c104 --- /dev/null +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -0,0 +1,36 @@ +from sqlglot import exp +from sqlglot._typing import E +from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE, DialectType + + +def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: + """ + Normalize all unquoted identifiers to either lower or upper case, depending on + the dialect. This essentially makes those identifiers case-insensitive. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar') + >>> normalize_identifiers(expression).sql() + 'SELECT bar.a AS a FROM "Foo".bar' + + Args: + expression: The expression to transform. + dialect: The dialect to use in order to decide how to normalize identifiers. + + Returns: + The transformed expression. + """ + return expression.transform(_normalize, dialect, copy=False) + + +def _normalize(node: exp.Expression, dialect: DialectType = None) -> exp.Expression: + if isinstance(node, exp.Identifier) and not node.quoted: + node.set( + "this", + node.this.upper() + if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE + else node.this.lower(), + ) + + return node diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 8589657..43436cb 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -1,6 +1,8 @@ from sqlglot import exp from sqlglot.helper import tsort +JOIN_ATTRS = ("on", "side", "kind", "using", "natural") + def optimize_joins(expression): """ @@ -45,7 +47,7 @@ def reorder_joins(expression): Reorder joins by topological sort order based on predicate references. """ for from_ in expression.find_all(exp.From): - head = from_.expressions[0] + head = from_.this parent = from_.parent joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])} dag = {head.alias_or_name: []} @@ -65,6 +67,9 @@ def normalize(expression): Remove INNER and OUTER from joins as they are optional. """ for join in expression.find_all(exp.Join): + if not any(join.args.get(k) for k in JOIN_ATTRS): + join.set("kind", "CROSS") + if join.kind != "CROSS": join.set("kind", None) return expression diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index c165ffe..dbe33a2 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -10,36 +10,29 @@ from sqlglot.optimizer.canonicalize import canonicalize from sqlglot.optimizer.eliminate_ctes import eliminate_ctes from sqlglot.optimizer.eliminate_joins import eliminate_joins from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries -from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects -from sqlglot.optimizer.isolate_table_selects import isolate_table_selects -from sqlglot.optimizer.lower_identities import lower_identities from sqlglot.optimizer.merge_subqueries import merge_subqueries from sqlglot.optimizer.normalize import normalize from sqlglot.optimizer.optimize_joins import optimize_joins from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.pushdown_projections import pushdown_projections -from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns -from sqlglot.optimizer.qualify_tables import qualify_tables +from sqlglot.optimizer.qualify import qualify +from sqlglot.optimizer.qualify_columns import quote_identifiers from sqlglot.optimizer.simplify import simplify from sqlglot.optimizer.unnest_subqueries import unnest_subqueries from sqlglot.schema import ensure_schema RULES = ( - lower_identities, - qualify_tables, - isolate_table_selects, - qualify_columns, + qualify, pushdown_projections, - validate_qualify_columns, normalize, unnest_subqueries, - expand_multi_table_selects, pushdown_predicates, optimize_joins, eliminate_subqueries, merge_subqueries, eliminate_joins, eliminate_ctes, + quote_identifiers, annotate_types, canonicalize, simplify, @@ -54,7 +47,7 @@ def optimize( dialect: DialectType = None, rules: t.Sequence[t.Callable] = RULES, **kwargs, -): +) -> exp.Expression: """ Rewrite a sqlglot AST into an optimized form. @@ -72,14 +65,23 @@ def optimize( dialect: The dialect to parse the sql string. rules: sequence of optimizer rules to use. Many of the rules require tables and columns to be qualified. - Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know - what you're doing! + Do not remove `qualify` from the sequence of rules unless you know what you're doing! **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in. + Returns: - sqlglot.Expression: optimized expression + The optimized expression. """ schema = ensure_schema(schema or sqlglot.schema, dialect=dialect) - possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs} + possible_kwargs = { + "db": db, + "catalog": catalog, + "schema": schema, + "dialect": dialect, + "isolate_tables": True, # needed for other optimizations to perform well + "quote_identifiers": False, # this happens in canonicalize + **kwargs, + } + expression = exp.maybe_parse(expression, dialect=dialect, copy=True) for rule in rules: # Find any additional rule parameters, beyond `expression` @@ -88,4 +90,5 @@ def optimize( param: possible_kwargs[param] for param in rule_params if param in possible_kwargs } expression = rule(expression, **rule_kwargs) - return expression + + return t.cast(exp.Expression, expression) diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index ba5c8b5..96dda33 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -21,26 +21,28 @@ def pushdown_predicates(expression): sqlglot.Expression: optimized expression """ root = build_scope(expression) - scope_ref_count = root.ref_count() - - for scope in reversed(list(root.traverse())): - select = scope.expression - where = select.args.get("where") - if where: - selected_sources = scope.selected_sources - # a right join can only push down to itself and not the source FROM table - for k, (node, source) in selected_sources.items(): - parent = node.find_ancestor(exp.Join, exp.From) - if isinstance(parent, exp.Join) and parent.side == "RIGHT": - selected_sources = {k: (node, source)} - break - pushdown(where.this, selected_sources, scope_ref_count) - - # joins should only pushdown into itself, not to other joins - # so we limit the selected sources to only itself - for join in select.args.get("joins") or []: - name = join.this.alias_or_name - pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count) + + if root: + scope_ref_count = root.ref_count() + + for scope in reversed(list(root.traverse())): + select = scope.expression + where = select.args.get("where") + if where: + selected_sources = scope.selected_sources + # a right join can only push down to itself and not the source FROM table + for k, (node, source) in selected_sources.items(): + parent = node.find_ancestor(exp.Join, exp.From) + if isinstance(parent, exp.Join) and parent.side == "RIGHT": + selected_sources = {k: (node, source)} + break + pushdown(where.this, selected_sources, scope_ref_count) + + # joins should only pushdown into itself, not to other joins + # so we limit the selected sources to only itself + for join in select.args.get("joins") or []: + name = join.this.alias_or_name + pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count) return expression diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 2e51117..be3ddb2 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -39,8 +39,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True) for scope in reversed(traverse_scope(expression)): parent_selections = referenced_columns.get(scope, {SELECT_ALL}) - if scope.expression.args.get("distinct"): - # We can't remove columns SELECT DISTINCT nor UNION DISTINCT + if scope.expression.args.get("distinct") or scope.parent and scope.parent.pivots: + # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if + # we select from a pivoted source in the parent scope. parent_selections = {SELECT_ALL} if isinstance(scope.expression, exp.Union): @@ -105,7 +106,9 @@ def _remove_unused_selections(scope, parent_selections, schema): for name in sorted(parent_selections): if name not in names: - new_selections.append(alias(exp.column(name, table=resolver.get_table(name)), name)) + new_selections.append( + alias(exp.column(name, table=resolver.get_table(name)), name, copy=False) + ) # If there are no remaining selections, just select a single constant if not new_selections: diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py new file mode 100644 index 0000000..5fdbde8 --- /dev/null +++ b/sqlglot/optimizer/qualify.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp +from sqlglot.dialects.dialect import DialectType +from sqlglot.optimizer.isolate_table_selects import isolate_table_selects +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers +from sqlglot.optimizer.qualify_columns import ( + qualify_columns as qualify_columns_func, + quote_identifiers as quote_identifiers_func, + validate_qualify_columns as validate_qualify_columns_func, +) +from sqlglot.optimizer.qualify_tables import qualify_tables +from sqlglot.schema import Schema, ensure_schema + + +def qualify( + expression: exp.Expression, + dialect: DialectType = None, + db: t.Optional[str] = None, + catalog: t.Optional[str] = None, + schema: t.Optional[dict | Schema] = None, + expand_alias_refs: bool = True, + infer_schema: t.Optional[bool] = None, + isolate_tables: bool = False, + qualify_columns: bool = True, + validate_qualify_columns: bool = True, + quote_identifiers: bool = True, + identify: bool = True, +) -> exp.Expression: + """ + Rewrite sqlglot AST to have normalized and qualified tables and columns. + + This step is necessary for all further SQLGlot optimizations. + + Example: + >>> import sqlglot + >>> schema = {"tbl": {"col": "INT"}} + >>> expression = sqlglot.parse_one("SELECT col FROM tbl") + >>> qualify(expression, schema=schema).sql() + 'SELECT "tbl"."col" AS "col" FROM "tbl" AS "tbl"' + + Args: + expression: Expression to qualify. + db: Default database name for tables. + catalog: Default catalog name for tables. + schema: Schema to infer column names and types. + expand_alias_refs: Whether or not to expand references to aliases. + infer_schema: Whether or not to infer the schema if missing. + isolate_tables: Whether or not to isolate table selects. + qualify_columns: Whether or not to qualify columns. + validate_qualify_columns: Whether or not to validate columns. + quote_identifiers: Whether or not to run the quote_identifiers step. + This step is necessary to ensure correctness for case sensitive queries. + But this flag is provided in case this step is performed at a later time. + identify: If True, quote all identifiers, else only necessary ones. + + Returns: + The qualified expression. + """ + schema = ensure_schema(schema, dialect=dialect) + expression = normalize_identifiers(expression, dialect=dialect) + expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema) + + if isolate_tables: + expression = isolate_table_selects(expression, schema=schema) + + if qualify_columns: + expression = qualify_columns_func( + expression, schema, expand_alias_refs=expand_alias_refs, infer_schema=infer_schema + ) + + if quote_identifiers: + expression = quote_identifiers_func(expression, dialect=dialect, identify=identify) + + if validate_qualify_columns: + validate_qualify_columns_func(expression) + + return expression diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 6ac39f0..4a31171 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -1,14 +1,23 @@ +from __future__ import annotations + import itertools import typing as t from sqlglot import alias, exp +from sqlglot._typing import E +from sqlglot.dialects.dialect import DialectType from sqlglot.errors import OptimizeError -from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals -from sqlglot.optimizer.scope import Scope, traverse_scope -from sqlglot.schema import ensure_schema +from sqlglot.helper import case_sensitive, seq_get +from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope +from sqlglot.schema import Schema, ensure_schema -def qualify_columns(expression, schema, expand_laterals=True): +def qualify_columns( + expression: exp.Expression, + schema: dict | Schema, + expand_alias_refs: bool = True, + infer_schema: t.Optional[bool] = None, +) -> exp.Expression: """ Rewrite sqlglot AST to have fully qualified columns. @@ -20,32 +29,36 @@ def qualify_columns(expression, schema, expand_laterals=True): 'SELECT tbl.col AS col FROM tbl' Args: - expression (sqlglot.Expression): expression to qualify - schema (dict|sqlglot.optimizer.Schema): Database schema + expression: expression to qualify + schema: Database schema + expand_alias_refs: whether or not to expand references to aliases + infer_schema: whether or not to infer the schema if missing Returns: sqlglot.Expression: qualified expression """ schema = ensure_schema(schema) - - if not schema.mapping and expand_laterals: - expression = _expand_laterals(expression) + infer_schema = schema.empty if infer_schema is None else infer_schema for scope in traverse_scope(expression): - resolver = Resolver(scope, schema) + resolver = Resolver(scope, schema, infer_schema=infer_schema) _pop_table_column_aliases(scope.ctes) _pop_table_column_aliases(scope.derived_tables) using_column_tables = _expand_using(scope, resolver) + + if schema.empty and expand_alias_refs: + _expand_alias_refs(scope, resolver) + _qualify_columns(scope, resolver) + + if not schema.empty and expand_alias_refs: + _expand_alias_refs(scope, resolver) + if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver, using_column_tables) _qualify_outputs(scope) - _expand_alias_refs(scope, resolver) _expand_group_by(scope, resolver) _expand_order_by(scope) - if schema.mapping and expand_laterals: - expression = _expand_laterals(expression) - return expression @@ -55,9 +68,11 @@ def validate_qualify_columns(expression): for scope in traverse_scope(expression): if isinstance(scope.expression, exp.Select): unqualified_columns.extend(scope.unqualified_columns) - if scope.external_columns and not scope.is_correlated_subquery: + if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: column = scope.external_columns[0] - raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'") + raise OptimizeError( + f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" + ) if unqualified_columns: raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") @@ -142,52 +157,48 @@ def _expand_using(scope, resolver): # Ensure selects keep their output name if isinstance(column.parent, exp.Select): - replacement = exp.alias_(replacement, alias=column.name) + replacement = alias(replacement, alias=column.name, copy=False) scope.replace(column, replacement) return column_tables -def _expand_alias_refs(scope, resolver): - selects = {} - - # Replace references to select aliases - def transform(node, source_first=True): - if isinstance(node, exp.Column) and not node.table: - table = resolver.get_table(node.name) +def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: + expression = scope.expression - # Source columns get priority over select aliases - if source_first and table: - node.set("table", table) - return node + if not isinstance(expression, exp.Select): + return - if not selects: - for s in scope.selects: - selects[s.alias_or_name] = s - select = selects.get(node.name) + alias_to_expression: t.Dict[str, exp.Expression] = {} - if select: - scope.clear_cache() - if isinstance(select, exp.Alias): - select = select.this - return select.copy() + def replace_columns( + node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False + ): + if not node: + return - node.set("table", table) - elif isinstance(node, exp.Expression) and not isinstance(node, exp.Subqueryable): - exp.replace_children(node, transform, source_first) + for column, *_ in walk_in_scope(node): + if not isinstance(column, exp.Column): + continue + table = resolver.get_table(column.name) if resolve_agg and not column.table else None + if table and column.find_ancestor(exp.AggFunc): + column.set("table", table) + elif expand and not column.table and column.name in alias_to_expression: + column.replace(alias_to_expression[column.name].copy()) - return node + for projection in scope.selects: + replace_columns(projection) - for select in scope.expression.selects: - transform(select) + if isinstance(projection, exp.Alias): + alias_to_expression[projection.alias] = projection.this - for modifier, source_first in ( - ("where", True), - ("group", True), - ("having", False), - ): - transform(scope.expression.args.get(modifier), source_first=source_first) + replace_columns(expression.args.get("where")) + replace_columns(expression.args.get("group")) + replace_columns(expression.args.get("having"), resolve_agg=True) + replace_columns(expression.args.get("qualify"), resolve_agg=True) + replace_columns(expression.args.get("order"), expand=False, resolve_agg=True) + scope.clear_cache() def _expand_group_by(scope, resolver): @@ -242,6 +253,12 @@ def _qualify_columns(scope, resolver): raise OptimizeError(f"Unknown column: {column_name}") if not column_table: + if scope.pivots and not column.find_ancestor(exp.Pivot): + # If the column is under the Pivot expression, we need to qualify it + # using the name of the pivoted source instead of the pivot's alias + column.set("table", exp.to_identifier(scope.pivots[0].alias)) + continue + column_table = resolver.get_table(column_name) # column_table can be a '' because bigquery unnest has no table alias @@ -265,38 +282,12 @@ def _qualify_columns(scope, resolver): if column_table: column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) - columns_missing_from_scope = [] - - # Determine whether each reference in the order by clause is to a column or an alias. - order = scope.expression.args.get("order") - - if order: - for ordered in order.expressions: - for column in ordered.find_all(exp.Column): - if ( - not column.table - and column.parent is not ordered - and column.name in resolver.all_columns - ): - columns_missing_from_scope.append(column) - - # Determine whether each reference in the having clause is to a column or an alias. - having = scope.expression.args.get("having") - - if having: - for column in having.find_all(exp.Column): - if ( - not column.table - and column.find_ancestor(exp.AggFunc) - and column.name in resolver.all_columns - ): - columns_missing_from_scope.append(column) - - for column in columns_missing_from_scope: - column_table = resolver.get_table(column.name) - - if column_table: - column.set("table", column_table) + for pivot in scope.pivots: + for column in pivot.find_all(exp.Column): + if not column.table and column.name in resolver.all_columns: + column_table = resolver.get_table(column.name) + if column_table: + column.set("table", column_table) def _expand_stars(scope, resolver, using_column_tables): @@ -307,6 +298,19 @@ def _expand_stars(scope, resolver, using_column_tables): replace_columns = {} coalesced_columns = set() + # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future + pivot_columns = None + pivot_output_columns = None + pivot = seq_get(scope.pivots, 0) + + has_pivoted_source = pivot and not pivot.args.get("unpivot") + if has_pivoted_source: + pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column)) + + pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] + if not pivot_output_columns: + pivot_output_columns = [col.alias_or_name for col in pivot.expressions] + for expression in scope.selects: if isinstance(expression, exp.Star): tables = list(scope.selected_sources) @@ -323,9 +327,18 @@ def _expand_stars(scope, resolver, using_column_tables): for table in tables: if table not in scope.sources: raise OptimizeError(f"Unknown table: {table}") + columns = resolver.get_source_columns(table, only_visible=True) if columns and "*" not in columns: + if has_pivoted_source: + implicit_columns = [col for col in columns if col not in pivot_columns] + new_selections.extend( + exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) + for name in implicit_columns + pivot_output_columns + ) + continue + table_id = id(table) for name in columns: if name in using_column_tables and table in using_column_tables[name]: @@ -337,16 +350,21 @@ def _expand_stars(scope, resolver, using_column_tables): coalesce = [exp.column(name, table=table) for table in tables] new_selections.append( - exp.alias_( - exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name + alias( + exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), + alias=name, + copy=False, ) ) elif name not in except_columns.get(table_id, set()): alias_ = replace_columns.get(table_id, {}).get(name, name) - column = exp.column(name, table) - new_selections.append(alias(column, alias_) if alias_ != name else column) + column = exp.column(name, table=table) + new_selections.append( + alias(column, alias_, copy=False) if alias_ != name else column + ) else: return + scope.expression.set("expressions", new_selections) @@ -388,9 +406,6 @@ def _qualify_outputs(scope): selection = alias( selection, alias=selection.output_name or f"_col_{i}", - quoted=True - if isinstance(selection, exp.Column) and selection.this.quoted - else None, ) if aliased_column: selection.set("alias", exp.to_identifier(aliased_column)) @@ -400,6 +415,23 @@ def _qualify_outputs(scope): scope.expression.set("expressions", new_selections) +def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: + """Makes sure all identifiers that need to be quoted are quoted.""" + + def _quote(expression: E) -> E: + if isinstance(expression, exp.Identifier): + name = expression.this + expression.set( + "quoted", + identify + or case_sensitive(name, dialect=dialect) + or not exp.SAFE_IDENTIFIER_RE.match(name), + ) + return expression + + return expression.transform(_quote, copy=False) + + class Resolver: """ Helper for resolving columns. @@ -407,12 +439,13 @@ class Resolver: This is a class so we can lazily load some things and easily share them across functions. """ - def __init__(self, scope, schema): + def __init__(self, scope, schema, infer_schema: bool = True): self.scope = scope self.schema = schema self._source_columns = None - self._unambiguous_columns = None + self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None self._all_columns = None + self._infer_schema = infer_schema def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: """ @@ -430,7 +463,7 @@ class Resolver: table_name = self._unambiguous_columns.get(column_name) - if not table_name: + if not table_name and self._infer_schema: sources_without_schema = tuple( source for source, columns in self._get_all_source_columns().items() @@ -450,11 +483,9 @@ class Resolver: node_alias = node.args.get("alias") if node_alias: - return node_alias.this + return exp.to_identifier(node_alias.this) - return exp.to_identifier( - table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None - ) + return exp.to_identifier(table_name) @property def all_columns(self): diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 1b451a6..fcc5f26 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -1,11 +1,19 @@ import itertools +import typing as t from sqlglot import alias, exp -from sqlglot.helper import csv_reader +from sqlglot._typing import E +from sqlglot.helper import csv_reader, name_sequence from sqlglot.optimizer.scope import Scope, traverse_scope +from sqlglot.schema import Schema -def qualify_tables(expression, db=None, catalog=None, schema=None): +def qualify_tables( + expression: E, + db: t.Optional[str] = None, + catalog: t.Optional[str] = None, + schema: t.Optional[Schema] = None, +) -> E: """ Rewrite sqlglot AST to have fully qualified tables. Additionally, this replaces "join constructs" (*) by equivalent SELECT * subqueries. @@ -21,19 +29,17 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): 'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0' Args: - expression (sqlglot.Expression): expression to qualify - db (str): Database name - catalog (str): Catalog name + expression: Expression to qualify + db: Database name + catalog: Catalog name schema: A schema to populate Returns: - sqlglot.Expression: qualified expression + The qualified expression. (*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html """ - sequence = itertools.count() - - next_name = lambda: f"_q_{next(sequence)}" + next_alias_name = name_sequence("_q_") for scope in traverse_scope(expression): for derived_table in itertools.chain(scope.ctes, scope.derived_tables): @@ -44,10 +50,14 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) if not derived_table.args.get("alias"): - alias_ = f"_q_{next(sequence)}" + alias_ = next_alias_name() derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) scope.rename_source(None, alias_) + pivots = derived_table.args.get("pivots") + if pivots and not pivots[0].alias: + pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))) + for name, source in scope.sources.items(): if isinstance(source, exp.Table): if isinstance(source.this, exp.Identifier): @@ -59,12 +69,19 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): if not source.alias: source = source.replace( alias( - source.copy(), - name if name else next_name(), + source, + name or source.name or next_alias_name(), + copy=True, table=True, ) ) + pivots = source.args.get("pivots") + if pivots and not pivots[0].alias: + pivots[0].set( + "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())) + ) + if schema and isinstance(source.this, exp.ReadCSV): with csv_reader(source.this) as reader: header = next(reader) @@ -74,11 +91,11 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): ) elif isinstance(source, Scope) and source.is_udtf: udtf = source.expression - table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name()) + table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_alias_name()) udtf.set("alias", table_alias) if not table_alias.name: - table_alias.set("this", next_name()) + table_alias.set("this", next_alias_name()) if isinstance(udtf, exp.Values) and not table_alias.columns: for i, e in enumerate(udtf.expressions[0].expressions): table_alias.append("columns", exp.to_identifier(f"_col_{i}")) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index e00b3c9..9ffb4d6 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -1,4 +1,5 @@ import itertools +import typing as t from collections import defaultdict from enum import Enum, auto @@ -83,6 +84,7 @@ class Scope: self._columns = None self._external_columns = None self._join_hints = None + self._pivots = None def branch(self, expression, scope_type, chain_sources=None, **kwargs): """Branch from the current scope to a new, inner scope""" @@ -261,12 +263,14 @@ class Scope: self._columns = [] for column in columns + external_columns: - ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint) + ancestor = column.find_ancestor( + exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint + ) if ( not ancestor - # Window functions can have an ORDER BY clause - or not isinstance(ancestor.parent, exp.Select) or column.table + or isinstance(ancestor, exp.Select) + or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window)) or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) ): self._columns.append(column) @@ -370,6 +374,17 @@ class Scope: return [] return self._join_hints + @property + def pivots(self): + if not self._pivots: + self._pivots = [ + pivot + for node in self.tables + self.derived_tables + for pivot in node.args.get("pivots") or [] + ] + + return self._pivots + def source_columns(self, source_name): """ Get all columns in the current scope for a particular source. @@ -463,7 +478,7 @@ class Scope: return scope_ref_count -def traverse_scope(expression): +def traverse_scope(expression: exp.Expression) -> t.List[Scope]: """ Traverse an expression by it's "scopes". @@ -488,10 +503,12 @@ def traverse_scope(expression): Returns: list[Scope]: scope instances """ + if not isinstance(expression, exp.Unionable): + return [] return list(_traverse_scope(Scope(expression))) -def build_scope(expression): +def build_scope(expression: exp.Expression) -> t.Optional[Scope]: """ Build a scope tree. @@ -500,7 +517,10 @@ def build_scope(expression): Returns: Scope: root scope """ - return traverse_scope(expression)[-1] + scopes = traverse_scope(expression) + if scopes: + return scopes[-1] + return None def _traverse_scope(scope): @@ -585,7 +605,7 @@ def _traverse_tables(scope): expressions = [] from_ = scope.expression.args.get("from") if from_: - expressions.extend(from_.expressions) + expressions.append(from_.this) for join in scope.expression.args.get("joins") or []: expressions.append(join.this) @@ -601,8 +621,13 @@ def _traverse_tables(scope): source_name = expression.alias_or_name if table_name in scope.sources: - # This is a reference to a parent source (e.g. a CTE), not an actual table. - sources[source_name] = scope.sources[table_name] + # This is a reference to a parent source (e.g. a CTE), not an actual table, unless + # it is pivoted, because then we get back a new table and hence a new source. + pivots = expression.args.get("pivots") + if pivots: + sources[pivots[0].alias] = expression + else: + sources[source_name] = scope.sources[table_name] elif source_name in sources: sources[find_new_name(sources, table_name)] = expression else: diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 0904189..e2772a0 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -5,11 +5,9 @@ from collections import deque from decimal import Decimal from sqlglot import exp -from sqlglot.generator import Generator +from sqlglot.generator import cached_generator from sqlglot.helper import first, while_changing -GENERATOR = Generator(normalize=True, identify="safe") - def simplify(expression): """ @@ -27,12 +25,12 @@ def simplify(expression): sqlglot.Expression: simplified expression """ - cache = {} + generate = cached_generator() def _simplify(expression, root=True): node = expression node = rewrite_between(node) - node = uniq_sort(node, cache, root) + node = uniq_sort(node, generate, root) node = absorb_and_eliminate(node, root) exp.replace_children(node, lambda e: _simplify(e, False)) node = simplify_not(node) @@ -247,7 +245,7 @@ def remove_compliments(expression, root=True): return expression -def uniq_sort(expression, cache=None, root=True): +def uniq_sort(expression, generate, root=True): """ Uniq and sort a connector. @@ -256,7 +254,7 @@ def uniq_sort(expression, cache=None, root=True): if isinstance(expression, exp.Connector) and (root or not expression.same_parent): result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ flattened = tuple(expression.flatten()) - deduped = {GENERATOR.generate(e, cache): e for e in flattened} + deduped = {generate(e): e for e in flattened} arr = tuple(deduped.items()) # check if the operands are already sorted, if not sort them @@ -388,14 +386,18 @@ def _simplify_binary(expression, a, b): def simplify_parens(expression): - if ( - isinstance(expression, exp.Paren) - and not isinstance(expression.this, exp.Select) - and ( - not isinstance(expression.parent, (exp.Condition, exp.Binary)) - or isinstance(expression.this, exp.Predicate) - or not isinstance(expression.this, exp.Binary) - ) + if not isinstance(expression, exp.Paren): + return expression + + this = expression.this + parent = expression.parent + + if not isinstance(this, exp.Select) and ( + not isinstance(parent, (exp.Condition, exp.Binary)) + or isinstance(this, exp.Predicate) + or not isinstance(this, exp.Binary) + or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) + or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) ): return expression.this return expression diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index a515489..09e3f2a 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -1,6 +1,5 @@ -import itertools - from sqlglot import exp +from sqlglot.helper import name_sequence from sqlglot.optimizer.scope import ScopeType, traverse_scope @@ -22,7 +21,7 @@ def unnest_subqueries(expression): Returns: sqlglot.Expression: unnested expression """ - sequence = itertools.count() + next_alias_name = name_sequence("_u_") for scope in traverse_scope(expression): select = scope.expression @@ -30,19 +29,19 @@ def unnest_subqueries(expression): if not parent: continue if scope.external_columns: - decorrelate(select, parent, scope.external_columns, sequence) + decorrelate(select, parent, scope.external_columns, next_alias_name) elif scope.scope_type == ScopeType.SUBQUERY: - unnest(select, parent, sequence) + unnest(select, parent, next_alias_name) return expression -def unnest(select, parent_select, sequence): +def unnest(select, parent_select, next_alias_name): if len(select.selects) > 1: return predicate = select.find_ancestor(exp.Condition) - alias = _alias(sequence) + alias = next_alias_name() if not predicate or parent_select is not predicate.parent_select: return @@ -87,13 +86,13 @@ def unnest(select, parent_select, sequence): ) -def decorrelate(select, parent_select, external_columns, sequence): +def decorrelate(select, parent_select, external_columns, next_alias_name): where = select.args.get("where") if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): return - table_alias = _alias(sequence) + table_alias = next_alias_name() keys = [] # for all external columns in the where statement, find the relevant predicate @@ -136,7 +135,7 @@ def decorrelate(select, parent_select, external_columns, sequence): group_by.append(key) else: if key not in key_aliases: - key_aliases[key] = _alias(sequence) + key_aliases[key] = next_alias_name() # all predicates that are equalities must also be in the unique # so that we don't do a many to many join if isinstance(predicate, exp.EQ) and key not in group_by: @@ -244,10 +243,6 @@ def decorrelate(select, parent_select, external_columns, sequence): ) -def _alias(sequence): - return f"_u_{next(sequence)}" - - def _replace(expression, condition): return expression.replace(exp.condition(condition)) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index d8d9f88..e77bb5a 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -6,22 +6,17 @@ from collections import defaultdict from sqlglot import exp from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors -from sqlglot.helper import ( - apply_index_offset, - count_params, - ensure_collection, - ensure_list, - seq_get, -) +from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import in_trie, new_trie -logger = logging.getLogger("sqlglot") +if t.TYPE_CHECKING: + from sqlglot._typing import E -E = t.TypeVar("E", bound=exp.Expression) +logger = logging.getLogger("sqlglot") -def parse_var_map(args: t.Sequence) -> exp.Expression: +def parse_var_map(args: t.List) -> exp.StarMap | exp.VarMap: if len(args) == 1 and args[0].is_star: return exp.StarMap(this=args[0]) @@ -36,7 +31,7 @@ def parse_var_map(args: t.Sequence) -> exp.Expression: ) -def parse_like(args): +def parse_like(args: t.List) -> exp.Expression: like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0)) return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like @@ -65,7 +60,7 @@ class Parser(metaclass=_Parser): Args: error_level: the desired error level. - Default: ErrorLevel.RAISE + Default: ErrorLevel.IMMEDIATE error_message_context: determines the amount of context to capture from a query string when displaying the error message (in number of characters). Default: 50. @@ -118,8 +113,8 @@ class Parser(metaclass=_Parser): NESTED_TYPE_TOKENS = { TokenType.ARRAY, TokenType.MAP, - TokenType.STRUCT, TokenType.NULLABLE, + TokenType.STRUCT, } TYPE_TOKENS = { @@ -158,6 +153,7 @@ class Parser(metaclass=_Parser): TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPLTZ, TokenType.DATETIME, + TokenType.DATETIME64, TokenType.DATE, TokenType.DECIMAL, TokenType.BIGDECIMAL, @@ -211,20 +207,18 @@ class Parser(metaclass=_Parser): TokenType.VAR, TokenType.ANTI, TokenType.APPLY, + TokenType.ASC, TokenType.AUTO_INCREMENT, TokenType.BEGIN, - TokenType.BOTH, - TokenType.BUCKET, TokenType.CACHE, - TokenType.CASCADE, TokenType.COLLATE, TokenType.COMMAND, TokenType.COMMENT, TokenType.COMMIT, - TokenType.COMPOUND, TokenType.CONSTRAINT, TokenType.DEFAULT, TokenType.DELETE, + TokenType.DESC, TokenType.DESCRIBE, TokenType.DIV, TokenType.END, @@ -233,7 +227,6 @@ class Parser(metaclass=_Parser): TokenType.FALSE, TokenType.FIRST, TokenType.FILTER, - TokenType.FOLLOWING, TokenType.FORMAT, TokenType.FULL, TokenType.IF, @@ -241,41 +234,31 @@ class Parser(metaclass=_Parser): TokenType.ISNULL, TokenType.INTERVAL, TokenType.KEEP, - TokenType.LAZY, - TokenType.LEADING, TokenType.LEFT, - TokenType.LOCAL, - TokenType.MATERIALIZED, + TokenType.LOAD, TokenType.MERGE, TokenType.NATURAL, TokenType.NEXT, TokenType.OFFSET, - TokenType.ONLY, - TokenType.OPTIONS, TokenType.ORDINALITY, TokenType.OVERWRITE, TokenType.PARTITION, TokenType.PERCENT, TokenType.PIVOT, TokenType.PRAGMA, - TokenType.PRECEDING, TokenType.RANGE, TokenType.REFERENCES, TokenType.RIGHT, TokenType.ROW, TokenType.ROWS, - TokenType.SEED, TokenType.SEMI, TokenType.SET, + TokenType.SETTINGS, TokenType.SHOW, - TokenType.SORTKEY, TokenType.TEMPORARY, TokenType.TOP, - TokenType.TRAILING, TokenType.TRUE, - TokenType.UNBOUNDED, TokenType.UNIQUE, - TokenType.UNLOGGED, TokenType.UNPIVOT, TokenType.VOLATILE, TokenType.WINDOW, @@ -291,6 +274,7 @@ class Parser(metaclass=_Parser): TokenType.APPLY, TokenType.FULL, TokenType.LEFT, + TokenType.LOCK, TokenType.NATURAL, TokenType.OFFSET, TokenType.RIGHT, @@ -301,7 +285,7 @@ class Parser(metaclass=_Parser): UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET} - TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH} + TRIM_TYPES = {"LEADING", "TRAILING", "BOTH"} FUNC_TOKENS = { TokenType.COMMAND, @@ -322,6 +306,7 @@ class Parser(metaclass=_Parser): TokenType.MERGE, TokenType.OFFSET, TokenType.PRIMARY_KEY, + TokenType.RANGE, TokenType.REPLACE, TokenType.ROW, TokenType.UNNEST, @@ -455,31 +440,31 @@ class Parser(metaclass=_Parser): } EXPRESSION_PARSERS = { + exp.Cluster: lambda self: self._parse_sort(exp.Cluster, "CLUSTER", "BY"), exp.Column: lambda self: self._parse_column(), + exp.Condition: lambda self: self._parse_conjunction(), exp.DataType: lambda self: self._parse_types(), + exp.Expression: lambda self: self._parse_statement(), exp.From: lambda self: self._parse_from(), exp.Group: lambda self: self._parse_group(), + exp.Having: lambda self: self._parse_having(), exp.Identifier: lambda self: self._parse_id_var(), - exp.Lateral: lambda self: self._parse_lateral(), exp.Join: lambda self: self._parse_join(), - exp.Order: lambda self: self._parse_order(), - exp.Cluster: lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), - exp.Sort: lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), exp.Lambda: lambda self: self._parse_lambda(), + exp.Lateral: lambda self: self._parse_lateral(), exp.Limit: lambda self: self._parse_limit(), exp.Offset: lambda self: self._parse_offset(), - exp.TableAlias: lambda self: self._parse_table_alias(), - exp.Table: lambda self: self._parse_table(), - exp.Condition: lambda self: self._parse_conjunction(), - exp.Expression: lambda self: self._parse_statement(), - exp.Properties: lambda self: self._parse_properties(), - exp.Where: lambda self: self._parse_where(), + exp.Order: lambda self: self._parse_order(), exp.Ordered: lambda self: self._parse_ordered(), - exp.Having: lambda self: self._parse_having(), - exp.With: lambda self: self._parse_with(), - exp.Window: lambda self: self._parse_named_window(), + exp.Properties: lambda self: self._parse_properties(), exp.Qualify: lambda self: self._parse_qualify(), exp.Returning: lambda self: self._parse_returning(), + exp.Sort: lambda self: self._parse_sort(exp.Sort, "SORT", "BY"), + exp.Table: lambda self: self._parse_table_parts(), + exp.TableAlias: lambda self: self._parse_table_alias(), + exp.Where: lambda self: self._parse_where(), + exp.Window: lambda self: self._parse_named_window(), + exp.With: lambda self: self._parse_with(), "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(), } @@ -495,9 +480,13 @@ class Parser(metaclass=_Parser): TokenType.DESCRIBE: lambda self: self._parse_describe(), TokenType.DROP: lambda self: self._parse_drop(), TokenType.END: lambda self: self._parse_commit_or_rollback(), + TokenType.FROM: lambda self: exp.select("*").from_( + t.cast(exp.From, self._parse_from(skip_from_token=True)) + ), TokenType.INSERT: lambda self: self._parse_insert(), - TokenType.LOAD_DATA: lambda self: self._parse_load_data(), + TokenType.LOAD: lambda self: self._parse_load(), TokenType.MERGE: lambda self: self._parse_merge(), + TokenType.PIVOT: lambda self: self._parse_simplified_pivot(), TokenType.PRAGMA: lambda self: self.expression(exp.Pragma, this=self._parse_expression()), TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), TokenType.SET: lambda self: self._parse_set(), @@ -536,7 +525,10 @@ class Parser(metaclass=_Parser): TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text), TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text), TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token), - TokenType.NATIONAL: lambda self, token: self._parse_national(token), + TokenType.NATIONAL_STRING: lambda self, token: self.expression( + exp.National, this=token.text + ), + TokenType.RAW_STRING: lambda self, token: self.expression(exp.RawString, this=token.text), TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(), } @@ -551,91 +543,76 @@ class Parser(metaclass=_Parser): RANGE_PARSERS = { TokenType.BETWEEN: lambda self, this: self._parse_between(this), TokenType.GLOB: binary_range_parser(exp.Glob), - TokenType.OVERLAPS: binary_range_parser(exp.Overlaps), + TokenType.ILIKE: binary_range_parser(exp.ILike), TokenType.IN: lambda self, this: self._parse_in(this), + TokenType.IRLIKE: binary_range_parser(exp.RegexpILike), TokenType.IS: lambda self, this: self._parse_is(this), TokenType.LIKE: binary_range_parser(exp.Like), - TokenType.ILIKE: binary_range_parser(exp.ILike), - TokenType.IRLIKE: binary_range_parser(exp.RegexpILike), + TokenType.OVERLAPS: binary_range_parser(exp.Overlaps), TokenType.RLIKE: binary_range_parser(exp.RegexpLike), TokenType.SIMILAR_TO: binary_range_parser(exp.SimilarTo), } - PROPERTY_PARSERS = { - "AFTER": lambda self: self._parse_afterjournal( - no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" - ), + PROPERTY_PARSERS: t.Dict[str, t.Callable] = { "ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty), "AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty), - "BEFORE": lambda self: self._parse_journal( - no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" - ), "BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(), "CHARACTER SET": lambda self: self._parse_character_set(), "CHECKSUM": lambda self: self._parse_checksum(), - "CLUSTER BY": lambda self: self.expression( - exp.Cluster, expressions=self._parse_csv(self._parse_ordered) - ), + "CLUSTER": lambda self: self._parse_cluster(), "COLLATE": lambda self: self._parse_property_assignment(exp.CollateProperty), "COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), - "DATABLOCKSIZE": lambda self: self._parse_datablocksize( - default=self._prev.text.upper() == "DEFAULT" - ), + "DATABLOCKSIZE": lambda self, **kwargs: self._parse_datablocksize(**kwargs), "DEFINER": lambda self: self._parse_definer(), "DETERMINISTIC": lambda self: self.expression( exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE") ), "DISTKEY": lambda self: self._parse_distkey(), "DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty), + "ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty), "EXECUTE": lambda self: self._parse_property_assignment(exp.ExecuteAsProperty), "EXTERNAL": lambda self: self.expression(exp.ExternalProperty), - "FALLBACK": lambda self: self._parse_fallback(no=self._prev.text.upper() == "NO"), + "FALLBACK": lambda self, **kwargs: self._parse_fallback(**kwargs), "FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty), "FREESPACE": lambda self: self._parse_freespace(), - "GLOBAL": lambda self: self._parse_temporary(global_=True), "IMMUTABLE": lambda self: self.expression( exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE") ), - "JOURNAL": lambda self: self._parse_journal( - no=self._prev.text.upper() == "NO", dual=self._prev.text.upper() == "DUAL" - ), + "JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs), "LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty), "LIKE": lambda self: self._parse_create_like(), - "LOCAL": lambda self: self._parse_afterjournal(no=False, dual=False, local=True), "LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty), "LOCK": lambda self: self._parse_locking(), "LOCKING": lambda self: self._parse_locking(), - "LOG": lambda self: self._parse_log(no=self._prev.text.upper() == "NO"), + "LOG": lambda self, **kwargs: self._parse_log(**kwargs), "MATERIALIZED": lambda self: self.expression(exp.MaterializedProperty), - "MAX": lambda self: self._parse_datablocksize(), - "MAXIMUM": lambda self: self._parse_datablocksize(), - "MERGEBLOCKRATIO": lambda self: self._parse_mergeblockratio( - no=self._prev.text.upper() == "NO", default=self._prev.text.upper() == "DEFAULT" - ), - "MIN": lambda self: self._parse_datablocksize(), - "MINIMUM": lambda self: self._parse_datablocksize(), + "MERGEBLOCKRATIO": lambda self, **kwargs: self._parse_mergeblockratio(**kwargs), "MULTISET": lambda self: self.expression(exp.SetProperty, multi=True), - "NO": lambda self: self._parse_noprimaryindex(), - "NOT": lambda self: self._parse_afterjournal(no=False, dual=False, local=False), - "ON": lambda self: self._parse_oncommit(), + "NO": lambda self: self._parse_no_property(), + "ON": lambda self: self._parse_on_property(), + "ORDER BY": lambda self: self._parse_order(skip_order_token=True), "PARTITION BY": lambda self: self._parse_partitioned_by(), "PARTITIONED BY": lambda self: self._parse_partitioned_by(), "PARTITIONED_BY": lambda self: self._parse_partitioned_by(), + "PRIMARY KEY": lambda self: self._parse_primary_key(), "RETURNS": lambda self: self._parse_returns(), "ROW": lambda self: self._parse_row(), "ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty), "SET": lambda self: self.expression(exp.SetProperty, multi=False), + "SETTINGS": lambda self: self.expression( + exp.SettingsProperty, expressions=self._parse_csv(self._parse_set_item) + ), "SORTKEY": lambda self: self._parse_sortkey(), "STABLE": lambda self: self.expression( exp.StabilityProperty, this=exp.Literal.string("STABLE") ), "STORED": lambda self: self._parse_stored(), - "TABLE_FORMAT": lambda self: self._parse_property_assignment(exp.TableFormatProperty), "TBLPROPERTIES": lambda self: self._parse_wrapped_csv(self._parse_property), - "TEMP": lambda self: self._parse_temporary(global_=False), - "TEMPORARY": lambda self: self._parse_temporary(global_=False), + "TEMP": lambda self: self.expression(exp.TemporaryProperty), + "TEMPORARY": lambda self: self.expression(exp.TemporaryProperty), "TRANSIENT": lambda self: self.expression(exp.TransientProperty), - "USING": lambda self: self._parse_property_assignment(exp.TableFormatProperty), + "TTL": lambda self: self._parse_ttl(), + "USING": lambda self: self._parse_property_assignment(exp.FileFormatProperty), "VOLATILE": lambda self: self._parse_volatile_property(), "WITH": lambda self: self._parse_with_property(), } @@ -679,6 +656,7 @@ class Parser(metaclass=_Parser): "TITLE": lambda self: self.expression( exp.TitleColumnConstraint, this=self._parse_var_or_string() ), + "TTL": lambda self: self.expression(exp.MergeTreeTTL, expressions=[self._parse_bitwise()]), "UNIQUE": lambda self: self._parse_unique(), "UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint), } @@ -704,6 +682,8 @@ class Parser(metaclass=_Parser): ), } + FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"} + FUNCTION_PARSERS: t.Dict[str, t.Callable] = { "CAST": lambda self: self._parse_cast(self.STRICT_CAST), "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), @@ -712,7 +692,9 @@ class Parser(metaclass=_Parser): "JSON_OBJECT": lambda self: self._parse_json_object(), "LOG": lambda self: self._parse_logarithm(), "MATCH": lambda self: self._parse_match_against(), + "OPENJSON": lambda self: self._parse_open_json(), "POSITION": lambda self: self._parse_position(), + "SAFE_CAST": lambda self: self._parse_cast(False), "STRING_AGG": lambda self: self._parse_string_agg(), "SUBSTRING": lambda self: self._parse_substring(), "TRIM": lambda self: self._parse_trim(), @@ -721,19 +703,18 @@ class Parser(metaclass=_Parser): } QUERY_MODIFIER_PARSERS = { + "joins": lambda self: list(iter(self._parse_join, None)), + "laterals": lambda self: list(iter(self._parse_lateral, None)), "match": lambda self: self._parse_match_recognize(), "where": lambda self: self._parse_where(), "group": lambda self: self._parse_group(), "having": lambda self: self._parse_having(), "qualify": lambda self: self._parse_qualify(), "windows": lambda self: self._parse_window_clause(), - "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute), - "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), - "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), "order": lambda self: self._parse_order(), "limit": lambda self: self._parse_limit(), "offset": lambda self: self._parse_offset(), - "lock": lambda self: self._parse_lock(), + "locks": lambda self: self._parse_locks(), "sample": lambda self: self._parse_table_sample(as_modifier=True), } @@ -763,8 +744,11 @@ class Parser(metaclass=_Parser): INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"} + CLONE_KINDS = {"TIMESTAMP", "OFFSET", "STATEMENT"} + WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS} WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER} + WINDOW_SIDES = {"FOLLOWING", "PRECEDING"} ADD_CONSTRAINT_TOKENS = {TokenType.CONSTRAINT, TokenType.PRIMARY_KEY, TokenType.FOREIGN_KEY} @@ -772,8 +756,8 @@ class Parser(metaclass=_Parser): CONVERT_TYPE_FIRST = False - QUOTED_PIVOT_COLUMNS: t.Optional[bool] = None PREFIXED_PIVOT_COLUMNS = False + IDENTIFY_PIVOT_STRINGS = False LOG_BASE_FIRST = True LOG_DEFAULTS_TO_LN = False @@ -875,7 +859,7 @@ class Parser(metaclass=_Parser): e.errors[0]["into_expression"] = expression_type errors.append(e) raise ParseError( - f"Failed to parse into {expression_types}", + f"Failed to parse '{sql or raw_tokens}' into {expression_types}", errors=merge_errors(errors), ) from errors[-1] @@ -933,7 +917,7 @@ class Parser(metaclass=_Parser): """ token = token or self._curr or self._prev or Token.string("") start = token.start - end = token.end + end = token.end + 1 start_context = self.sql[max(start - self.error_message_context, 0) : start] highlight = self.sql[start:end] end_context = self.sql[end : end + self.error_message_context] @@ -996,7 +980,7 @@ class Parser(metaclass=_Parser): self.raise_error(error_message) def _find_sql(self, start: Token, end: Token) -> str: - return self.sql[start.start : end.end] + return self.sql[start.start : end.end + 1] def _advance(self, times: int = 1) -> None: self._index += times @@ -1042,6 +1026,44 @@ class Parser(metaclass=_Parser): exp.Comment, this=this, kind=kind.text, expression=self._parse_string(), exists=exists ) + # https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl + def _parse_ttl(self) -> exp.Expression: + def _parse_ttl_action() -> t.Optional[exp.Expression]: + this = self._parse_bitwise() + + if self._match_text_seq("DELETE"): + return self.expression(exp.MergeTreeTTLAction, this=this, delete=True) + if self._match_text_seq("RECOMPRESS"): + return self.expression( + exp.MergeTreeTTLAction, this=this, recompress=self._parse_bitwise() + ) + if self._match_text_seq("TO", "DISK"): + return self.expression( + exp.MergeTreeTTLAction, this=this, to_disk=self._parse_string() + ) + if self._match_text_seq("TO", "VOLUME"): + return self.expression( + exp.MergeTreeTTLAction, this=this, to_volume=self._parse_string() + ) + + return this + + expressions = self._parse_csv(_parse_ttl_action) + where = self._parse_where() + group = self._parse_group() + + aggregates = None + if group and self._match(TokenType.SET): + aggregates = self._parse_csv(self._parse_set_item) + + return self.expression( + exp.MergeTreeTTL, + expressions=expressions, + where=where, + group=group, + aggregates=aggregates, + ) + def _parse_statement(self) -> t.Optional[exp.Expression]: if self._curr is None: return None @@ -1054,14 +1076,12 @@ class Parser(metaclass=_Parser): expression = self._parse_expression() expression = self._parse_set_operations(expression) if expression else self._parse_select() - - self._parse_query_modifiers(expression) - return expression + return self._parse_query_modifiers(expression) def _parse_drop(self) -> t.Optional[exp.Drop | exp.Command]: start = self._prev temporary = self._match(TokenType.TEMPORARY) - materialized = self._match(TokenType.MATERIALIZED) + materialized = self._match_text_seq("MATERIALIZED") kind = self._match_set(self.CREATABLES) and self._prev.text if not kind: return self._parse_as_command(start) @@ -1073,7 +1093,7 @@ class Parser(metaclass=_Parser): kind=kind, temporary=temporary, materialized=materialized, - cascade=self._match(TokenType.CASCADE), + cascade=self._match_text_seq("CASCADE"), constraints=self._match_text_seq("CONSTRAINTS"), purge=self._match_text_seq("PURGE"), ) @@ -1111,6 +1131,7 @@ class Parser(metaclass=_Parser): indexes = None no_schema_binding = None begin = None + clone = None if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): this = self._parse_user_defined_function(kind=create_token.token_type) @@ -1128,7 +1149,7 @@ class Parser(metaclass=_Parser): if return_: expression = self.expression(exp.Return, this=expression) elif create_token.token_type == TokenType.INDEX: - this = self._parse_index() + this = self._parse_index(index=self._parse_id_var()) elif create_token.token_type in self.DB_CREATABLES: table_parts = self._parse_table_parts(schema=True) @@ -1166,33 +1187,40 @@ class Parser(metaclass=_Parser): expression = self._parse_ddl_select() if create_token.token_type == TokenType.TABLE: - # exp.Properties.Location.POST_EXPRESSION - temp_properties = self._parse_properties() - if properties and temp_properties: - properties.expressions.extend(temp_properties.expressions) - elif temp_properties: - properties = temp_properties - indexes = [] while True: - index = self._parse_create_table_index() + index = self._parse_index() - # exp.Properties.Location.POST_INDEX - if self._match(TokenType.PARTITION_BY, advance=False): - temp_properties = self._parse_properties() - if properties and temp_properties: - properties.expressions.extend(temp_properties.expressions) - elif temp_properties: - properties = temp_properties + # exp.Properties.Location.POST_EXPRESSION or exp.Properties.Location.POST_INDEX + temp_properties = self._parse_properties() + if properties and temp_properties: + properties.expressions.extend(temp_properties.expressions) + elif temp_properties: + properties = temp_properties if not index: break else: + self._match(TokenType.COMMA) indexes.append(index) elif create_token.token_type == TokenType.VIEW: if self._match_text_seq("WITH", "NO", "SCHEMA", "BINDING"): no_schema_binding = True + if self._match_text_seq("CLONE"): + clone = self._parse_table(schema=True) + when = self._match_texts({"AT", "BEFORE"}) and self._prev.text.upper() + clone_kind = ( + self._match(TokenType.L_PAREN) + and self._match_texts(self.CLONE_KINDS) + and self._prev.text.upper() + ) + clone_expression = self._match(TokenType.FARROW) and self._parse_bitwise() + self._match(TokenType.R_PAREN) + clone = self.expression( + exp.Clone, this=clone, when=when, kind=clone_kind, expression=clone_expression + ) + return self.expression( exp.Create, this=this, @@ -1205,18 +1233,31 @@ class Parser(metaclass=_Parser): indexes=indexes, no_schema_binding=no_schema_binding, begin=begin, + clone=clone, ) def _parse_property_before(self) -> t.Optional[exp.Expression]: + # only used for teradata currently self._match(TokenType.COMMA) - # parsers look to _prev for no/dual/default, so need to consume first - self._match_text_seq("NO") - self._match_text_seq("DUAL") - self._match_text_seq("DEFAULT") + kwargs = { + "no": self._match_text_seq("NO"), + "dual": self._match_text_seq("DUAL"), + "before": self._match_text_seq("BEFORE"), + "default": self._match_text_seq("DEFAULT"), + "local": (self._match_text_seq("LOCAL") and "LOCAL") + or (self._match_text_seq("NOT", "LOCAL") and "NOT LOCAL"), + "after": self._match_text_seq("AFTER"), + "minimum": self._match_texts(("MIN", "MINIMUM")), + "maximum": self._match_texts(("MAX", "MAXIMUM")), + } - if self.PROPERTY_PARSERS.get(self._curr.text.upper()): - return self.PROPERTY_PARSERS[self._curr.text.upper()](self) + if self._match_texts(self.PROPERTY_PARSERS): + parser = self.PROPERTY_PARSERS[self._prev.text.upper()] + try: + return parser(self, **{k: v for k, v in kwargs.items() if v}) + except TypeError: + self.raise_error(f"Cannot parse property '{self._prev.text}'") return None @@ -1227,7 +1268,7 @@ class Parser(metaclass=_Parser): if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET): return self._parse_character_set(default=True) - if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY): + if self._match_text_seq("COMPOUND", "SORTKEY"): return self._parse_sortkey(compound=True) if self._match_text_seq("SQL", "SECURITY"): @@ -1262,23 +1303,20 @@ class Parser(metaclass=_Parser): def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression: self._match(TokenType.EQ) self._match(TokenType.ALIAS) - return self.expression( - exp_class, - this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(), - ) + return self.expression(exp_class, this=self._parse_field()) - def _parse_properties(self, before=None) -> t.Optional[exp.Expression]: + def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Expression]: properties = [] while True: if before: - identified_property = self._parse_property_before() + prop = self._parse_property_before() else: - identified_property = self._parse_property() + prop = self._parse_property() - if not identified_property: + if not prop: break - for p in ensure_list(identified_property): + for p in ensure_list(prop): properties.append(p) if properties: @@ -1286,8 +1324,7 @@ class Parser(metaclass=_Parser): return None - def _parse_fallback(self, no=False) -> exp.Expression: - self._match_text_seq("FALLBACK") + def _parse_fallback(self, no: bool = False) -> exp.Expression: return self.expression( exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION") ) @@ -1345,23 +1382,13 @@ class Parser(metaclass=_Parser): self._match(TokenType.EQ) return self.expression(exp.WithJournalTableProperty, this=self._parse_table_parts()) - def _parse_log(self, no=False) -> exp.Expression: - self._match_text_seq("LOG") + def _parse_log(self, no: bool = False) -> exp.Expression: return self.expression(exp.LogProperty, no=no) - def _parse_journal(self, no=False, dual=False) -> exp.Expression: - before = self._match_text_seq("BEFORE") - self._match_text_seq("JOURNAL") - return self.expression(exp.JournalProperty, no=no, dual=dual, before=before) - - def _parse_afterjournal(self, no=False, dual=False, local=None) -> exp.Expression: - self._match_text_seq("NOT") - self._match_text_seq("LOCAL") - self._match_text_seq("AFTER", "JOURNAL") - return self.expression(exp.AfterJournalProperty, no=no, dual=dual, local=local) + def _parse_journal(self, **kwargs) -> exp.Expression: + return self.expression(exp.JournalProperty, **kwargs) def _parse_checksum(self) -> exp.Expression: - self._match_text_seq("CHECKSUM") self._match(TokenType.EQ) on = None @@ -1377,49 +1404,55 @@ class Parser(metaclass=_Parser): default=default, ) + def _parse_cluster(self) -> t.Optional[exp.Expression]: + if not self._match_text_seq("BY"): + self._retreat(self._index - 1) + return None + return self.expression( + exp.Cluster, + expressions=self._parse_csv(self._parse_ordered), + ) + def _parse_freespace(self) -> exp.Expression: - self._match_text_seq("FREESPACE") self._match(TokenType.EQ) return self.expression( exp.FreespaceProperty, this=self._parse_number(), percent=self._match(TokenType.PERCENT) ) - def _parse_mergeblockratio(self, no=False, default=False) -> exp.Expression: - self._match_text_seq("MERGEBLOCKRATIO") + def _parse_mergeblockratio(self, no: bool = False, default: bool = False) -> exp.Expression: if self._match(TokenType.EQ): return self.expression( exp.MergeBlockRatioProperty, this=self._parse_number(), percent=self._match(TokenType.PERCENT), ) - else: - return self.expression( - exp.MergeBlockRatioProperty, - no=no, - default=default, - ) + return self.expression( + exp.MergeBlockRatioProperty, + no=no, + default=default, + ) - def _parse_datablocksize(self, default=None) -> exp.Expression: - if default: - self._match_text_seq("DATABLOCKSIZE") - return self.expression(exp.DataBlocksizeProperty, default=True) - elif self._match_texts(("MIN", "MINIMUM")): - self._match_text_seq("DATABLOCKSIZE") - return self.expression(exp.DataBlocksizeProperty, min=True) - elif self._match_texts(("MAX", "MAXIMUM")): - self._match_text_seq("DATABLOCKSIZE") - return self.expression(exp.DataBlocksizeProperty, min=False) - - self._match_text_seq("DATABLOCKSIZE") + def _parse_datablocksize( + self, + default: t.Optional[bool] = None, + minimum: t.Optional[bool] = None, + maximum: t.Optional[bool] = None, + ) -> exp.Expression: self._match(TokenType.EQ) size = self._parse_number() units = None if self._match_texts(("BYTES", "KBYTES", "KILOBYTES")): units = self._prev.text - return self.expression(exp.DataBlocksizeProperty, size=size, units=units) + return self.expression( + exp.DataBlocksizeProperty, + size=size, + units=units, + default=default, + minimum=minimum, + maximum=maximum, + ) def _parse_blockcompression(self) -> exp.Expression: - self._match_text_seq("BLOCKCOMPRESSION") self._match(TokenType.EQ) always = self._match_text_seq("ALWAYS") manual = self._match_text_seq("MANUAL") @@ -1516,7 +1549,7 @@ class Parser(metaclass=_Parser): this=self._parse_schema() or self._parse_bracket(self._parse_field()), ) - def _parse_withdata(self, no=False) -> exp.Expression: + def _parse_withdata(self, no: bool = False) -> exp.Expression: if self._match_text_seq("AND", "STATISTICS"): statistics = True elif self._match_text_seq("AND", "NO", "STATISTICS"): @@ -1526,13 +1559,17 @@ class Parser(metaclass=_Parser): return self.expression(exp.WithDataProperty, no=no, statistics=statistics) - def _parse_noprimaryindex(self) -> exp.Expression: - self._match_text_seq("PRIMARY", "INDEX") - return exp.NoPrimaryIndexProperty() + def _parse_no_property(self) -> t.Optional[exp.Property]: + if self._match_text_seq("PRIMARY", "INDEX"): + return exp.NoPrimaryIndexProperty() + return None - def _parse_oncommit(self) -> exp.Expression: - self._match_text_seq("COMMIT", "PRESERVE", "ROWS") - return exp.OnCommitProperty() + def _parse_on_property(self) -> t.Optional[exp.Property]: + if self._match_text_seq("COMMIT", "PRESERVE", "ROWS"): + return exp.OnCommitProperty() + elif self._match_text_seq("COMMIT", "DELETE", "ROWS"): + return exp.OnCommitProperty(delete=True) + return None def _parse_distkey(self) -> exp.Expression: return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var)) @@ -1587,10 +1624,6 @@ class Parser(metaclass=_Parser): return self.expression(exp.ReturnsProperty, this=value, is_table=is_table) - def _parse_temporary(self, global_=False) -> exp.Expression: - self._match(TokenType.TEMPORARY) # in case calling from "GLOBAL" - return self.expression(exp.TemporaryProperty, global_=global_) - def _parse_describe(self) -> exp.Expression: kind = self._match_set(self.CREATABLES) and self._prev.text this = self._parse_table() @@ -1599,7 +1632,7 @@ class Parser(metaclass=_Parser): def _parse_insert(self) -> exp.Expression: overwrite = self._match(TokenType.OVERWRITE) - local = self._match(TokenType.LOCAL) + local = self._match_text_seq("LOCAL") alternative = None if self._match_text_seq("DIRECTORY"): @@ -1700,23 +1733,25 @@ class Parser(metaclass=_Parser): return self.expression(exp.RowFormatDelimitedProperty, **kwargs) # type: ignore - def _parse_load_data(self) -> exp.Expression: - local = self._match(TokenType.LOCAL) - self._match_text_seq("INPATH") - inpath = self._parse_string() - overwrite = self._match(TokenType.OVERWRITE) - self._match_pair(TokenType.INTO, TokenType.TABLE) + def _parse_load(self) -> exp.Expression: + if self._match_text_seq("DATA"): + local = self._match_text_seq("LOCAL") + self._match_text_seq("INPATH") + inpath = self._parse_string() + overwrite = self._match(TokenType.OVERWRITE) + self._match_pair(TokenType.INTO, TokenType.TABLE) - return self.expression( - exp.LoadData, - this=self._parse_table(schema=True), - local=local, - overwrite=overwrite, - inpath=inpath, - partition=self._parse_partition(), - input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(), - serde=self._match_text_seq("SERDE") and self._parse_string(), - ) + return self.expression( + exp.LoadData, + this=self._parse_table(schema=True), + local=local, + overwrite=overwrite, + inpath=inpath, + partition=self._parse_partition(), + input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(), + serde=self._match_text_seq("SERDE") and self._parse_string(), + ) + return self._parse_as_command(self._prev) def _parse_delete(self) -> exp.Expression: self._match(TokenType.FROM) @@ -1735,7 +1770,7 @@ class Parser(metaclass=_Parser): **{ # type: ignore "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), - "from": self._parse_from(), + "from": self._parse_from(modifiers=True), "where": self._parse_where(), "returning": self._parse_returning(), }, @@ -1752,12 +1787,12 @@ class Parser(metaclass=_Parser): ) def _parse_cache(self) -> exp.Expression: - lazy = self._match(TokenType.LAZY) + lazy = self._match_text_seq("LAZY") self._match(TokenType.TABLE) table = self._parse_table(schema=True) options = [] - if self._match(TokenType.OPTIONS): + if self._match_text_seq("OPTIONS"): self._match_l_paren() k = self._parse_string() self._match(TokenType.EQ) @@ -1851,11 +1886,10 @@ class Parser(metaclass=_Parser): if from_: this.set("from", from_) - self._parse_query_modifiers(this) + this = self._parse_query_modifiers(this) elif (table or nested) and self._match(TokenType.L_PAREN): this = self._parse_table() if table else self._parse_select(nested=True) - self._parse_query_modifiers(this) - this = self._parse_set_operations(this) + this = self._parse_set_operations(self._parse_query_modifiers(this)) self._match_r_paren() # early return so that subquery unions aren't parsed again @@ -1868,6 +1902,10 @@ class Parser(metaclass=_Parser): expressions=self._parse_csv(self._parse_value), alias=self._parse_table_alias(), ) + elif self._match(TokenType.PIVOT): + this = self._parse_simplified_pivot() + elif self._match(TokenType.FROM): + this = exp.select("*").from_(t.cast(exp.From, self._parse_from(skip_from_token=True))) else: this = None @@ -1929,7 +1967,9 @@ class Parser(metaclass=_Parser): def _parse_subquery( self, this: t.Optional[exp.Expression], parse_alias: bool = True - ) -> exp.Expression: + ) -> t.Optional[exp.Expression]: + if not this: + return None return self.expression( exp.Subquery, this=this, @@ -1937,35 +1977,16 @@ class Parser(metaclass=_Parser): alias=self._parse_table_alias() if parse_alias else None, ) - def _parse_query_modifiers(self, this: t.Optional[exp.Expression]) -> None: - if not isinstance(this, self.MODIFIABLES): - return - - table = isinstance(this, exp.Table) - - while True: - join = self._parse_join() - if join: - this.append("joins", join) - - lateral = None - if not join: - lateral = self._parse_lateral() - if lateral: - this.append("laterals", lateral) - - comma = None if table else self._match(TokenType.COMMA) - if comma: - this.args["from"].append("expressions", self._parse_table()) - - if not (lateral or join or comma): - break - - for key, parser in self.QUERY_MODIFIER_PARSERS.items(): - expression = parser(self) + def _parse_query_modifiers( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + if isinstance(this, self.MODIFIABLES): + for key, parser in self.QUERY_MODIFIER_PARSERS.items(): + expression = parser(self) - if expression: - this.set(key, expression) + if expression: + this.set(key, expression) + return this def _parse_hint(self) -> t.Optional[exp.Expression]: if self._match(TokenType.HINT): @@ -1981,19 +2002,26 @@ class Parser(metaclass=_Parser): return None temp = self._match(TokenType.TEMPORARY) - unlogged = self._match(TokenType.UNLOGGED) + unlogged = self._match_text_seq("UNLOGGED") self._match(TokenType.TABLE) return self.expression( exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged ) - def _parse_from(self) -> t.Optional[exp.Expression]: - if not self._match(TokenType.FROM): + def _parse_from( + self, modifiers: bool = False, skip_from_token: bool = False + ) -> t.Optional[exp.From]: + if not skip_from_token and not self._match(TokenType.FROM): return None + comments = self._prev_comments + this = self._parse_table() + return self.expression( - exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table) + exp.From, + comments=comments, + this=self._parse_query_modifiers(this) if modifiers else this, ) def _parse_match_recognize(self) -> t.Optional[exp.Expression]: @@ -2136,6 +2164,9 @@ class Parser(metaclass=_Parser): ) def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]: + if self._match(TokenType.COMMA): + return self.expression(exp.Join, this=self._parse_table()) + index = self._index natural, side, kind = self._parse_join_side_and_kind() hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None @@ -2176,55 +2207,66 @@ class Parser(metaclass=_Parser): return self.expression(exp.Join, **kwargs) # type: ignore - def _parse_index(self) -> exp.Expression: - index = self._parse_id_var() - self._match(TokenType.ON) - self._match(TokenType.TABLE) # hive + def _parse_index( + self, + index: t.Optional[exp.Expression] = None, + ) -> t.Optional[exp.Expression]: + if index: + unique = None + primary = None + amp = None - return self.expression( - exp.Index, - this=index, - table=self.expression(exp.Table, this=self._parse_id_var()), - columns=self._parse_expression(), - ) + self._match(TokenType.ON) + self._match(TokenType.TABLE) # hive + table = self._parse_table_parts(schema=True) + else: + unique = self._match(TokenType.UNIQUE) + primary = self._match_text_seq("PRIMARY") + amp = self._match_text_seq("AMP") + if not self._match(TokenType.INDEX): + return None + index = self._parse_id_var() + table = None - def _parse_create_table_index(self) -> t.Optional[exp.Expression]: - unique = self._match(TokenType.UNIQUE) - primary = self._match_text_seq("PRIMARY") - amp = self._match_text_seq("AMP") - if not self._match(TokenType.INDEX): - return None - index = self._parse_id_var() - columns = None if self._match(TokenType.L_PAREN, advance=False): - columns = self._parse_wrapped_csv(self._parse_column) + columns = self._parse_wrapped_csv(self._parse_ordered) + else: + columns = None + return self.expression( exp.Index, this=index, + table=table, columns=columns, unique=unique, primary=primary, amp=amp, + partition_by=self._parse_partition_by(), ) - def _parse_table_parts(self, schema: bool = False) -> exp.Expression: - catalog = None - db = None - - table = ( + def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: + return ( (not schema and self._parse_function()) or self._parse_id_var(any_token=False) or self._parse_string_as_identifier() + or self._parse_placeholder() ) + def _parse_table_parts(self, schema: bool = False) -> exp.Table: + catalog = None + db = None + table = self._parse_table_part(schema=schema) + while self._match(TokenType.DOT): if catalog: # This allows nesting the table in arbitrarily many dot expressions if needed - table = self.expression(exp.Dot, this=table, expression=self._parse_id_var()) + table = self.expression( + exp.Dot, this=table, expression=self._parse_table_part(schema=schema) + ) else: catalog = db db = table - table = self._parse_id_var() + table = self._parse_table_part(schema=schema) if not table: self.raise_error(f"Expected table name but got {self._curr}") @@ -2237,28 +2279,24 @@ class Parser(metaclass=_Parser): self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None ) -> t.Optional[exp.Expression]: lateral = self._parse_lateral() - if lateral: return lateral unnest = self._parse_unnest() - if unnest: return unnest values = self._parse_derived_table_values() - if values: return values subquery = self._parse_select(table=True) - if subquery: if not subquery.args.get("pivots"): subquery.set("pivots", self._parse_pivots()) return subquery - this = self._parse_table_parts(schema=schema) + this: exp.Expression = self._parse_table_parts(schema=schema) if schema: return self._parse_schema(this=this) @@ -2267,7 +2305,6 @@ class Parser(metaclass=_Parser): table_sample = self._parse_table_sample() alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS) - if alias: this.set("alias", alias) @@ -2352,9 +2389,9 @@ class Parser(metaclass=_Parser): num = self._parse_number() - if self._match(TokenType.BUCKET): + if self._match_text_seq("BUCKET"): bucket_numerator = self._parse_number() - self._match(TokenType.OUT_OF) + self._match_text_seq("OUT", "OF") bucket_denominator = bucket_denominator = self._parse_number() self._match(TokenType.ON) bucket_field = self._parse_field() @@ -2390,6 +2427,22 @@ class Parser(metaclass=_Parser): def _parse_pivots(self) -> t.List[t.Optional[exp.Expression]]: return list(iter(self._parse_pivot, None)) + # https://duckdb.org/docs/sql/statements/pivot + def _parse_simplified_pivot(self) -> exp.Pivot: + def _parse_on() -> t.Optional[exp.Expression]: + this = self._parse_bitwise() + return self._parse_in(this) if self._match(TokenType.IN) else this + + this = self._parse_table() + expressions = self._match(TokenType.ON) and self._parse_csv(_parse_on) + using = self._match(TokenType.USING) and self._parse_csv( + lambda: self._parse_alias(self._parse_function()) + ) + group = self._parse_group() + return self.expression( + exp.Pivot, this=this, expressions=expressions, using=using, group=group + ) + def _parse_pivot(self) -> t.Optional[exp.Expression]: index = self._index @@ -2423,7 +2476,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.IN): self.raise_error("Expecting IN") - field = self._parse_in(value) + field = self._parse_in(value, alias=True) self._match_r_paren() @@ -2436,21 +2489,22 @@ class Parser(metaclass=_Parser): names = self._pivot_column_names(t.cast(t.List[exp.Expression], expressions)) columns: t.List[exp.Expression] = [] - for col in pivot.args["field"].expressions: + for fld in pivot.args["field"].expressions: + field_name = fld.sql() if self.IDENTIFY_PIVOT_STRINGS else fld.alias_or_name for name in names: if self.PREFIXED_PIVOT_COLUMNS: - name = f"{name}_{col.alias_or_name}" if name else col.alias_or_name + name = f"{name}_{field_name}" if name else field_name else: - name = f"{col.alias_or_name}_{name}" if name else col.alias_or_name + name = f"{field_name}_{name}" if name else field_name - columns.append(exp.to_identifier(name, quoted=self.QUOTED_PIVOT_COLUMNS)) + columns.append(exp.to_identifier(name)) pivot.set("columns", columns) return pivot - def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]: - return [agg.alias for agg in pivot_columns] + def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: + return [agg.alias for agg in aggregations] def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]: if not skip_where_token and not self._match(TokenType.WHERE): @@ -2477,6 +2531,7 @@ class Parser(metaclass=_Parser): rollup = None cube = None + totals = None with_ = self._match(TokenType.WITH) if self._match(TokenType.ROLLUP): @@ -2487,7 +2542,11 @@ class Parser(metaclass=_Parser): cube = with_ or self._parse_wrapped_csv(self._parse_column) elements["cube"].extend(ensure_list(cube)) - if not (expressions or grouping_sets or rollup or cube): + if self._match_text_seq("TOTALS"): + totals = True + elements["totals"] = True # type: ignore + + if not (grouping_sets or rollup or cube or totals): break return self.expression(exp.Group, **elements) # type: ignore @@ -2527,9 +2586,9 @@ class Parser(metaclass=_Parser): ) def _parse_sort( - self, token_type: TokenType, exp_class: t.Type[exp.Expression] + self, exp_class: t.Type[exp.Expression], *texts: str ) -> t.Optional[exp.Expression]: - if not self._match(token_type): + if not self._match_text_seq(*texts): return None return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered)) @@ -2537,8 +2596,8 @@ class Parser(metaclass=_Parser): this = self._parse_conjunction() self._match(TokenType.ASC) is_desc = self._match(TokenType.DESC) - is_nulls_first = self._match(TokenType.NULLS_FIRST) - is_nulls_last = self._match(TokenType.NULLS_LAST) + is_nulls_first = self._match_text_seq("NULLS", "FIRST") + is_nulls_last = self._match_text_seq("NULLS", "LAST") desc = is_desc or False asc = not desc nulls_first = is_nulls_first or False @@ -2578,7 +2637,7 @@ class Parser(metaclass=_Parser): self._match_set((TokenType.ROW, TokenType.ROWS)) - only = self._match(TokenType.ONLY) + only = self._match_text_seq("ONLY") with_ties = self._match_text_seq("WITH", "TIES") if only and with_ties: @@ -2602,13 +2661,37 @@ class Parser(metaclass=_Parser): self._match_set((TokenType.ROW, TokenType.ROWS)) return self.expression(exp.Offset, this=this, expression=count) - def _parse_lock(self) -> t.Optional[exp.Expression]: - if self._match_text_seq("FOR", "UPDATE"): - return self.expression(exp.Lock, update=True) - if self._match_text_seq("FOR", "SHARE"): - return self.expression(exp.Lock, update=False) + def _parse_locks(self) -> t.List[exp.Expression]: + # Lists are invariant, so we need to use a type hint here + locks: t.List[exp.Expression] = [] - return None + while True: + if self._match_text_seq("FOR", "UPDATE"): + update = True + elif self._match_text_seq("FOR", "SHARE") or self._match_text_seq( + "LOCK", "IN", "SHARE", "MODE" + ): + update = False + else: + break + + expressions = None + if self._match_text_seq("OF"): + expressions = self._parse_csv(lambda: self._parse_table(schema=True)) + + wait: t.Optional[bool | exp.Expression] = None + if self._match_text_seq("NOWAIT"): + wait = True + elif self._match_text_seq("WAIT"): + wait = self._parse_primary() + elif self._match_text_seq("SKIP", "LOCKED"): + wait = False + + locks.append( + self.expression(exp.Lock, update=update, expressions=expressions, wait=wait) + ) + + return locks def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match_set(self.SET_OPERATIONS): @@ -2672,7 +2755,7 @@ class Parser(metaclass=_Parser): def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: index = self._index - 1 negate = self._match(TokenType.NOT) - if self._match(TokenType.DISTINCT_FROM): + if self._match_text_seq("DISTINCT", "FROM"): klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ return self.expression(klass, this=this, expression=self._parse_expression()) @@ -2684,12 +2767,12 @@ class Parser(metaclass=_Parser): this = self.expression(exp.Is, this=this, expression=expression) return self.expression(exp.Not, this=this) if negate else this - def _parse_in(self, this: t.Optional[exp.Expression]) -> exp.Expression: + def _parse_in(self, this: t.Optional[exp.Expression], alias: bool = False) -> exp.In: unnest = self._parse_unnest() if unnest: this = self.expression(exp.In, this=this, unnest=unnest) elif self._match(TokenType.L_PAREN): - expressions = self._parse_csv(self._parse_select_or_expression) + expressions = self._parse_csv(lambda: self._parse_select_or_expression(alias=alias)) if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable): this = self.expression(exp.In, this=this, query=expressions[0]) @@ -2722,15 +2805,19 @@ class Parser(metaclass=_Parser): # Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse # each INTERVAL expression into this canonical form so it's easy to transpile - if this and isinstance(this, exp.Literal): - if this.is_number: - this = exp.Literal.string(this.name) - - # Try to not clutter Snowflake's multi-part intervals like INTERVAL '1 day, 1 year' + if this and this.is_number: + this = exp.Literal.string(this.name) + elif this and this.is_string: parts = this.name.split() - if not unit and len(parts) <= 2: - this = exp.Literal.string(seq_get(parts, 0)) - unit = self.expression(exp.Var, this=seq_get(parts, 1)) + + if len(parts) == 2: + if unit: + # this is not actually a unit, it's something else + unit = None + self._retreat(self._index - 1) + else: + this = exp.Literal.string(parts[0]) + unit = self.expression(exp.Var, this=parts[1]) return self.expression(exp.Interval, this=this, unit=unit) @@ -2783,13 +2870,22 @@ class Parser(metaclass=_Parser): if parser: return parser(self, this, data_type) return self.expression(exp.Cast, this=this, to=data_type) - if not data_type.args.get("expressions"): + if not data_type.expressions: self._retreat(index) return self._parse_column() - return data_type + return self._parse_column_ops(data_type) return this + def _parse_type_size(self) -> t.Optional[exp.Expression]: + this = self._parse_type() + if not this: + return None + + return self.expression( + exp.DataTypeSize, this=this, expression=self._parse_var(any_token=True) + ) + def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]: index = self._index @@ -2814,7 +2910,7 @@ class Parser(metaclass=_Parser): elif nested: expressions = self._parse_csv(self._parse_types) else: - expressions = self._parse_csv(self._parse_conjunction) + expressions = self._parse_csv(self._parse_type_size) if not expressions or not self._match(TokenType.R_PAREN): self._retreat(index) @@ -2858,13 +2954,14 @@ class Parser(metaclass=_Parser): value: t.Optional[exp.Expression] = None if type_token in self.TIMESTAMPS: - if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ: + if self._match_text_seq("WITH", "TIME", "ZONE") or type_token == TokenType.TIMESTAMPTZ: value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions) elif ( - self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ + self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE") + or type_token == TokenType.TIMESTAMPLTZ ): value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions) - elif self._match(TokenType.WITHOUT_TIME_ZONE): + elif self._match_text_seq("WITHOUT", "TIME", "ZONE"): if type_token == TokenType.TIME: value = exp.DataType(this=exp.DataType.Type.TIME, expressions=expressions) else: @@ -2909,7 +3006,7 @@ class Parser(metaclass=_Parser): return self._parse_column_def(this) def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - if not self._match(TokenType.AT_TIME_ZONE): + if not self._match_text_seq("AT", "TIME", "ZONE"): return this return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary()) @@ -2919,6 +3016,9 @@ class Parser(metaclass=_Parser): this = self.expression(exp.Column, this=this) elif not this: return self._parse_bracket(this) + return self._parse_column_ops(this) + + def _parse_column_ops(self, this: exp.Expression) -> exp.Expression: this = self._parse_bracket(this) while self._match_set(self.COLUMN_OPERATORS): @@ -2929,7 +3029,7 @@ class Parser(metaclass=_Parser): field = self._parse_types() if not field: self.raise_error("Expected type") - elif op: + elif op and self._curr: self._advance() value = self._prev.text field = ( @@ -2963,7 +3063,6 @@ class Parser(metaclass=_Parser): else: this = self.expression(exp.Dot, this=this, expression=field) this = self._parse_bracket(this) - return this def _parse_primary(self) -> t.Optional[exp.Expression]: @@ -2989,12 +3088,9 @@ class Parser(metaclass=_Parser): if query: expressions = [query] else: - expressions = self._parse_csv( - lambda: self._parse_alias(self._parse_conjunction(), explicit=True) - ) + expressions = self._parse_csv(self._parse_expression) - this = seq_get(expressions, 0) - self._parse_query_modifiers(this) + this = self._parse_query_modifiers(seq_get(expressions, 0)) if isinstance(this, exp.Subqueryable): this = self._parse_set_operations( @@ -3065,20 +3161,12 @@ class Parser(metaclass=_Parser): functions = self.FUNCTIONS function = functions.get(upper) - args = self._parse_csv(self._parse_lambda) - if function and not anonymous: - # Clickhouse supports function calls like foo(x, y)(z), so for these we need to also parse the - # second parameter list (i.e. "(z)") and the corresponding function will receive both arg lists. - if count_params(function) == 2: - params = None - if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN): - params = self._parse_csv(self._parse_lambda) - - this = function(args, params) - else: - this = function(args) + alias = upper in self.FUNCTIONS_WITH_ALIASED_ARGS + args = self._parse_csv(lambda: self._parse_lambda(alias=alias)) + if function and not anonymous: + this = function(args) self.validate_expression(this, args) else: this = self.expression(exp.Anonymous, this=this, expressions=args) @@ -3113,9 +3201,6 @@ class Parser(metaclass=_Parser): return self.expression(exp.Identifier, this=token.text) - def _parse_national(self, token: Token) -> exp.Expression: - return self.expression(exp.National, this=exp.Literal.string(token.text)) - def _parse_session_parameter(self) -> exp.Expression: kind = None this = self._parse_id_var() or self._parse_primary() @@ -3126,7 +3211,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.SessionParameter, this=this, kind=kind) - def _parse_lambda(self) -> t.Optional[exp.Expression]: + def _parse_lambda(self, alias: bool = False) -> t.Optional[exp.Expression]: index = self._index if self._match(TokenType.L_PAREN): @@ -3149,7 +3234,7 @@ class Parser(metaclass=_Parser): exp.Distinct, expressions=self._parse_csv(self._parse_conjunction) ) else: - this = self._parse_select_or_expression() + this = self._parse_select_or_expression(alias=alias) if isinstance(this, exp.EQ): left = this.this @@ -3161,13 +3246,15 @@ class Parser(metaclass=_Parser): def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: index = self._index - try: - if self._parse_select(nested=True): - return this - except Exception: - pass - finally: - self._retreat(index) + if not self.errors: + try: + if self._parse_select(nested=True): + return this + except ParseError: + pass + finally: + self.errors.clear() + self._retreat(index) if not self._match(TokenType.L_PAREN): return this @@ -3227,13 +3314,18 @@ class Parser(metaclass=_Parser): return self.expression(exp.CompressColumnConstraint, this=self._parse_bitwise()) def _parse_generated_as_identity(self) -> exp.Expression: - if self._match(TokenType.BY_DEFAULT): - this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=False) + if self._match_text_seq("BY", "DEFAULT"): + on_null = self._match_pair(TokenType.ON, TokenType.NULL) + this = self.expression( + exp.GeneratedAsIdentityColumnConstraint, this=False, on_null=on_null + ) else: self._match_text_seq("ALWAYS") this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True) - self._match_text_seq("AS", "IDENTITY") + self._match(TokenType.ALIAS) + identity = self._match_text_seq("IDENTITY") + if self._match(TokenType.L_PAREN): if self._match_text_seq("START", "WITH"): this.set("start", self._parse_bitwise()) @@ -3249,6 +3341,9 @@ class Parser(metaclass=_Parser): elif self._match_text_seq("NO", "CYCLE"): this.set("cycle", False) + if not identity: + this.set("expression", self._parse_bitwise()) + self._match_r_paren() return this @@ -3307,9 +3402,10 @@ class Parser(metaclass=_Parser): return self.CONSTRAINT_PARSERS[constraint](self) def _parse_unique(self) -> exp.Expression: - if not self._match(TokenType.L_PAREN, advance=False): - return self.expression(exp.UniqueColumnConstraint) - return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars()) + self._match_text_seq("KEY") + return self.expression( + exp.UniqueColumnConstraint, this=self._parse_schema(self._parse_id_var(any_token=False)) + ) def _parse_key_constraint_options(self) -> t.List[str]: options = [] @@ -3321,9 +3417,9 @@ class Parser(metaclass=_Parser): action = None on = self._advance_any() and self._prev.text - if self._match(TokenType.NO_ACTION): + if self._match_text_seq("NO", "ACTION"): action = "NO ACTION" - elif self._match(TokenType.CASCADE): + elif self._match_text_seq("CASCADE"): action = "CASCADE" elif self._match_pair(TokenType.SET, TokenType.NULL): action = "SET NULL" @@ -3348,7 +3444,7 @@ class Parser(metaclass=_Parser): return options - def _parse_references(self, match=True) -> t.Optional[exp.Expression]: + def _parse_references(self, match: bool = True) -> t.Optional[exp.Expression]: if match and not self._match(TokenType.REFERENCES): return None @@ -3372,7 +3468,7 @@ class Parser(metaclass=_Parser): kind = self._prev.text.lower() - if self._match(TokenType.NO_ACTION): + if self._match_text_seq("NO", "ACTION"): action = "NO ACTION" elif self._match(TokenType.SET): self._match_set((TokenType.NULL, TokenType.DEFAULT)) @@ -3396,11 +3492,19 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.L_PAREN, advance=False): return self.expression(exp.PrimaryKeyColumnConstraint, desc=desc) - expressions = self._parse_wrapped_id_vars() + expressions = self._parse_wrapped_csv(self._parse_field) options = self._parse_key_constraint_options() return self.expression(exp.PrimaryKey, expressions=expressions, options=options) + @t.overload + def _parse_bracket(self, this: exp.Expression) -> exp.Expression: + ... + + @t.overload def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + ... + + def _parse_bracket(self, this): if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)): return this @@ -3493,7 +3597,12 @@ class Parser(metaclass=_Parser): this = self._parse_conjunction() if not self._match(TokenType.ALIAS): - self.raise_error("Expected AS after CAST") + if self._match(TokenType.COMMA): + return self.expression( + exp.CastToStrType, this=this, expression=self._parse_string() + ) + else: + self.raise_error("Expected AS after CAST") to = self._parse_types() @@ -3524,7 +3633,7 @@ class Parser(metaclass=_Parser): # Checks if we can parse an order clause: WITHIN GROUP (ORDER BY <order_by_expression_list> [ASC | DESC]). # This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that # the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them. - if not self._match(TokenType.WITHIN_GROUP): + if not self._match_text_seq("WITHIN", "GROUP"): self._retreat(index) this = exp.GroupConcat.from_arg_list(args) self.validate_expression(this, args) @@ -3674,6 +3783,27 @@ class Parser(metaclass=_Parser): exp.MatchAgainst, this=this, expressions=expressions, modifier=modifier ) + # https://learn.microsoft.com/en-us/sql/t-sql/functions/openjson-transact-sql?view=sql-server-ver16 + def _parse_open_json(self) -> exp.Expression: + this = self._parse_bitwise() + path = self._match(TokenType.COMMA) and self._parse_string() + + def _parse_open_json_column_def() -> exp.Expression: + this = self._parse_field(any_token=True) + kind = self._parse_types() + path = self._parse_string() + as_json = self._match_pair(TokenType.ALIAS, TokenType.JSON) + return self.expression( + exp.OpenJSONColumnDef, this=this, kind=kind, path=path, as_json=as_json + ) + + expressions = None + if self._match_pair(TokenType.R_PAREN, TokenType.WITH): + self._match_l_paren() + expressions = self._parse_csv(_parse_open_json_column_def) + + return self.expression(exp.OpenJSON, this=this, path=path, expressions=expressions) + def _parse_position(self, haystack_first: bool = False) -> exp.Expression: args = self._parse_csv(self._parse_bitwise) @@ -3722,7 +3852,7 @@ class Parser(metaclass=_Parser): position = None collation = None - if self._match_set(self.TRIM_TYPES): + if self._match_texts(self.TRIM_TYPES): position = self._prev.text.upper() expression = self._parse_bitwise() @@ -3752,9 +3882,9 @@ class Parser(metaclass=_Parser): def _parse_respect_or_ignore_nulls( self, this: t.Optional[exp.Expression] ) -> t.Optional[exp.Expression]: - if self._match(TokenType.IGNORE_NULLS): + if self._match_text_seq("IGNORE", "NULLS"): return self.expression(exp.IgnoreNulls, this=this) - if self._match(TokenType.RESPECT_NULLS): + if self._match_text_seq("RESPECT", "NULLS"): return self.expression(exp.RespectNulls, this=this) return this @@ -3767,7 +3897,7 @@ class Parser(metaclass=_Parser): # T-SQL allows the OVER (...) syntax after WITHIN GROUP. # https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16 - if self._match(TokenType.WITHIN_GROUP): + if self._match_text_seq("WITHIN", "GROUP"): order = self._parse_wrapped(self._parse_order) this = self.expression(exp.WithinGroup, this=this, expression=order) @@ -3846,10 +3976,11 @@ class Parser(metaclass=_Parser): return { "value": ( - self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text - ) - or self._parse_bitwise(), - "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text, + (self._match_text_seq("UNBOUNDED") and "UNBOUNDED") + or (self._match_text_seq("CURRENT", "ROW") and "CURRENT ROW") + or self._parse_bitwise() + ), + "side": self._match_texts(self.WINDOW_SIDES) and self._prev.text, } def _parse_alias( @@ -3956,7 +4087,7 @@ class Parser(metaclass=_Parser): def _parse_parameter(self) -> exp.Expression: wrapped = self._match(TokenType.L_BRACE) - this = self._parse_var() or self._parse_primary() + this = self._parse_var() or self._parse_identifier() or self._parse_primary() self._match(TokenType.R_BRACE) return self.expression(exp.Parameter, this=this, wrapped=wrapped) @@ -4011,26 +4142,33 @@ class Parser(metaclass=_Parser): return this - def _parse_wrapped_id_vars(self) -> t.List[t.Optional[exp.Expression]]: - return self._parse_wrapped_csv(self._parse_id_var) + def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[t.Optional[exp.Expression]]: + return self._parse_wrapped_csv(self._parse_id_var, optional=optional) def _parse_wrapped_csv( - self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA + self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA, optional: bool = False ) -> t.List[t.Optional[exp.Expression]]: - return self._parse_wrapped(lambda: self._parse_csv(parse_method, sep=sep)) + return self._parse_wrapped( + lambda: self._parse_csv(parse_method, sep=sep), optional=optional + ) - def _parse_wrapped(self, parse_method: t.Callable) -> t.Any: - self._match_l_paren() + def _parse_wrapped(self, parse_method: t.Callable, optional: bool = False) -> t.Any: + wrapped = self._match(TokenType.L_PAREN) + if not wrapped and not optional: + self.raise_error("Expecting (") parse_result = parse_method() - self._match_r_paren() + if wrapped: + self._match_r_paren() return parse_result - def _parse_select_or_expression(self) -> t.Optional[exp.Expression]: - return self._parse_select() or self._parse_set_operations(self._parse_expression()) + def _parse_select_or_expression(self, alias: bool = False) -> t.Optional[exp.Expression]: + return self._parse_select() or self._parse_set_operations( + self._parse_expression() if alias else self._parse_conjunction() + ) def _parse_ddl_select(self) -> t.Optional[exp.Expression]: - return self._parse_set_operations( - self._parse_select(nested=True, parse_subquery_alias=False) + return self._parse_query_modifiers( + self._parse_set_operations(self._parse_select(nested=True, parse_subquery_alias=False)) ) def _parse_transaction(self) -> exp.Expression: @@ -4391,11 +4529,11 @@ class Parser(metaclass=_Parser): return None - def _match_l_paren(self, expression=None): + def _match_l_paren(self, expression: t.Optional[exp.Expression] = None) -> None: if not self._match(TokenType.L_PAREN, expression=expression): self.raise_error("Expecting (") - def _match_r_paren(self, expression=None): + def _match_r_paren(self, expression: t.Optional[exp.Expression] = None) -> None: if not self._match(TokenType.R_PAREN, expression=expression): self.raise_error("Expecting )") @@ -4420,6 +4558,16 @@ class Parser(metaclass=_Parser): return True + @t.overload + def _replace_columns_with_dots(self, this: exp.Expression) -> exp.Expression: + ... + + @t.overload + def _replace_columns_with_dots( + self, this: t.Optional[exp.Expression] + ) -> t.Optional[exp.Expression]: + ... + def _replace_columns_with_dots(self, this): if isinstance(this, exp.Dot): exp.replace_children(this, self._replace_columns_with_dots) @@ -4433,9 +4581,15 @@ class Parser(metaclass=_Parser): ) elif isinstance(this, exp.Identifier): this = self.expression(exp.Var, this=this.name) + return this - def _replace_lambda(self, node, lambda_variables): + def _replace_lambda( + self, node: t.Optional[exp.Expression], lambda_variables: t.Set[str] + ) -> t.Optional[exp.Expression]: + if not node: + return node + for column in node.find_all(exp.Column): if column.parts[0].name in lambda_variables: dot_or_id = column.to_dot() if column.table else column.this diff --git a/sqlglot/planner.py b/sqlglot/planner.py index 5fd96ef..eccad35 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -1,11 +1,10 @@ from __future__ import annotations -import itertools import math import typing as t from sqlglot import alias, exp -from sqlglot.errors import UnsupportedError +from sqlglot.helper import name_sequence from sqlglot.optimizer.eliminate_joins import join_condition @@ -105,13 +104,7 @@ class Step: from_ = expression.args.get("from") if isinstance(expression, exp.Select) and from_: - from_ = from_.expressions - if len(from_) > 1: - raise UnsupportedError( - "Multi-from statements are unsupported. Run it through the optimizer" - ) - - step = Scan.from_expression(from_[0], ctes) + step = Scan.from_expression(from_.this, ctes) elif isinstance(expression, exp.Union): step = SetOperation.from_expression(expression, ctes) else: @@ -128,7 +121,7 @@ class Step: projections = [] # final selects in this chain of steps representing a select operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1) aggregations = [] - sequence = itertools.count() + next_operand_name = name_sequence("_a_") def extract_agg_operands(expression): for agg in expression.find_all(exp.AggFunc): @@ -136,7 +129,7 @@ class Step: if isinstance(operand, exp.Column): continue if operand not in operands: - operands[operand] = f"_a_{next(sequence)}" + operands[operand] = next_operand_name() operand.replace(exp.column(operands[operand], quoted=True)) for e in expression.expressions: @@ -310,7 +303,7 @@ class Join(Step): for join in joins: source_key, join_key, condition = join_condition(join) step.joins[join.this.alias_or_name] = { - "side": join.side, + "side": join.side, # type: ignore "join_key": join_key, "source_key": source_key, "condition": condition, diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 5d60eb9..f1c4a09 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -5,6 +5,8 @@ import typing as t import sqlglot from sqlglot import expressions as exp +from sqlglot._typing import T +from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE from sqlglot.errors import ParseError, SchemaError from sqlglot.helper import dict_depth from sqlglot.trie import in_trie, new_trie @@ -17,62 +19,83 @@ if t.TYPE_CHECKING: TABLE_ARGS = ("this", "db", "catalog") -T = t.TypeVar("T") - class Schema(abc.ABC): """Abstract base class for database schemas""" + dialect: DialectType + @abc.abstractmethod def add_table( - self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None + self, + table: exp.Table | str, + column_mapping: t.Optional[ColumnMapping] = None, + dialect: DialectType = None, ) -> None: """ Register or update a table. Some implementing classes may require column information to also be provided. Args: - table: table expression instance or string representing the table. + table: the `Table` expression instance or string representing the table. column_mapping: a column mapping that describes the structure of the table. + dialect: the SQL dialect that will be used to parse `table` if it's a string. """ @abc.abstractmethod - def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: + def column_names( + self, + table: exp.Table | str, + only_visible: bool = False, + dialect: DialectType = None, + ) -> t.List[str]: """ Get the column names for a table. Args: table: the `Table` expression instance. only_visible: whether to include invisible columns. + dialect: the SQL dialect that will be used to parse `table` if it's a string. Returns: The list of column names. """ @abc.abstractmethod - def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType: + def get_column_type( + self, + table: exp.Table | str, + column: exp.Column, + dialect: DialectType = None, + ) -> exp.DataType: """ - Get the :class:`sqlglot.exp.DataType` type of a column in the schema. + Get the `sqlglot.exp.DataType` type of a column in the schema. Args: table: the source table. column: the target column. + dialect: the SQL dialect that will be used to parse `table` if it's a string. Returns: The resulting column type. """ @property + @abc.abstractmethod def supported_table_args(self) -> t.Tuple[str, ...]: """ Table arguments this schema support, e.g. `("this", "db", "catalog")` """ - raise NotImplementedError + + @property + def empty(self) -> bool: + """Returns whether or not the schema is empty.""" + return True class AbstractMappingSchema(t.Generic[T]): def __init__( self, - mapping: dict | None = None, + mapping: t.Optional[t.Dict] = None, ) -> None: self.mapping = mapping or {} self.mapping_trie = new_trie( @@ -80,6 +103,10 @@ class AbstractMappingSchema(t.Generic[T]): ) self._supported_table_args: t.Tuple[str, ...] = tuple() + @property + def empty(self) -> bool: + return not self.mapping + def _depth(self) -> int: return dict_depth(self.mapping) @@ -110,8 +137,10 @@ class AbstractMappingSchema(t.Generic[T]): if value == 0: return None - elif value == 1: + + if value == 1: possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1) + if len(possibilities) == 1: parts.extend(possibilities[0]) else: @@ -119,12 +148,13 @@ class AbstractMappingSchema(t.Generic[T]): if raise_on_missing: raise SchemaError(f"Ambiguous mapping for {table}: {message}.") return None - return self._nested_get(parts, raise_on_missing=raise_on_missing) - def _nested_get( + return self.nested_get(parts, raise_on_missing=raise_on_missing) + + def nested_get( self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True ) -> t.Optional[t.Any]: - return _nested_get( + return nested_get( d or self.mapping, *zip(self.supported_table_args, reversed(parts)), raise_on_missing=raise_on_missing, @@ -136,17 +166,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): Schema based on a nested mapping. Args: - schema (dict): Mapping in one of the following forms: + schema: Mapping in one of the following forms: 1. {table: {col: type}} 2. {db: {table: {col: type}}} 3. {catalog: {db: {table: {col: type}}}} 4. None - Tables will be added later - visible (dict): Optional mapping of which columns in the schema are visible. If not provided, all columns + visible: Optional mapping of which columns in the schema are visible. If not provided, all columns are assumed to be visible. The nesting should mirror that of the schema: 1. {table: set(*cols)}} 2. {db: {table: set(*cols)}}} 3. {catalog: {db: {table: set(*cols)}}}} - dialect (str): The dialect to be used for custom type mappings. + dialect: The dialect to be used for custom type mappings & parsing string arguments. + normalize: Whether to normalize identifier names according to the given dialect or not. """ def __init__( @@ -154,10 +185,13 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): schema: t.Optional[t.Dict] = None, visible: t.Optional[t.Dict] = None, dialect: DialectType = None, + normalize: bool = True, ) -> None: self.dialect = dialect self.visible = visible or {} + self.normalize = normalize self._type_mapping_cache: t.Dict[str, exp.DataType] = {} + super().__init__(self._normalize(schema or {})) @classmethod @@ -179,7 +213,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): ) def add_table( - self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None + self, + table: exp.Table | str, + column_mapping: t.Optional[ColumnMapping] = None, + dialect: DialectType = None, ) -> None: """ Register or update a table. Updates are only performed if a new column mapping is provided. @@ -187,10 +224,13 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): Args: table: the `Table` expression instance or string representing the table. column_mapping: a column mapping that describes the structure of the table. + dialect: the SQL dialect that will be used to parse `table` if it's a string. """ - normalized_table = self._normalize_table(self._ensure_table(table)) + normalized_table = self._normalize_table( + self._ensure_table(table, dialect=dialect), dialect=dialect + ) normalized_column_mapping = { - self._normalize_name(key): value + self._normalize_name(key, dialect=dialect): value for key, value in ensure_column_mapping(column_mapping).items() } @@ -200,38 +240,51 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): parts = self.table_parts(normalized_table) - _nested_set( - self.mapping, - tuple(reversed(parts)), - normalized_column_mapping, - ) + nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) new_trie([parts], self.mapping_trie) - def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]: - table_ = self._normalize_table(self._ensure_table(table)) - schema = self.find(table_) + def column_names( + self, + table: exp.Table | str, + only_visible: bool = False, + dialect: DialectType = None, + ) -> t.List[str]: + normalized_table = self._normalize_table( + self._ensure_table(table, dialect=dialect), dialect=dialect + ) + schema = self.find(normalized_table) if schema is None: return [] if not only_visible or not self.visible: return list(schema) - visible = self._nested_get(self.table_parts(table_), self.visible) - return [col for col in schema if col in visible] # type: ignore + visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] + return [col for col in schema if col in visible] - def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType: - column_name = self._normalize_name(column if isinstance(column, str) else column.this) - table_ = self._normalize_table(self._ensure_table(table)) + def get_column_type( + self, + table: exp.Table | str, + column: exp.Column, + dialect: DialectType = None, + ) -> exp.DataType: + normalized_table = self._normalize_table( + self._ensure_table(table, dialect=dialect), dialect=dialect + ) + normalized_column_name = self._normalize_name( + column if isinstance(column, str) else column.this, dialect=dialect + ) - table_schema = self.find(table_, raise_on_missing=False) + table_schema = self.find(normalized_table, raise_on_missing=False) if table_schema: - column_type = table_schema.get(column_name) + column_type = table_schema.get(normalized_column_name) if isinstance(column_type, exp.DataType): return column_type elif isinstance(column_type, str): - return self._to_data_type(column_type.upper()) + return self._to_data_type(column_type.upper(), dialect=dialect) + raise SchemaError(f"Unknown column type '{column_type}'") return exp.DataType.build("unknown") @@ -250,81 +303,88 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): normalized_mapping: t.Dict = {} for keys in flattened_schema: - columns = _nested_get(schema, *zip(keys, keys)) + columns = nested_get(schema, *zip(keys, keys)) assert columns is not None - normalized_keys = [self._normalize_name(key) for key in keys] + normalized_keys = [self._normalize_name(key, dialect=self.dialect) for key in keys] for column_name, column_type in columns.items(): - _nested_set( + nested_set( normalized_mapping, - normalized_keys + [self._normalize_name(column_name)], + normalized_keys + [self._normalize_name(column_name, dialect=self.dialect)], column_type, ) return normalized_mapping - def _normalize_table(self, table: exp.Table) -> exp.Table: + def _normalize_table(self, table: exp.Table, dialect: DialectType = None) -> exp.Table: normalized_table = table.copy() + for arg in TABLE_ARGS: value = normalized_table.args.get(arg) if isinstance(value, (str, exp.Identifier)): - normalized_table.set(arg, self._normalize_name(value)) + normalized_table.set( + arg, exp.to_identifier(self._normalize_name(value, dialect=dialect)) + ) return normalized_table - def _normalize_name(self, name: str | exp.Identifier) -> str: + def _normalize_name(self, name: str | exp.Identifier, dialect: DialectType = None) -> str: + dialect = dialect or self.dialect + try: - identifier = sqlglot.maybe_parse(name, dialect=self.dialect, into=exp.Identifier) + identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) except ParseError: return name if isinstance(name, str) else name.name - return identifier.name if identifier.quoted else identifier.name.lower() + name = identifier.name + + if not self.normalize or identifier.quoted: + return name + + return name.upper() if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE else name.lower() def _depth(self) -> int: # The columns themselves are a mapping, but we don't want to include those return super()._depth() - 1 - def _ensure_table(self, table: exp.Table | str) -> exp.Table: - if isinstance(table, exp.Table): - return table - - table_ = sqlglot.parse_one(table, read=self.dialect, into=exp.Table) - if not table_: - raise SchemaError(f"Not a valid table '{table}'") + def _ensure_table(self, table: exp.Table | str, dialect: DialectType = None) -> exp.Table: + return exp.maybe_parse(table, into=exp.Table, dialect=dialect or self.dialect) - return table_ - - def _to_data_type(self, schema_type: str) -> exp.DataType: + def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: """ - Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object. + Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. Args: schema_type: the type we want to convert. + dialect: the SQL dialect that will be used to parse `schema_type`, if needed. Returns: The resulting expression type. """ if schema_type not in self._type_mapping_cache: + dialect = dialect or self.dialect + try: - expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect) - if expression is None: - raise ValueError(f"Could not parse {schema_type}") - self._type_mapping_cache[schema_type] = expression # type: ignore + expression = exp.DataType.build(schema_type, dialect=dialect) + self._type_mapping_cache[schema_type] = expression except AttributeError: - raise SchemaError(f"Failed to convert type {schema_type}") + in_dialect = f" in dialect {dialect}" if dialect else "" + raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") return self._type_mapping_cache[schema_type] -def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema: +def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: if isinstance(schema, Schema): return schema - return MappingSchema(schema, dialect=dialect) + return MappingSchema(schema, **kwargs) def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: - if isinstance(mapping, dict): + if mapping is None: + return {} + elif isinstance(mapping, dict): return mapping elif isinstance(mapping, str): col_name_type_strs = [x.strip() for x in mapping.split(",")] @@ -334,11 +394,10 @@ def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: } # Check if mapping looks like a DataFrame StructType elif hasattr(mapping, "simpleString"): - return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore + return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} elif isinstance(mapping, list): return {x.strip(): None for x in mapping} - elif mapping is None: - return {} + raise ValueError(f"Invalid mapping provided: {type(mapping)}") @@ -353,10 +412,11 @@ def flatten_schema( tables.extend(flatten_schema(v, depth - 1, keys + [k])) elif depth == 1: tables.append(keys + [k]) + return tables -def _nested_get( +def nested_get( d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True ) -> t.Optional[t.Any]: """ @@ -378,18 +438,19 @@ def _nested_get( name = "table" if name == "this" else name raise ValueError(f"Unknown {name}: {key}") return None + return d -def _nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: +def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: """ In-place set a value for a nested dictionary Example: - >>> _nested_set({}, ["top_key", "second_key"], "value") + >>> nested_set({}, ["top_key", "second_key"], "value") {'top_key': {'second_key': 'value'}} - >>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") + >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} Args: diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 5e50b7c..ad329d2 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -51,7 +51,6 @@ class TokenType(AutoName): DOLLAR = auto() PARAMETER = auto() SESSION_PARAMETER = auto() - NATIONAL = auto() DAMP = auto() BLOCK_START = auto() @@ -72,6 +71,8 @@ class TokenType(AutoName): BIT_STRING = auto() HEX_STRING = auto() BYTE_STRING = auto() + NATIONAL_STRING = auto() + RAW_STRING = auto() # types BIT = auto() @@ -110,6 +111,7 @@ class TokenType(AutoName): TIMESTAMPTZ = auto() TIMESTAMPLTZ = auto() DATETIME = auto() + DATETIME64 = auto() DATE = auto() UUID = auto() GEOGRAPHY = auto() @@ -142,30 +144,22 @@ class TokenType(AutoName): ARRAY = auto() ASC = auto() ASOF = auto() - AT_TIME_ZONE = auto() AUTO_INCREMENT = auto() BEGIN = auto() BETWEEN = auto() - BOTH = auto() - BUCKET = auto() - BY_DEFAULT = auto() CACHE = auto() - CASCADE = auto() CASE = auto() CHARACTER_SET = auto() - CLUSTER_BY = auto() COLLATE = auto() COMMAND = auto() COMMENT = auto() COMMIT = auto() - COMPOUND = auto() CONSTRAINT = auto() CREATE = auto() CROSS = auto() CUBE = auto() CURRENT_DATE = auto() CURRENT_DATETIME = auto() - CURRENT_ROW = auto() CURRENT_TIME = auto() CURRENT_TIMESTAMP = auto() CURRENT_USER = auto() @@ -174,8 +168,6 @@ class TokenType(AutoName): DESC = auto() DESCRIBE = auto() DISTINCT = auto() - DISTINCT_FROM = auto() - DISTRIBUTE_BY = auto() DIV = auto() DROP = auto() ELSE = auto() @@ -189,7 +181,6 @@ class TokenType(AutoName): FILTER = auto() FINAL = auto() FIRST = auto() - FOLLOWING = auto() FOR = auto() FOREIGN_KEY = auto() FORMAT = auto() @@ -203,7 +194,6 @@ class TokenType(AutoName): HAVING = auto() HINT = auto() IF = auto() - IGNORE_NULLS = auto() ILIKE = auto() ILIKE_ANY = auto() IN = auto() @@ -222,36 +212,27 @@ class TokenType(AutoName): KEEP = auto() LANGUAGE = auto() LATERAL = auto() - LAZY = auto() - LEADING = auto() LEFT = auto() LIKE = auto() LIKE_ANY = auto() LIMIT = auto() - LOAD_DATA = auto() - LOCAL = auto() + LOAD = auto() + LOCK = auto() MAP = auto() MATCH_RECOGNIZE = auto() - MATERIALIZED = auto() MERGE = auto() MOD = auto() NATURAL = auto() NEXT = auto() NEXT_VALUE_FOR = auto() - NO_ACTION = auto() NOTNULL = auto() NULL = auto() - NULLS_FIRST = auto() - NULLS_LAST = auto() OFFSET = auto() ON = auto() - ONLY = auto() - OPTIONS = auto() ORDER_BY = auto() ORDERED = auto() ORDINALITY = auto() OUTER = auto() - OUT_OF = auto() OVER = auto() OVERLAPS = auto() OVERWRITE = auto() @@ -261,7 +242,6 @@ class TokenType(AutoName): PIVOT = auto() PLACEHOLDER = auto() PRAGMA = auto() - PRECEDING = auto() PRIMARY_KEY = auto() PROCEDURE = auto() PROPERTIES = auto() @@ -271,7 +251,6 @@ class TokenType(AutoName): RANGE = auto() RECURSIVE = auto() REPLACE = auto() - RESPECT_NULLS = auto() RETURNING = auto() REFERENCES = auto() RIGHT = auto() @@ -280,28 +259,23 @@ class TokenType(AutoName): ROLLUP = auto() ROW = auto() ROWS = auto() - SEED = auto() SELECT = auto() SEMI = auto() SEPARATOR = auto() SERDE_PROPERTIES = auto() SET = auto() + SETTINGS = auto() SHOW = auto() SIMILAR_TO = auto() SOME = auto() - SORTKEY = auto() - SORT_BY = auto() STRUCT = auto() TABLE_SAMPLE = auto() TEMPORARY = auto() TOP = auto() THEN = auto() - TRAILING = auto() TRUE = auto() - UNBOUNDED = auto() UNCACHE = auto() UNION = auto() - UNLOGGED = auto() UNNEST = auto() UNPIVOT = auto() UPDATE = auto() @@ -314,15 +288,11 @@ class TokenType(AutoName): WHERE = auto() WINDOW = auto() WITH = auto() - WITH_TIME_ZONE = auto() - WITH_LOCAL_TIME_ZONE = auto() - WITHIN_GROUP = auto() - WITHOUT_TIME_ZONE = auto() UNIQUE = auto() class Token: - __slots__ = ("token_type", "text", "line", "col", "end", "comments") + __slots__ = ("token_type", "text", "line", "col", "start", "end", "comments") @classmethod def number(cls, number: int) -> Token: @@ -350,22 +320,28 @@ class Token: text: str, line: int = 1, col: int = 1, + start: int = 0, end: int = 0, comments: t.List[str] = [], ) -> None: + """Token initializer. + + Args: + token_type: The TokenType Enum. + text: The text of the token. + line: The line that the token ends on. + col: The column that the token ends on. + start: The start index of the token. + end: The ending index of the token. + """ self.token_type = token_type self.text = text self.line = line - size = len(text) self.col = col - self.end = end if end else size + self.start = start + self.end = end self.comments = comments - @property - def start(self) -> int: - """Returns the start of the token.""" - return self.end - len(self.text) - def __repr__(self) -> str: attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) return f"<Token {attributes}>" @@ -375,15 +351,31 @@ class _Tokenizer(type): def __new__(cls, clsname, bases, attrs): klass = super().__new__(cls, clsname, bases, attrs) - klass._QUOTES = { - f"{prefix}{s}": e - for s, e in cls._delimeter_list_to_dict(klass.QUOTES).items() - for prefix in (("",) if s[0].isalpha() else ("", "n", "N")) + def _convert_quotes(arr: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]: + return dict( + (item, item) if isinstance(item, str) else (item[0], item[1]) for item in arr + ) + + def _quotes_to_format( + token_type: TokenType, arr: t.List[str | t.Tuple[str, str]] + ) -> t.Dict[str, t.Tuple[str, TokenType]]: + return {k: (v, token_type) for k, v in _convert_quotes(arr).items()} + + klass._QUOTES = _convert_quotes(klass.QUOTES) + klass._IDENTIFIERS = _convert_quotes(klass.IDENTIFIERS) + + klass._FORMAT_STRINGS = { + **{ + p + s: (e, TokenType.NATIONAL_STRING) + for s, e in klass._QUOTES.items() + for p in ("n", "N") + }, + **_quotes_to_format(TokenType.BIT_STRING, klass.BIT_STRINGS), + **_quotes_to_format(TokenType.BYTE_STRING, klass.BYTE_STRINGS), + **_quotes_to_format(TokenType.HEX_STRING, klass.HEX_STRINGS), + **_quotes_to_format(TokenType.RAW_STRING, klass.RAW_STRINGS), } - klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS) - klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS) - klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS) - klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS) + klass._STRING_ESCAPES = set(klass.STRING_ESCAPES) klass._IDENTIFIER_ESCAPES = set(klass.IDENTIFIER_ESCAPES) klass._COMMENTS = dict( @@ -393,23 +385,17 @@ class _Tokenizer(type): klass.KEYWORD_TRIE = new_trie( key.upper() - for key in { - **klass.KEYWORDS, - **{comment: TokenType.COMMENT for comment in klass._COMMENTS}, - **{quote: TokenType.QUOTE for quote in klass._QUOTES}, - **{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS}, - **{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS}, - **{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS}, - } + for key in ( + *klass.KEYWORDS, + *klass._COMMENTS, + *klass._QUOTES, + *klass._FORMAT_STRINGS, + ) if " " in key or any(single in key for single in klass.SINGLE_TOKENS) ) return klass - @staticmethod - def _delimeter_list_to_dict(list: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]: - return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list) - class Tokenizer(metaclass=_Tokenizer): SINGLE_TOKENS = { @@ -450,6 +436,7 @@ class Tokenizer(metaclass=_Tokenizer): BIT_STRINGS: t.List[str | t.Tuple[str, str]] = [] BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = [] HEX_STRINGS: t.List[str | t.Tuple[str, str]] = [] + RAW_STRINGS: t.List[str | t.Tuple[str, str]] = [] IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"'] IDENTIFIER_ESCAPES = ['"'] QUOTES: t.List[t.Tuple[str, str] | str] = ["'"] @@ -457,9 +444,7 @@ class Tokenizer(metaclass=_Tokenizer): VAR_SINGLE_TOKENS: t.Set[str] = set() _COMMENTS: t.Dict[str, str] = {} - _BIT_STRINGS: t.Dict[str, str] = {} - _BYTE_STRINGS: t.Dict[str, str] = {} - _HEX_STRINGS: t.Dict[str, str] = {} + _FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {} _IDENTIFIERS: t.Dict[str, str] = {} _IDENTIFIER_ESCAPES: t.Set[str] = set() _QUOTES: t.Dict[str, str] = {} @@ -495,30 +480,22 @@ class Tokenizer(metaclass=_Tokenizer): "ANY": TokenType.ANY, "ASC": TokenType.ASC, "AS": TokenType.ALIAS, - "AT TIME ZONE": TokenType.AT_TIME_ZONE, "AUTOINCREMENT": TokenType.AUTO_INCREMENT, "AUTO_INCREMENT": TokenType.AUTO_INCREMENT, "BEGIN": TokenType.BEGIN, "BETWEEN": TokenType.BETWEEN, - "BOTH": TokenType.BOTH, - "BUCKET": TokenType.BUCKET, - "BY DEFAULT": TokenType.BY_DEFAULT, "CACHE": TokenType.CACHE, "UNCACHE": TokenType.UNCACHE, "CASE": TokenType.CASE, - "CASCADE": TokenType.CASCADE, "CHARACTER SET": TokenType.CHARACTER_SET, - "CLUSTER BY": TokenType.CLUSTER_BY, "COLLATE": TokenType.COLLATE, "COLUMN": TokenType.COLUMN, "COMMIT": TokenType.COMMIT, - "COMPOUND": TokenType.COMPOUND, "CONSTRAINT": TokenType.CONSTRAINT, "CREATE": TokenType.CREATE, "CROSS": TokenType.CROSS, "CUBE": TokenType.CUBE, "CURRENT_DATE": TokenType.CURRENT_DATE, - "CURRENT ROW": TokenType.CURRENT_ROW, "CURRENT_TIME": TokenType.CURRENT_TIME, "CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP, "CURRENT_USER": TokenType.CURRENT_USER, @@ -528,8 +505,6 @@ class Tokenizer(metaclass=_Tokenizer): "DESC": TokenType.DESC, "DESCRIBE": TokenType.DESCRIBE, "DISTINCT": TokenType.DISTINCT, - "DISTINCT FROM": TokenType.DISTINCT_FROM, - "DISTRIBUTE BY": TokenType.DISTRIBUTE_BY, "DIV": TokenType.DIV, "DROP": TokenType.DROP, "ELSE": TokenType.ELSE, @@ -544,18 +519,18 @@ class Tokenizer(metaclass=_Tokenizer): "FIRST": TokenType.FIRST, "FULL": TokenType.FULL, "FUNCTION": TokenType.FUNCTION, - "FOLLOWING": TokenType.FOLLOWING, "FOR": TokenType.FOR, "FOREIGN KEY": TokenType.FOREIGN_KEY, "FORMAT": TokenType.FORMAT, "FROM": TokenType.FROM, + "GEOGRAPHY": TokenType.GEOGRAPHY, + "GEOMETRY": TokenType.GEOMETRY, "GLOB": TokenType.GLOB, "GROUP BY": TokenType.GROUP_BY, "GROUPING SETS": TokenType.GROUPING_SETS, "HAVING": TokenType.HAVING, "IF": TokenType.IF, "ILIKE": TokenType.ILIKE, - "IGNORE NULLS": TokenType.IGNORE_NULLS, "IN": TokenType.IN, "INDEX": TokenType.INDEX, "INET": TokenType.INET, @@ -569,34 +544,25 @@ class Tokenizer(metaclass=_Tokenizer): "JOIN": TokenType.JOIN, "KEEP": TokenType.KEEP, "LATERAL": TokenType.LATERAL, - "LAZY": TokenType.LAZY, - "LEADING": TokenType.LEADING, "LEFT": TokenType.LEFT, "LIKE": TokenType.LIKE, "LIMIT": TokenType.LIMIT, - "LOAD DATA": TokenType.LOAD_DATA, - "LOCAL": TokenType.LOCAL, - "MATERIALIZED": TokenType.MATERIALIZED, + "LOAD": TokenType.LOAD, + "LOCK": TokenType.LOCK, "MERGE": TokenType.MERGE, "NATURAL": TokenType.NATURAL, "NEXT": TokenType.NEXT, "NEXT VALUE FOR": TokenType.NEXT_VALUE_FOR, - "NO ACTION": TokenType.NO_ACTION, "NOT": TokenType.NOT, "NOTNULL": TokenType.NOTNULL, "NULL": TokenType.NULL, - "NULLS FIRST": TokenType.NULLS_FIRST, - "NULLS LAST": TokenType.NULLS_LAST, "OBJECT": TokenType.OBJECT, "OFFSET": TokenType.OFFSET, "ON": TokenType.ON, - "ONLY": TokenType.ONLY, - "OPTIONS": TokenType.OPTIONS, "OR": TokenType.OR, "ORDER BY": TokenType.ORDER_BY, "ORDINALITY": TokenType.ORDINALITY, "OUTER": TokenType.OUTER, - "OUT OF": TokenType.OUT_OF, "OVER": TokenType.OVER, "OVERLAPS": TokenType.OVERLAPS, "OVERWRITE": TokenType.OVERWRITE, @@ -607,7 +573,6 @@ class Tokenizer(metaclass=_Tokenizer): "PERCENT": TokenType.PERCENT, "PIVOT": TokenType.PIVOT, "PRAGMA": TokenType.PRAGMA, - "PRECEDING": TokenType.PRECEDING, "PRIMARY KEY": TokenType.PRIMARY_KEY, "PROCEDURE": TokenType.PROCEDURE, "QUALIFY": TokenType.QUALIFY, @@ -615,7 +580,6 @@ class Tokenizer(metaclass=_Tokenizer): "RECURSIVE": TokenType.RECURSIVE, "REGEXP": TokenType.RLIKE, "REPLACE": TokenType.REPLACE, - "RESPECT NULLS": TokenType.RESPECT_NULLS, "REFERENCES": TokenType.REFERENCES, "RIGHT": TokenType.RIGHT, "RLIKE": TokenType.RLIKE, @@ -624,25 +588,20 @@ class Tokenizer(metaclass=_Tokenizer): "ROW": TokenType.ROW, "ROWS": TokenType.ROWS, "SCHEMA": TokenType.SCHEMA, - "SEED": TokenType.SEED, "SELECT": TokenType.SELECT, "SEMI": TokenType.SEMI, "SET": TokenType.SET, + "SETTINGS": TokenType.SETTINGS, "SHOW": TokenType.SHOW, "SIMILAR TO": TokenType.SIMILAR_TO, "SOME": TokenType.SOME, - "SORTKEY": TokenType.SORTKEY, - "SORT BY": TokenType.SORT_BY, "TABLE": TokenType.TABLE, "TABLESAMPLE": TokenType.TABLE_SAMPLE, "TEMP": TokenType.TEMPORARY, "TEMPORARY": TokenType.TEMPORARY, "THEN": TokenType.THEN, "TRUE": TokenType.TRUE, - "TRAILING": TokenType.TRAILING, - "UNBOUNDED": TokenType.UNBOUNDED, "UNION": TokenType.UNION, - "UNLOGGED": TokenType.UNLOGGED, "UNNEST": TokenType.UNNEST, "UNPIVOT": TokenType.UNPIVOT, "UPDATE": TokenType.UPDATE, @@ -656,10 +615,6 @@ class Tokenizer(metaclass=_Tokenizer): "WHERE": TokenType.WHERE, "WINDOW": TokenType.WINDOW, "WITH": TokenType.WITH, - "WITH TIME ZONE": TokenType.WITH_TIME_ZONE, - "WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE, - "WITHIN GROUP": TokenType.WITHIN_GROUP, - "WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE, "APPLY": TokenType.APPLY, "ARRAY": TokenType.ARRAY, "BIT": TokenType.BIT, @@ -718,15 +673,6 @@ class Tokenizer(metaclass=_Tokenizer): "STRUCT": TokenType.STRUCT, "VARIANT": TokenType.VARIANT, "ALTER": TokenType.ALTER, - "ALTER AGGREGATE": TokenType.COMMAND, - "ALTER DEFAULT": TokenType.COMMAND, - "ALTER DOMAIN": TokenType.COMMAND, - "ALTER ROLE": TokenType.COMMAND, - "ALTER RULE": TokenType.COMMAND, - "ALTER SEQUENCE": TokenType.COMMAND, - "ALTER TYPE": TokenType.COMMAND, - "ALTER USER": TokenType.COMMAND, - "ALTER VIEW": TokenType.COMMAND, "ANALYZE": TokenType.COMMAND, "CALL": TokenType.COMMAND, "COMMENT": TokenType.COMMENT, @@ -790,7 +736,7 @@ class Tokenizer(metaclass=_Tokenizer): self._start = 0 self._current = 0 self._line = 1 - self._col = 1 + self._col = 0 self._comments: t.List[str] = [] self._char = "" @@ -803,13 +749,12 @@ class Tokenizer(metaclass=_Tokenizer): self.reset() self.sql = sql self.size = len(sql) + try: self._scan() except Exception as e: - start = self._current - 50 - end = self._current + 50 - start = start if start > 0 else 0 - end = end if end < self.size else self.size - 1 + start = max(self._current - 50, 0) + end = min(self._current + 50, self.size - 1) context = self.sql[start:end] raise ValueError(f"Error tokenizing '{context}'") from e @@ -834,17 +779,17 @@ class Tokenizer(metaclass=_Tokenizer): if until and until(): break - if self.tokens: + if self.tokens and self._comments: self.tokens[-1].comments.extend(self._comments) def _chars(self, size: int) -> str: if size == 1: return self._char + start = self._current - 1 end = start + size - if end <= self.size: - return self.sql[start:end] - return "" + + return self.sql[start:end] if end <= self.size else "" def _advance(self, i: int = 1, alnum: bool = False) -> None: if self.WHITE_SPACE.get(self._char) is TokenType.BREAK: @@ -859,6 +804,7 @@ class Tokenizer(metaclass=_Tokenizer): self._peek = "" if self._end else self.sql[self._current] if alnum and self._char.isalnum(): + # Here we use local variables instead of attributes for better performance _col = self._col _current = self._current _end = self._end @@ -885,11 +831,12 @@ class Tokenizer(metaclass=_Tokenizer): self.tokens.append( Token( token_type, - self._text if text is None else text, - self._line, - self._col, - self._current, - self._comments, + text=self._text if text is None else text, + line=self._line, + col=self._col, + start=self._start, + end=self._current - 1, + comments=self._comments, ) ) self._comments = [] @@ -929,6 +876,7 @@ class Tokenizer(metaclass=_Tokenizer): break if result == 2: word = chars + size += 1 end = self._current - 1 + size @@ -946,6 +894,7 @@ class Tokenizer(metaclass=_Tokenizer): else: skip = True else: + char = "" chars = " " word = None if not single_token and chars[-1] not in self.WHITE_SPACE else word @@ -959,8 +908,6 @@ class Tokenizer(metaclass=_Tokenizer): if self._scan_string(word): return - if self._scan_formatted_string(word): - return if self._scan_comment(word): return @@ -1004,9 +951,9 @@ class Tokenizer(metaclass=_Tokenizer): if self._char == "0": peek = self._peek.upper() if peek == "B": - return self._scan_bits() if self._BIT_STRINGS else self._add(TokenType.NUMBER) + return self._scan_bits() if self.BIT_STRINGS else self._add(TokenType.NUMBER) elif peek == "X": - return self._scan_hex() if self._HEX_STRINGS else self._add(TokenType.NUMBER) + return self._scan_hex() if self.HEX_STRINGS else self._add(TokenType.NUMBER) decimal = False scientific = 0 @@ -1075,37 +1022,24 @@ class Tokenizer(metaclass=_Tokenizer): return self._text - def _scan_string(self, quote: str) -> bool: - quote_end = self._QUOTES.get(quote) - if quote_end is None: - return False + def _scan_string(self, start: str) -> bool: + base = None + token_type = TokenType.STRING - self._advance(len(quote)) - text = self._extract_string(quote_end) - text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text - self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text) - return True + if start in self._QUOTES: + end = self._QUOTES[start] + elif start in self._FORMAT_STRINGS: + end, token_type = self._FORMAT_STRINGS[start] - # X'1234', b'0110', E'\\\\\' etc. - def _scan_formatted_string(self, string_start: str) -> bool: - if string_start in self._HEX_STRINGS: - delimiters = self._HEX_STRINGS - token_type = TokenType.HEX_STRING - base = 16 - elif string_start in self._BIT_STRINGS: - delimiters = self._BIT_STRINGS - token_type = TokenType.BIT_STRING - base = 2 - elif string_start in self._BYTE_STRINGS: - delimiters = self._BYTE_STRINGS - token_type = TokenType.BYTE_STRING - base = None + if token_type == TokenType.HEX_STRING: + base = 16 + elif token_type == TokenType.BIT_STRING: + base = 2 else: return False - self._advance(len(string_start)) - string_end = delimiters[string_start] - text = self._extract_string(string_end) + self._advance(len(start)) + text = self._extract_string(end) if base: try: @@ -1114,6 +1048,8 @@ class Tokenizer(metaclass=_Tokenizer): raise RuntimeError( f"Numeric string contains invalid characters from {self._line}:{self._start}" ) + else: + text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text self._add(token_type, text) return True diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 3643cd7..a1ec1bd 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -3,7 +3,7 @@ from __future__ import annotations import typing as t from sqlglot import expressions as exp -from sqlglot.helper import find_new_name +from sqlglot.helper import find_new_name, name_sequence if t.TYPE_CHECKING: from sqlglot.generator import Generator @@ -63,16 +63,17 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: distinct_cols = expression.args["distinct"].pop().args["on"].expressions outer_selects = expression.selects row_number = find_new_name(expression.named_selects, "_row_number") - window = exp.Window( - this=exp.RowNumber(), - partition_by=distinct_cols, - ) + window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) order = expression.args.get("order") + if order: window.set("order", order.pop().copy()) + window = exp.alias_(window, row_number) expression.select(window, copy=False) + return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') + return expression @@ -93,7 +94,7 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression: for select in expression.selects: if not select.alias_or_name: alias = find_new_name(taken, "_c") - select.replace(exp.alias_(select.copy(), alias)) + select.replace(exp.alias_(select, alias)) taken.add(alias) outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) @@ -102,8 +103,9 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression: for expr in qualify_filters.find_all((exp.Window, exp.Column)): if isinstance(expr, exp.Window): alias = find_new_name(expression.named_selects, "_w") - expression.select(exp.alias_(expr.copy(), alias), copy=False) + expression.select(exp.alias_(expr, alias), copy=False) column = exp.column(alias) + if isinstance(expr.parent, exp.Qualify): qualify_filters = column else: @@ -123,6 +125,7 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr """ for node in expression.find_all(exp.DataType): node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)]) + return expression @@ -147,6 +150,7 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression: alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore ), ) + return expression @@ -156,7 +160,10 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression: from sqlglot.optimizer.scope import build_scope taken_select_names = set(expression.named_selects) - taken_source_names = set(build_scope(expression).selected_sources) + scope = build_scope(expression) + if not scope: + return expression + taken_source_names = set(scope.selected_sources) for select in expression.selects: to_replace = select @@ -226,6 +233,7 @@ def remove_target_from_merge(expression: exp.Expression) -> exp.Expression: else node, copy=False, ) + return expression @@ -242,12 +250,20 @@ def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expre return expression -def unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: - if isinstance(expression, exp.Pivot): - expression.args["field"].transform( - lambda node: exp.column(node.output_name) if isinstance(node, exp.Column) else node, - copy=False, - ) +def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.With) and expression.recursive: + next_name = name_sequence("_c_") + + for cte in expression.expressions: + if not cte.args["alias"].columns: + query = cte.this + if isinstance(query, exp.Union): + query = query.this + + cte.args["alias"].set( + "columns", + [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], + ) return expression |