From 66af5c6fc22f6f11e9ea807b274e011a6f64efb7 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 19 Mar 2023 11:22:09 +0100 Subject: Merging upstream version 11.4.1. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 2 +- sqlglot/__main__.py | 19 ++- sqlglot/dataframe/sql/dataframe.py | 2 +- sqlglot/dialects/bigquery.py | 2 - sqlglot/dialects/clickhouse.py | 40 ++++- sqlglot/dialects/databricks.py | 1 + sqlglot/dialects/dialect.py | 34 +++- sqlglot/dialects/drill.py | 1 + sqlglot/dialects/duckdb.py | 7 + sqlglot/dialects/hive.py | 43 +++-- sqlglot/dialects/mysql.py | 72 +-------- sqlglot/dialects/postgres.py | 11 +- sqlglot/dialects/presto.py | 28 +++- sqlglot/dialects/redshift.py | 12 +- sqlglot/dialects/snowflake.py | 10 +- sqlglot/dialects/spark.py | 1 + sqlglot/dialects/sqlite.py | 58 ++++--- sqlglot/dialects/starrocks.py | 12 ++ sqlglot/dialects/tsql.py | 6 +- sqlglot/expressions.py | 251 +++++++++++++++++++++++++----- sqlglot/generator.py | 31 +++- sqlglot/helper.py | 17 ++ sqlglot/optimizer/canonicalize.py | 28 +++- sqlglot/optimizer/pushdown_projections.py | 17 +- sqlglot/optimizer/qualify_columns.py | 22 ++- sqlglot/optimizer/qualify_tables.py | 8 +- sqlglot/optimizer/scope.py | 5 + sqlglot/parser.py | 157 ++++++++++++++----- sqlglot/tokens.py | 8 +- sqlglot/transforms.py | 46 +++++- 30 files changed, 700 insertions(+), 251 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 4a30008..10046d1 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -47,7 +47,7 @@ if t.TYPE_CHECKING: T = t.TypeVar("T", bound=Expression) -__version__ = "11.3.6" +__version__ = "11.4.1" pretty = False """Whether to format generated SQL by default.""" diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py index f9613b2..f3433d3 100644 --- a/sqlglot/__main__.py +++ b/sqlglot/__main__.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import argparse import sys +import typing as t import sqlglot @@ -42,6 +45,12 @@ parser.add_argument( action="store_true", help="Parse and return the expression tree", ) +parser.add_argument( + "--tokenize", + dest="tokenize", + action="store_true", + help="Tokenize and return the tokens list", +) parser.add_argument( "--error-level", dest="error_level", @@ -57,7 +66,7 @@ error_level = sqlglot.ErrorLevel[args.error_level.upper()] sql = sys.stdin.read() if args.sql == "-" else args.sql if args.parse: - sqls = [ + objs: t.Union[t.List[str], t.List[sqlglot.tokens.Token]] = [ repr(expression) for expression in sqlglot.parse( sql, @@ -65,8 +74,10 @@ if args.parse: error_level=error_level, ) ] +elif args.tokenize: + objs = sqlglot.Dialect.get_or_raise(args.read)().tokenize(sql) else: - sqls = sqlglot.transpile( + objs = sqlglot.transpile( sql, read=args.read, write=args.write, @@ -75,5 +86,5 @@ else: error_level=error_level, ) -for sql in sqls: - print(sql) +for obj in objs: + print(obj) diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 32ee927..93bdf75 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -299,7 +299,7 @@ class DataFrame: for expression_type, select_expression in select_expressions: select_expression = select_expression.transform(replace_id_value, replacement_mapping) if optimize: - select_expression = optimize_func(select_expression) + select_expression = optimize_func(select_expression, identify="always") select_expression = df._replace_cte_names_with_hashes(select_expression) expression: t.Union[exp.Select, exp.Cache, exp.Drop] if expression_type == exp.Cache: diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 0c2105b..6a43846 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -144,7 +144,6 @@ class BigQuery(Dialect): "BEGIN": TokenType.COMMAND, "BEGIN TRANSACTION": TokenType.BEGIN, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, - "CURRENT_TIME": TokenType.CURRENT_TIME, "DECLARE": TokenType.COMMAND, "GEOGRAPHY": TokenType.GEOGRAPHY, "FLOAT64": TokenType.DOUBLE, @@ -194,7 +193,6 @@ class BigQuery(Dialect): NO_PAREN_FUNCTIONS = { **parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore TokenType.CURRENT_DATETIME: exp.CurrentDatetime, - TokenType.CURRENT_TIME: exp.CurrentTime, } NESTED_TYPE_TOKENS = { diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index b553df2..b54a77d 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -5,6 +5,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql from sqlglot.errors import ParseError +from sqlglot.helper import ensure_list, seq_get from sqlglot.parser import parse_var_map from sqlglot.tokens import TokenType @@ -40,7 +41,18 @@ class ClickHouse(Dialect): class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore + "EXPONENTIALTIMEDECAYEDAVG": lambda params, args: exp.ExponentialTimeDecayedAvg( + this=seq_get(args, 0), + time=seq_get(args, 1), + decay=seq_get(params, 0), + ), "MAP": parse_var_map, + "HISTOGRAM": lambda params, args: exp.Histogram( + this=seq_get(args, 0), bins=seq_get(params, 0) + ), + "GROUPUNIQARRAY": lambda params, args: exp.GroupUniqArray( + this=seq_get(args, 0), size=seq_get(params, 0) + ), "QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params), "QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args), "QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args), @@ -113,22 +125,40 @@ class ClickHouse(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore exp.Array: inline_array_sql, - exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", + exp.ExponentialTimeDecayedAvg: lambda self, e: f"exponentialTimeDecayedAvg{self._param_args_sql(e, 'decay', ['this', 'time'])}", exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", + exp.GroupUniqArray: lambda self, e: f"groupUniqArray{self._param_args_sql(e, 'size', 'this')}", + exp.Histogram: lambda self, e: f"histogram{self._param_args_sql(e, 'bins', 'this')}", exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)), - exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}", exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}", exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}", + exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", + exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), } EXPLICIT_UNION = True def _param_args_sql( - self, expression: exp.Expression, params_name: str, args_name: str + self, + expression: exp.Expression, + param_names: str | t.List[str], + arg_names: str | t.List[str], ) -> str: - params = self.format_args(self.expressions(expression, params_name)) - args = self.format_args(self.expressions(expression, args_name)) + params = self.format_args( + *( + arg + for name in ensure_list(param_names) + for arg in ensure_list(expression.args.get(name)) + ) + ) + args = self.format_args( + *( + arg + for name in ensure_list(arg_names) + for arg in ensure_list(expression.args.get(name)) + ) + ) return f"({params})({args})" def cte_sql(self, expression: exp.CTE) -> str: diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 4ff3594..4268f1b 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -23,6 +23,7 @@ class Databricks(Spark): exp.DateDiff: generate_date_delta_with_unit_sql, exp.ToChar: lambda self, e: self.function_fallback_sql(e), } + TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation PARAMETER_TOKEN = "$" diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 25490cb..b267521 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -8,7 +8,7 @@ from sqlglot.generator import Generator from sqlglot.helper import flatten, seq_get from sqlglot.parser import Parser from sqlglot.time import format_time -from sqlglot.tokens import Tokenizer +from sqlglot.tokens import Token, Tokenizer from sqlglot.trie import new_trie E = t.TypeVar("E", bound=exp.Expression) @@ -160,12 +160,12 @@ class Dialect(metaclass=_Dialect): return expression def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: - return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql) + return self.parser(**opts).parse(self.tokenize(sql), sql) def parse_into( self, expression_type: exp.IntoType, sql: str, **opts ) -> t.List[t.Optional[exp.Expression]]: - return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql) + return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: return self.generator(**opts).generate(expression) @@ -173,6 +173,9 @@ class Dialect(metaclass=_Dialect): def transpile(self, sql: str, **opts) -> t.List[str]: return [self.generate(expression, **opts) for expression in self.parse(sql)] + def tokenize(self, sql: str) -> t.List[Token]: + return self.tokenizer.tokenize(sql) + @property def tokenizer(self) -> Tokenizer: if not hasattr(self, "_tokenizer"): @@ -385,6 +388,21 @@ def parse_date_delta( return inner_func +def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc: + unit = seq_get(args, 0) + this = seq_get(args, 1) + + if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE): + return exp.DateTrunc(unit=unit, this=this) + return exp.TimestampTrunc(this=this, unit=unit) + + +def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: + return self.func( + "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this + ) + + def locate_to_strposition(args: t.Sequence) -> exp.Expression: return exp.StrPosition( this=seq_get(args, 1), @@ -412,6 +430,16 @@ def min_or_least(self: Generator, expression: exp.Min) -> str: return rename_func(name)(self, expression) +def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: + cond = expression.this + + if isinstance(expression.this, exp.Distinct): + cond = expression.this.expressions[0] + self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") + + return self.func("sum", exp.func("if", cond, 1, 0)) + + def trim_sql(self: Generator, expression: exp.Trim) -> str: target = self.sql(expression, "this") trim_type = self.sql(expression, "position") diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 208e2ab..dc0e519 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -97,6 +97,7 @@ class Drill(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore + "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "drill"), "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, "TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"), } diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 43f538c..f1d2266 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -14,6 +14,7 @@ from sqlglot.dialects.dialect import ( rename_func, str_position_sql, str_to_time_sql, + timestamptrunc_sql, timestrtotime_sql, ts_or_ds_to_date_sql, ) @@ -148,6 +149,9 @@ class DuckDB(Dialect): exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySort: _array_sort_sql, exp.ArraySum: rename_func("LIST_SUM"), + exp.DayOfMonth: rename_func("DAYOFMONTH"), + exp.DayOfWeek: rename_func("DAYOFWEEK"), + exp.DayOfYear: rename_func("DAYOFYEAR"), exp.DataType: _datatype_sql, exp.DateAdd: _date_add, exp.DateDiff: lambda self, e: self.func( @@ -162,6 +166,7 @@ class DuckDB(Dialect): exp.JSONBExtract: arrow_json_extract_sql, exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.LogicalOr: rename_func("BOOL_OR"), + exp.LogicalAnd: rename_func("BOOL_AND"), exp.Pivot: no_pivot_sql, exp.Properties: no_properties_sql, exp.RegexpExtract: _regexp_extract_sql, @@ -175,6 +180,7 @@ class DuckDB(Dialect): exp.StrToTime: str_to_time_sql, exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))", exp.Struct: _struct_sql, + exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))", @@ -186,6 +192,7 @@ class DuckDB(Dialect): exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})", exp.UnixToTime: rename_func("TO_TIMESTAMP"), exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)", + exp.WeekOfYear: rename_func("WEEKOFYEAR"), } TYPE_MAPPING = { diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index c4b8fa9..0110eee 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, @@ -35,7 +37,7 @@ DATE_DELTA_INTERVAL = { DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") -def _add_date_sql(self, expression): +def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str: unit = expression.text("unit").upper() func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) modified_increment = ( @@ -47,7 +49,7 @@ def _add_date_sql(self, expression): return self.func(func, expression.this, modified_increment.this) -def _date_diff_sql(self, expression): +def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: unit = expression.text("unit").upper() sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF" _, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1)) @@ -56,21 +58,21 @@ def _date_diff_sql(self, expression): return f"{diff_sql}{multiplier_sql}" -def _array_sort(self, expression): +def _array_sort(self: generator.Generator, expression: exp.ArraySort) -> str: if expression.expression: self.unsupported("Hive SORT_ARRAY does not support a comparator") return f"SORT_ARRAY({self.sql(expression, 'this')})" -def _property_sql(self, expression): +def _property_sql(self: generator.Generator, expression: exp.Property) -> str: return f"'{expression.name}'={self.sql(expression, 'value')}" -def _str_to_unix(self, expression): +def _str_to_unix(self: generator.Generator, expression: exp.StrToUnix) -> str: return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression)) -def _str_to_date(self, expression): +def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format not in (Hive.time_format, Hive.date_format): @@ -78,7 +80,7 @@ def _str_to_date(self, expression): return f"CAST({this} AS DATE)" -def _str_to_time(self, expression): +def _str_to_time(self: generator.Generator, expression: exp.StrToTime) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format not in (Hive.time_format, Hive.date_format): @@ -86,20 +88,22 @@ def _str_to_time(self, expression): return f"CAST({this} AS TIMESTAMP)" -def _time_format(self, expression): +def _time_format( + self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix +) -> t.Optional[str]: time_format = self.format_time(expression) if time_format == Hive.time_format: return None return time_format -def _time_to_str(self, expression): +def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) return f"DATE_FORMAT({this}, {time_format})" -def _to_date_sql(self, expression): +def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format and time_format not in (Hive.time_format, Hive.date_format): @@ -107,7 +111,7 @@ def _to_date_sql(self, expression): return f"TO_DATE({this})" -def _unnest_to_explode_sql(self, expression): +def _unnest_to_explode_sql(self: generator.Generator, expression: exp.Join) -> str: unnest = expression.this if isinstance(unnest, exp.Unnest): alias = unnest.args.get("alias") @@ -117,7 +121,7 @@ def _unnest_to_explode_sql(self, expression): exp.Lateral( this=udtf(this=expression), view=True, - alias=exp.TableAlias(this=alias.this, columns=[column]), + alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore ) ) for expression, column in zip(unnest.expressions, alias.columns if alias else []) @@ -125,7 +129,7 @@ def _unnest_to_explode_sql(self, expression): return self.join_sql(expression) -def _index_sql(self, expression): +def _index_sql(self: generator.Generator, expression: exp.Index) -> str: this = self.sql(expression, "this") table = self.sql(expression, "table") columns = self.sql(expression, "columns") @@ -263,14 +267,15 @@ class Hive(Dialect): exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.DATETIME: "TIMESTAMP", exp.DataType.Type.VARBINARY: "BINARY", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", } TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore + **transforms.ELIMINATE_QUALIFY, # type: ignore exp.Property: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, - exp.ArrayAgg: rename_func("COLLECT_LIST"), exp.ArrayConcat: rename_func("CONCAT"), exp.ArraySize: rename_func("SIZE"), exp.ArraySort: _array_sort, @@ -333,13 +338,19 @@ class Hive(Dialect): exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA, } - def with_properties(self, properties): + def arrayagg_sql(self, expression: exp.ArrayAgg) -> str: + return self.func( + "COLLECT_LIST", + expression.this.this if isinstance(expression.this, exp.Order) else expression.this, + ) + + def with_properties(self, properties: exp.Properties) -> str: return self.properties( properties, prefix=self.seg("TBLPROPERTIES"), ) - def datatype_sql(self, expression): + def datatype_sql(self, expression: exp.DataType) -> str: if ( expression.this in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR) and not expression.expressions diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index a831235..1e2cfa3 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -177,7 +177,7 @@ class MySQL(Dialect): "@@": TokenType.SESSION_PARAMETER, } - COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW} + COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW} class Parser(parser.Parser): FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} # type: ignore @@ -211,7 +211,6 @@ class MySQL(Dialect): STATEMENT_PARSERS = { **parser.Parser.STATEMENT_PARSERS, # type: ignore TokenType.SHOW: lambda self: self._parse_show(), - TokenType.SET: lambda self: self._parse_set(), } SHOW_PARSERS = { @@ -269,15 +268,12 @@ class MySQL(Dialect): } SET_PARSERS = { - "GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"), + **parser.Parser.SET_PARSERS, "PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"), "PERSIST_ONLY": lambda self: self._parse_set_item_assignment("PERSIST_ONLY"), - "SESSION": lambda self: self._parse_set_item_assignment("SESSION"), - "LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"), "CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"), "CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"), "NAMES": lambda self: self._parse_set_item_names(), - "TRANSACTION": lambda self: self._parse_set_transaction(), } PROFILE_TYPES = { @@ -292,15 +288,6 @@ class MySQL(Dialect): "SWAPS", } - TRANSACTION_CHARACTERISTICS = { - "ISOLATION LEVEL REPEATABLE READ", - "ISOLATION LEVEL READ COMMITTED", - "ISOLATION LEVEL READ UNCOMMITTED", - "ISOLATION LEVEL SERIALIZABLE", - "READ WRITE", - "READ ONLY", - } - def _parse_show_mysql(self, this, target=False, full=None, global_=None): if target: if isinstance(target, str): @@ -354,12 +341,6 @@ class MySQL(Dialect): **{"global": global_}, ) - def _parse_var_from_options(self, options): - for option in options: - if self._match_text_seq(*option.split(" ")): - return exp.Var(this=option) - return None - def _parse_oldstyle_limit(self): limit = None offset = None @@ -372,30 +353,6 @@ class MySQL(Dialect): offset = parts[0] return offset, limit - def _default_parse_set_item(self): - return self._parse_set_item_assignment(kind=None) - - def _parse_set_item_assignment(self, kind): - if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"): - return self._parse_set_transaction(global_=kind == "GLOBAL") - - left = self._parse_primary() or self._parse_id_var() - if not self._match(TokenType.EQ): - self.raise_error("Expected =") - right = self._parse_statement() or self._parse_id_var() - - this = self.expression( - exp.EQ, - this=left, - expression=right, - ) - - return self.expression( - exp.SetItem, - this=this, - kind=kind, - ) - def _parse_set_item_charset(self, kind): this = self._parse_string() or self._parse_id_var() @@ -418,18 +375,6 @@ class MySQL(Dialect): kind="NAMES", ) - def _parse_set_transaction(self, global_=False): - self._match_text_seq("TRANSACTION") - characteristics = self._parse_csv( - lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS) - ) - return self.expression( - exp.SetItem, - expressions=characteristics, - kind="TRANSACTION", - **{"global": global_}, - ) - class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = False @@ -523,16 +468,3 @@ class MySQL(Dialect): limit_offset = f"{offset}, {limit}" if offset else limit return f" LIMIT {limit_offset}" return "" - - def setitem_sql(self, expression): - kind = self.sql(expression, "kind") - kind = f"{kind} " if kind else "" - this = self.sql(expression, "this") - expressions = self.expressions(expression) - collate = self.sql(expression, "collate") - collate = f" COLLATE {collate}" if collate else "" - global_ = "GLOBAL " if expression.args.get("global") else "" - return f"{global_}{kind}{this}{expressions}{collate}" - - def set_sql(self, expression): - return f"SET {self.expressions(expression)}" diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index d7cbac4..5f556a5 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import ( no_trycast_sql, rename_func, str_position_sql, + timestamptrunc_sql, trim_sql, ) from sqlglot.helper import seq_get @@ -34,7 +35,7 @@ def _date_add_sql(kind): from sqlglot.optimizer.simplify import simplify this = self.sql(expression, "this") - unit = self.sql(expression, "unit") + unit = expression.args.get("unit") expression = simplify(expression.args["expression"]) if not isinstance(expression, exp.Literal): @@ -92,8 +93,7 @@ def _string_agg_sql(self, expression): this = expression.this if isinstance(this, exp.Order): if this.this: - this = this.this - this.pop() + this = this.this.pop() order = self.sql(expression.this) # Order has a leading space return f"STRING_AGG({self.format_args(this, separator)}{order})" @@ -256,6 +256,9 @@ class Postgres(Dialect): "TO_TIMESTAMP": _to_timestamp, "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), "GENERATE_SERIES": _generate_series, + "DATE_TRUNC": lambda args: exp.TimestampTrunc( + this=seq_get(args, 1), unit=seq_get(args, 0) + ), } BITWISE = { @@ -311,6 +314,7 @@ class Postgres(Dialect): exp.DateSub: _date_add_sql("-"), exp.DateDiff: _date_diff_sql, exp.LogicalOr: rename_func("BOOL_OR"), + exp.LogicalAnd: rename_func("BOOL_AND"), exp.Min: min_or_least, exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), exp.ArrayContains: lambda self, e: self.binary(e, "@>"), @@ -320,6 +324,7 @@ class Postgres(Dialect): exp.StrPosition: str_position_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Substring: _substring_sql, + exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TableSample: no_tablesample_sql, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index aef9de3..07e8f43 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -3,12 +3,14 @@ from __future__ import annotations from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + date_trunc_to_time, format_time_lambda, if_sql, no_ilike_sql, no_safe_divide_sql, rename_func, struct_extract_sql, + timestamptrunc_sql, timestrtotime_sql, ) from sqlglot.dialects.mysql import MySQL @@ -98,10 +100,16 @@ def _ts_or_ds_to_date_sql(self, expression): def _ts_or_ds_add_sql(self, expression): - this = self.sql(expression, "this") - e = self.sql(expression, "expression") - unit = self.sql(expression, "unit") or "'day'" - return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))" + return self.func( + "DATE_ADD", + exp.Literal.string(expression.text("unit") or "day"), + expression.expression, + self.func( + "DATE_PARSE", + self.func("SUBSTR", expression.this, exp.Literal.number(1), exp.Literal.number(10)), + Presto.date_format, + ), + ) def _sequence_sql(self, expression): @@ -195,6 +203,7 @@ class Presto(Dialect): ), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), + "DATE_TRUNC": date_trunc_to_time, "FROM_UNIXTIME": _from_unixtime, "NOW": exp.CurrentTimestamp.from_arg_list, "STRPOS": lambda args: exp.StrPosition( @@ -237,6 +246,7 @@ class Presto(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore + **transforms.ELIMINATE_QUALIFY, # type: ignore exp.ApproxDistinct: _approx_distinct_sql, exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", exp.ArrayConcat: rename_func("CONCAT"), @@ -250,8 +260,12 @@ class Presto(Dialect): exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DataType: _datatype_sql, - exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", - exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", + exp.DateAdd: lambda self, e: self.func( + "DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this + ), + exp.DateDiff: lambda self, e: self.func( + "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this + ), exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)", exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)", exp.Decode: _decode_sql, @@ -265,6 +279,7 @@ class Presto(Dialect): exp.Lateral: _explode_to_unnest_sql, exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), exp.LogicalOr: rename_func("BOOL_OR"), + exp.LogicalAnd: rename_func("BOOL_AND"), exp.Quantile: _quantile_sql, exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.SafeDivide: no_safe_divide_sql, @@ -277,6 +292,7 @@ class Presto(Dialect): exp.StructExtract: struct_extract_sql, exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'", exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", + exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToDate: timestrtotime_sql, exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))", diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index dc881b9..ebd5216 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -20,6 +20,11 @@ class Redshift(Postgres): class Parser(Postgres.Parser): FUNCTIONS = { **Postgres.Parser.FUNCTIONS, # type: ignore + "DATEADD": lambda args: exp.DateAdd( + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=seq_get(args, 0), + ), "DATEDIFF": lambda args: exp.DateDiff( this=seq_get(args, 2), expression=seq_get(args, 1), @@ -76,13 +81,16 @@ class Redshift(Postgres): TRANSFORMS = { **Postgres.Generator.TRANSFORMS, # type: ignore **transforms.ELIMINATE_DISTINCT_ON, # type: ignore + exp.DateAdd: lambda self, e: self.func( + "DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this + ), exp.DateDiff: lambda self, e: self.func( - "DATEDIFF", e.args.get("unit") or "day", e.expression, e.this + "DATEDIFF", exp.var(e.text("unit") or "day"), e.expression, e.this ), exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})", - exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", exp.DistStyleProperty: lambda self, e: self.naked_property(e), exp.Matches: rename_func("DECODE"), + exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", } # Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 9b159a4..799e9a6 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -5,11 +5,13 @@ import typing as t from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, + date_trunc_to_time, datestrtodate_sql, format_time_lambda, inline_array_sql, min_or_least, rename_func, + timestamptrunc_sql, timestrtotime_sql, ts_or_ds_to_date_sql, var_map_sql, @@ -176,6 +178,7 @@ class Snowflake(Dialect): "ARRAYAGG": exp.ArrayAgg.from_arg_list, "ARRAY_CONSTRUCT": exp.Array.from_arg_list, "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list, + "DATE_TRUNC": date_trunc_to_time, "DATEADD": lambda args: exp.DateAdd( this=seq_get(args, 2), expression=seq_get(args, 1), @@ -186,10 +189,6 @@ class Snowflake(Dialect): expression=seq_get(args, 1), unit=seq_get(args, 0), ), - "DATE_TRUNC": lambda args: exp.DateTrunc( - unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore - this=seq_get(args, 1), - ), "DECODE": exp.Matches.from_arg_list, "DIV0": _div0_to_if, "IFF": exp.If.from_arg_list, @@ -280,6 +279,8 @@ class Snowflake(Dialect): exp.DataType: _datatype_sql, exp.If: rename_func("IFF"), exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), + exp.LogicalOr: rename_func("BOOLOR_AGG"), + exp.LogicalAnd: rename_func("BOOLAND_AGG"), exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.Matches: rename_func("DECODE"), @@ -287,6 +288,7 @@ class Snowflake(Dialect): "POSITION", e.args.get("substr"), e.this, e.args.get("position") ), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 05ee53f..c271f6f 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -157,6 +157,7 @@ class Spark(Hive): exp.VariancePop: rename_func("VAR_POP"), exp.DateFromParts: rename_func("MAKE_DATE"), exp.LogicalOr: rename_func("BOOL_OR"), + exp.LogicalAnd: rename_func("BOOL_AND"), exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfMonth: rename_func("DAYOFMONTH"), exp.DayOfYear: rename_func("DAYOFYEAR"), diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index ed7c741..ab78b6e 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -1,10 +1,11 @@ from __future__ import annotations -from sqlglot import exp, generator, parser, tokens +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, arrow_json_extract_scalar_sql, arrow_json_extract_sql, + count_if_to_sum, no_ilike_sql, no_tablesample_sql, no_trycast_sql, @@ -13,23 +14,6 @@ from sqlglot.dialects.dialect import ( from sqlglot.tokens import TokenType -# https://www.sqlite.org/lang_aggfunc.html#group_concat -def _group_concat_sql(self, expression): - this = expression.this - distinct = expression.find(exp.Distinct) - if distinct: - this = distinct.expressions[0] - distinct = "DISTINCT " - - if isinstance(expression.this, exp.Order): - self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.") - if expression.this.this and not distinct: - this = expression.this.this - - separator = expression.args.get("separator") - return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})" - - def _date_add_sql(self, expression): modifier = expression.expression modifier = expression.name if modifier.is_string else self.sql(modifier) @@ -78,20 +62,32 @@ class SQLite(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore + **transforms.ELIMINATE_QUALIFY, # type: ignore + exp.CountIf: count_if_to_sum, + exp.CurrentDate: lambda *_: "CURRENT_DATE", + exp.CurrentTime: lambda *_: "CURRENT_TIME", + exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: _date_add_sql, + exp.DateStrToDate: lambda self, e: self.sql(e, "this"), exp.ILike: no_ilike_sql, exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: arrow_json_extract_sql, exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, exp.Levenshtein: rename_func("EDITDIST3"), + exp.LogicalOr: rename_func("MAX"), + exp.LogicalAnd: rename_func("MIN"), exp.TableSample: no_tablesample_sql, - exp.DateStrToDate: lambda self, e: self.sql(e, "this"), exp.TimeStrToTime: lambda self, e: self.sql(e, "this"), exp.TryCast: no_trycast_sql, - exp.GroupConcat: _group_concat_sql, } + def cast_sql(self, expression: exp.Cast) -> str: + if expression.to.this == exp.DataType.Type.DATE: + return self.func("DATE", expression.this) + + return super().cast_sql(expression) + def datediff_sql(self, expression: exp.DateDiff) -> str: unit = expression.args.get("unit") unit = unit.name.upper() if unit else "DAY" @@ -119,16 +115,32 @@ class SQLite(Dialect): return f"CAST({sql} AS INTEGER)" - def fetch_sql(self, expression): + def fetch_sql(self, expression: exp.Fetch) -> str: return self.limit_sql(exp.Limit(expression=expression.args.get("count"))) - def least_sql(self, expression): + # https://www.sqlite.org/lang_aggfunc.html#group_concat + def groupconcat_sql(self, expression): + this = expression.this + distinct = expression.find(exp.Distinct) + if distinct: + this = distinct.expressions[0] + distinct = "DISTINCT " + + if isinstance(expression.this, exp.Order): + self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.") + if expression.this.this and not distinct: + this = expression.this.this + + separator = expression.args.get("separator") + return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})" + + def least_sql(self, expression: exp.Least) -> str: if len(expression.expressions) > 1: return rename_func("MIN")(self, expression) return self.expressions(expression) - def transaction_sql(self, expression): + def transaction_sql(self, expression: exp.Transaction) -> str: this = expression.this this = f" {this}" if this else "" return f"BEGIN{this} TRANSACTION" diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 01e6357..2ba1a92 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -3,9 +3,18 @@ from __future__ import annotations from sqlglot import exp from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func from sqlglot.dialects.mysql import MySQL +from sqlglot.helper import seq_get class StarRocks(MySQL): + class Parser(MySQL.Parser): # type: ignore + FUNCTIONS = { + **MySQL.Parser.FUNCTIONS, + "DATE_TRUNC": lambda args: exp.TimestampTrunc( + this=seq_get(args, 1), unit=seq_get(args, 0) + ), + } + class Generator(MySQL.Generator): # type: ignore TYPE_MAPPING = { **MySQL.Generator.TYPE_MAPPING, # type: ignore @@ -20,6 +29,9 @@ class StarRocks(MySQL): exp.JSONExtract: arrow_json_extract_sql, exp.DateDiff: rename_func("DATEDIFF"), exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimestampTrunc: lambda self, e: self.func( + "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this + ), exp.TimeStrToDate: rename_func("TO_DATE"), exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 371e888..7b52047 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -117,14 +117,12 @@ def _string_agg_sql(self, e): if distinct: # exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.") - this = distinct.expressions[0] - distinct.pop() + this = distinct.pop().expressions[0] order = "" if isinstance(e.this, exp.Order): if e.this.this: - this = e.this.this - e.this.this.pop() + this = e.this.this.pop() order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space separator = e.args.get("separator") or exp.Literal.string(",") diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 0c345b3..b9da4cc 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -301,7 +301,7 @@ class Expression(metaclass=_Expression): the specified types. Args: - expression_types (type): the expression type(s) to match. + expression_types: the expression type(s) to match. Returns: The node which matches the criteria or None if no such node was found. @@ -314,7 +314,7 @@ class Expression(metaclass=_Expression): yields those that match at least one of the specified expression types. Args: - expression_types (type): the expression type(s) to match. + expression_types: the expression type(s) to match. Returns: The generator object. @@ -328,7 +328,7 @@ class Expression(metaclass=_Expression): Returns a nearest parent matching expression_types. Args: - expression_types (type): the expression type(s) to match. + expression_types: the expression type(s) to match. Returns: The parent node. @@ -336,8 +336,7 @@ class Expression(metaclass=_Expression): ancestor = self.parent while ancestor and not isinstance(ancestor, expression_types): ancestor = ancestor.parent - # ignore type because mypy doesn't know that we're checking type in the loop - return ancestor # type: ignore[return-value] + return t.cast(E, ancestor) @property def parent_select(self): @@ -549,8 +548,12 @@ class Expression(metaclass=_Expression): def pop(self): """ Remove this expression from its AST. + + Returns: + The popped expression. """ self.replace(None) + return self def assert_is(self, type_): """ @@ -626,6 +629,7 @@ IntoType = t.Union[ t.Type[Expression], t.Collection[t.Union[str, t.Type[Expression]]], ] +ExpOrStr = t.Union[str, Expression] class Condition(Expression): @@ -809,7 +813,7 @@ class Describe(Expression): class Set(Expression): - arg_types = {"expressions": True} + arg_types = {"expressions": False} class SetItem(Expression): @@ -905,6 +909,23 @@ class Column(Condition): def output_name(self) -> str: return self.name + @property + def parts(self) -> t.List[Identifier]: + """Return the parts of a column in order catalog, db, table, name.""" + return [part for part in reversed(list(self.args.values())) if part] + + def to_dot(self) -> Dot: + """Converts the column into a dot expression.""" + parts = self.parts + parent = self.parent + + while parent: + if isinstance(parent, Dot): + parts.append(parent.expression) + parent = parent.parent + + return Dot.build(parts) + class ColumnDef(Expression): arg_types = { @@ -1033,6 +1054,113 @@ class Constraint(Expression): class Delete(Expression): arg_types = {"with": False, "this": False, "using": False, "where": False, "returning": False} + def delete( + self, + table: ExpOrStr, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Delete: + """ + Create a DELETE expression or replace the table on an existing DELETE expression. + + Example: + >>> delete("tbl").sql() + 'DELETE FROM tbl' + + Args: + table: the table from which to delete. + dialect: the dialect used to parse the input expression. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + Delete: the modified expression. + """ + return _apply_builder( + expression=table, + instance=self, + arg="this", + dialect=dialect, + into=Table, + copy=copy, + **opts, + ) + + def where( + self, + *expressions: ExpOrStr, + append: bool = True, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Delete: + """ + Append to or set the WHERE expressions. + + Example: + >>> delete("tbl").where("x = 'a' OR x < 'b'").sql() + "DELETE FROM tbl WHERE x = 'a' OR x < 'b'" + + Args: + *expressions: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + Multiple expressions are combined with an AND operator. + append: if `True`, AND the new expressions to any existing expression. + Otherwise, this resets the expression. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + Delete: the modified expression. + """ + return _apply_conjunction_builder( + *expressions, + instance=self, + arg="where", + append=append, + into=Where, + dialect=dialect, + copy=copy, + **opts, + ) + + def returning( + self, + expression: ExpOrStr, + dialect: DialectType = None, + copy: bool = True, + **opts, + ) -> Delete: + """ + Set the RETURNING expression. Not supported by all dialects. + + Example: + >>> delete("tbl").returning("*", dialect="postgres").sql() + 'DELETE FROM tbl RETURNING *' + + Args: + expression: the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + dialect: the dialect used to parse the input expressions. + copy: if `False`, modify this expression instance in-place. + opts: other options to use to parse the input expressions. + + Returns: + Delete: the modified expression. + """ + return _apply_builder( + expression=expression, + instance=self, + arg="returning", + prefix="RETURNING", + dialect=dialect, + copy=copy, + into=Returning, + **opts, + ) + class Drop(Expression): arg_types = { @@ -1824,7 +1952,7 @@ class Union(Subqueryable): def select( self, - *expressions: str | Expression, + *expressions: ExpOrStr, append: bool = True, dialect: DialectType = None, copy: bool = True, @@ -2170,7 +2298,7 @@ class Select(Subqueryable): def select( self, - *expressions: str | Expression, + *expressions: ExpOrStr, append: bool = True, dialect: DialectType = None, copy: bool = True, @@ -2875,6 +3003,20 @@ class Dot(Binary): def name(self) -> str: return self.expression.name + @classmethod + def build(self, expressions: t.Sequence[Expression]) -> Dot: + """Build a Dot object with a sequence of expressions.""" + if len(expressions) < 2: + raise ValueError(f"Dot requires >= 2 expressions.") + + a, b, *expressions = expressions + dot = Dot(this=a, expression=b) + + for expression in expressions: + dot = Dot(this=dot, expression=expression) + + return dot + class DPipe(Binary): pass @@ -3049,7 +3191,7 @@ class TimeUnit(Expression): def __init__(self, **args): unit = args.get("unit") - if isinstance(unit, Column): + if isinstance(unit, (Column, Literal)): args["unit"] = Var(this=unit.name) elif isinstance(unit, Week): unit.set("this", Var(this=unit.this.name)) @@ -3261,6 +3403,10 @@ class Count(AggFunc): arg_types = {"this": False} +class CountIf(AggFunc): + pass + + class CurrentDate(Func): arg_types = {"this": False} @@ -3407,6 +3553,10 @@ class Explode(Func): pass +class ExponentialTimeDecayedAvg(AggFunc): + arg_types = {"this": True, "time": False, "decay": False} + + class Floor(Func): arg_types = {"this": True, "decimals": False} @@ -3420,10 +3570,18 @@ class GroupConcat(Func): arg_types = {"this": True, "separator": False} +class GroupUniqArray(AggFunc): + arg_types = {"this": True, "size": False} + + class Hex(Func): pass +class Histogram(AggFunc): + arg_types = {"this": True, "bins": False} + + class If(Func): arg_types = {"this": True, "true": True, "false": False} @@ -3493,7 +3651,11 @@ class Log10(Func): class LogicalOr(AggFunc): - _sql_names = ["LOGICAL_OR", "BOOL_OR"] + _sql_names = ["LOGICAL_OR", "BOOL_OR", "BOOLOR_AGG"] + + +class LogicalAnd(AggFunc): + _sql_names = ["LOGICAL_AND", "BOOL_AND", "BOOLAND_AGG"] class Lower(Func): @@ -3561,6 +3723,7 @@ class Quantile(AggFunc): # https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/quantiles/#quantiles class Quantiles(AggFunc): arg_types = {"parameters": True, "expressions": True} + is_var_len_args = True class QuantileIf(AggFunc): @@ -3830,7 +3993,7 @@ ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func)) # Helpers def maybe_parse( - sql_or_expression: str | Expression, + sql_or_expression: ExpOrStr, *, into: t.Optional[IntoType] = None, dialect: DialectType = None, @@ -4091,7 +4254,7 @@ def except_(left, right, distinct=True, dialect=None, **opts): return Except(this=left, expression=right, distinct=distinct) -def select(*expressions: str | Expression, dialect: DialectType = None, **opts) -> Select: +def select(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Select: """ Initializes a syntax tree from one or multiple SELECT expressions. @@ -4135,7 +4298,14 @@ def from_(*expressions, dialect=None, **opts) -> Select: return Select().from_(*expressions, dialect=dialect, **opts) -def update(table, properties, where=None, from_=None, dialect=None, **opts) -> Update: +def update( + table: str | Table, + properties: dict, + where: t.Optional[ExpOrStr] = None, + from_: t.Optional[ExpOrStr] = None, + dialect: DialectType = None, + **opts, +) -> Update: """ Creates an update statement. @@ -4144,18 +4314,18 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U "UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz WHERE id > 1" Args: - *properties (Dict[str, Any]): dictionary of properties to set which are + *properties: dictionary of properties to set which are auto converted to sql objects eg None -> NULL - where (str): sql conditional parsed into a WHERE statement - from_ (str): sql statement parsed into a FROM statement - dialect (str): the dialect used to parse the input expressions. + where: sql conditional parsed into a WHERE statement + from_: sql statement parsed into a FROM statement + dialect: the dialect used to parse the input expressions. **opts: other options to use to parse the input expressions. Returns: Update: the syntax tree for the UPDATE statement. """ - update = Update(this=maybe_parse(table, into=Table, dialect=dialect)) - update.set( + update_expr = Update(this=maybe_parse(table, into=Table, dialect=dialect)) + update_expr.set( "expressions", [ EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) @@ -4163,21 +4333,27 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U ], ) if from_: - update.set( + update_expr.set( "from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts), ) if isinstance(where, Condition): where = Where(this=where) if where: - update.set( + update_expr.set( "where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts), ) - return update + return update_expr -def delete(table, where=None, dialect=None, **opts) -> Delete: +def delete( + table: ExpOrStr, + where: t.Optional[ExpOrStr] = None, + returning: t.Optional[ExpOrStr] = None, + dialect: DialectType = None, + **opts, +) -> Delete: """ Builds a delete statement. @@ -4186,19 +4362,20 @@ def delete(table, where=None, dialect=None, **opts) -> Delete: 'DELETE FROM my_table WHERE id > 1' Args: - where (str|Condition): sql conditional parsed into a WHERE statement - dialect (str): the dialect used to parse the input expressions. + where: sql conditional parsed into a WHERE statement + returning: sql conditional parsed into a RETURNING statement + dialect: the dialect used to parse the input expressions. **opts: other options to use to parse the input expressions. Returns: Delete: the syntax tree for the DELETE statement. """ - return Delete( - this=maybe_parse(table, into=Table, dialect=dialect, **opts), - where=Where(this=where) - if isinstance(where, Condition) - else maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts), - ) + delete_expr = Delete().delete(table, dialect=dialect, copy=False, **opts) + if where: + delete_expr = delete_expr.where(where, dialect=dialect, copy=False, **opts) + if returning: + delete_expr = delete_expr.returning(returning, dialect=dialect, copy=False, **opts) + return delete_expr def condition(expression, dialect=None, **opts) -> Condition: @@ -4414,7 +4591,7 @@ def to_column(sql_path: str | Column, **kwargs) -> Column: def alias_( - expression: str | Expression, + expression: ExpOrStr, alias: str | Identifier, table: bool | t.Sequence[str | Identifier] = False, quoted: t.Optional[bool] = None, @@ -4516,7 +4693,7 @@ def column( ) -def cast(expression: str | Expression, to: str | DataType | DataType.Type, **opts) -> Cast: +def cast(expression: ExpOrStr, to: str | DataType | DataType.Type, **opts) -> Cast: """Cast an expression to a data type. Example: @@ -4595,7 +4772,7 @@ def values( ) -def var(name: t.Optional[str | Expression]) -> Var: +def var(name: t.Optional[ExpOrStr]) -> Var: """Build a SQL variable. Example: @@ -4612,7 +4789,7 @@ def var(name: t.Optional[str | Expression]) -> Var: The new variable node. """ if not name: - raise ValueError(f"Cannot convert empty name into var.") + raise ValueError("Cannot convert empty name into var.") if isinstance(name, Expression): name = name.name @@ -4682,7 +4859,7 @@ def convert(value) -> Expression: raise ValueError(f"Cannot convert {value}") -def replace_children(expression, fun): +def replace_children(expression, fun, *args, **kwargs): """ Replace children of an expression with the result of a lambda fun(child) -> exp. """ @@ -4694,7 +4871,7 @@ def replace_children(expression, fun): for cn in child_nodes: if isinstance(cn, Expression): - for child_node in ensure_collection(fun(cn)): + for child_node in ensure_collection(fun(cn, *args, **kwargs)): new_child_nodes.append(child_node) child_node.parent = expression child_node.arg_key = k diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 5936649..a6f4772 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -5,7 +5,7 @@ import typing as t from sqlglot import exp from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages -from sqlglot.helper import apply_index_offset, csv, seq_get +from sqlglot.helper import apply_index_offset, csv, seq_get, should_identify from sqlglot.time import format_time from sqlglot.tokens import TokenType @@ -25,8 +25,7 @@ class Generator: quote_end (str): specifies which ending character to use to delimit quotes. Default: '. identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ". identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ". - identify (bool): if set to True all identifiers will be delimited by the corresponding - character. + identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always. normalize (bool): if set to True all identifiers will lower cased string_escape (str): specifies a string escape character. Default: '. identifier_escape (str): specifies an identifier escape character. Default: ". @@ -57,10 +56,10 @@ class Generator: TRANSFORMS = { exp.DateAdd: lambda self, e: self.func( - "DATE_ADD", e.this, e.expression, e.args.get("unit") + "DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit")) ), exp.TsOrDsAdd: lambda self, e: self.func( - "TS_OR_DS_ADD", e.this, e.expression, e.args.get("unit") + "TS_OR_DS_ADD", e.this, e.expression, exp.Literal.string(e.text("unit")) ), exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]), exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}", @@ -736,7 +735,7 @@ class Generator: text = expression.name text = text.lower() if self.normalize else text text = text.replace(self.identifier_end, self._escaped_identifier_end) - if expression.args.get("quoted") or self.identify: + if expression.args.get("quoted") or should_identify(text, self.identify): text = f"{self.identifier_start}{text}{self.identifier_end}" return text @@ -1176,6 +1175,22 @@ class Generator: this = self.sql(expression, "this") return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}" + def setitem_sql(self, expression: exp.SetItem) -> str: + kind = self.sql(expression, "kind") + kind = f"{kind} " if kind else "" + this = self.sql(expression, "this") + expressions = self.expressions(expression) + collate = self.sql(expression, "collate") + collate = f" COLLATE {collate}" if collate else "" + global_ = "GLOBAL " if expression.args.get("global") else "" + return f"{global_}{kind}{this}{expressions}{collate}" + + def set_sql(self, expression: exp.Set) -> str: + expressions = ( + f" {self.expressions(expression, flat=True)}" if expression.expressions else "" + ) + return f"SET{expressions}" + def lock_sql(self, expression: exp.Lock) -> str: if self.LOCKING_READS_SUPPORTED: lock_type = "UPDATE" if expression.args["update"] else "SHARE" @@ -1359,8 +1374,8 @@ class Generator: sql = self.query_modifiers( expression, self.wrap(expression), - self.expressions(expression, key="pivots", sep=" "), alias, + self.expressions(expression, key="pivots", sep=" "), ) return self.prepend_ctes(expression, sql) @@ -1668,7 +1683,7 @@ class Generator: expression_sql = self.sql(expression, "expression") return f"COMMENT{exists_sql}ON {kind} {this} IS {expression_sql}" - def transaction_sql(self, *_) -> str: + def transaction_sql(self, expression: exp.Transaction) -> str: return "BEGIN" def commit_sql(self, expression: exp.Commit) -> str: diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 68e0383..6eff974 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -403,3 +403,20 @@ def first(it: t.Iterable[T]) -> T: Useful for sets. """ return next(i for i in it) + + +def should_identify(text: str, identify: str | bool) -> bool: + """Checks if text should be identified given an identify option. + + Args: + text: the text to check. + identify: "always" | True - always returns true, "safe" - true if no upper case + + Returns: + Whether or not a string should be identified. + """ + if identify is True or identify == "always": + return True + if identify == "safe": + return not any(char.isupper() for char in text) + return False diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index fc37a54..c5c780d 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -1,9 +1,12 @@ +from __future__ import annotations + import itertools from sqlglot import exp +from sqlglot.helper import should_identify -def canonicalize(expression: exp.Expression) -> exp.Expression: +def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expression: """Converts a sql expression into a standard form. This method relies on annotate_types because many of the @@ -11,15 +14,18 @@ def canonicalize(expression: exp.Expression) -> exp.Expression: Args: expression: The expression to canonicalize. + identify: Whether or not to force identify identifier. """ - exp.replace_children(expression, canonicalize) + exp.replace_children(expression, canonicalize, identify=identify) expression = add_text_to_concat(expression) expression = coerce_type(expression) expression = remove_redundant_casts(expression) + expression = ensure_bool_predicates(expression) if isinstance(expression, exp.Identifier): - expression.set("quoted", True) + if should_identify(expression.this, identify): + expression.set("quoted", True) return expression @@ -52,6 +58,17 @@ def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: return expression +def ensure_bool_predicates(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Connector): + _replace_int_predicate(expression.left) + _replace_int_predicate(expression.right) + + elif isinstance(expression, (exp.Where, exp.Having)): + _replace_int_predicate(expression.this) + + return expression + + def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: for a, b in itertools.permutations([a, b]): if ( @@ -68,3 +85,8 @@ def _replace_cast(node: exp.Expression, to: str) -> None: cast = exp.Cast(this=node.copy(), to=data_type) cast.type = data_type node.replace(cast) + + +def _replace_int_predicate(expression: exp.Expression) -> None: + if expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: + expression.replace(exp.NEQ(this=expression.copy(), expression=exp.Literal.number(0))) diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 07a1b70..2e51117 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -1,7 +1,6 @@ from collections import defaultdict from sqlglot import alias, exp -from sqlglot.helper import flatten from sqlglot.optimizer.qualify_columns import Resolver from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import ensure_schema @@ -86,14 +85,15 @@ def _remove_unused_selections(scope, parent_selections, schema): else: order_refs = set() - new_selections = defaultdict(list) + new_selections = [] removed = False star = False + for selection in scope.selects: name = selection.alias_or_name if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs: - new_selections[name].append(selection) + new_selections.append(selection) else: if selection.is_star: star = True @@ -101,18 +101,17 @@ def _remove_unused_selections(scope, parent_selections, schema): if star: resolver = Resolver(scope, schema) + names = {s.alias_or_name for s in new_selections} for name in sorted(parent_selections): - if name not in new_selections: - new_selections[name].append( - alias(exp.column(name, table=resolver.get_table(name)), name) - ) + if name not in names: + new_selections.append(alias(exp.column(name, table=resolver.get_table(name)), name)) # If there are no remaining selections, just select a single constant if not new_selections: - new_selections[""].append(DEFAULT_SELECTION()) + new_selections.append(DEFAULT_SELECTION()) - scope.expression.select(*flatten(new_selections.values()), append=False, copy=False) + scope.expression.select(*new_selections, append=False, copy=False) if removed: scope.clear_cache() diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index e793e31..66b3170 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -37,6 +37,7 @@ def qualify_columns(expression, schema): _qualify_outputs(scope) _expand_group_by(scope, resolver) _expand_order_by(scope) + return expression @@ -213,6 +214,21 @@ def _qualify_columns(scope, resolver): # column_table can be a '' because bigquery unnest has no table alias if column_table: column.set("table", column_table) + elif column_table not in scope.sources: + # structs are used like tables (e.g. "struct"."field"), so they need to be qualified + # separately and represented as dot(dot(...(., field1), field2, ...)) + + root, *parts = column.parts + + if root.name in scope.sources: + # struct is already qualified, but we still need to change the AST representation + column_table = root + root, *parts = parts + else: + column_table = resolver.get_table(root.name) + + if column_table: + column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) columns_missing_from_scope = [] # Determine whether each reference in the order by clause is to a column or an alias. @@ -373,10 +389,14 @@ class Resolver: if isinstance(node, exp.Subqueryable): while node and node.alias != table_name: node = node.parent + node_alias = node.args.get("alias") if node_alias: return node_alias.this - return exp.to_identifier(table_name) + + return exp.to_identifier( + table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None + ) @property def all_columns(self): diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 6e50182..93e1179 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -34,11 +34,9 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) scope.rename_source(None, alias_) - for source in scope.sources.values(): + for name, source in scope.sources.items(): if isinstance(source, exp.Table): - identifier = isinstance(source.this, exp.Identifier) - - if identifier: + if isinstance(source.this, exp.Identifier): if not source.args.get("db"): source.set("db", exp.to_identifier(db)) if not source.args.get("catalog"): @@ -48,7 +46,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): source = source.replace( alias( source.copy(), - source.this if identifier else next_name(), + name if name else next_name(), table=True, ) ) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 335ff3e..9c0768c 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -4,6 +4,7 @@ from enum import Enum, auto from sqlglot import exp from sqlglot.errors import OptimizeError +from sqlglot.helper import find_new_name class ScopeType(Enum): @@ -293,6 +294,8 @@ class Scope: result = {} for name, node in referenced_names: + if name in result: + raise OptimizeError(f"Alias already used: {name}") if name in self.sources: result[name] = (node, self.sources[name]) @@ -594,6 +597,8 @@ def _traverse_tables(scope): if table_name in scope.sources: # This is a reference to a parent source (e.g. a CTE), not an actual table. sources[source_name] = scope.sources[table_name] + elif source_name in sources: + sources[find_new_name(sources, table_name)] = expression else: sources[source_name] = expression continue diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 90fdade..a36251e 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -96,6 +96,7 @@ class Parser(metaclass=_Parser): NO_PAREN_FUNCTIONS = { TokenType.CURRENT_DATE: exp.CurrentDate, TokenType.CURRENT_DATETIME: exp.CurrentDate, + TokenType.CURRENT_TIME: exp.CurrentTime, TokenType.CURRENT_TIMESTAMP: exp.CurrentTimestamp, } @@ -198,7 +199,6 @@ class Parser(metaclass=_Parser): TokenType.COMMIT, TokenType.COMPOUND, TokenType.CONSTRAINT, - TokenType.CURRENT_TIME, TokenType.DEFAULT, TokenType.DELETE, TokenType.DESCRIBE, @@ -370,8 +370,9 @@ class Parser(metaclass=_Parser): LAMBDAS = { TokenType.ARROW: lambda self, expressions: self.expression( exp.Lambda, - this=self._parse_conjunction().transform( - self._replace_lambda, {node.name for node in expressions} + this=self._replace_lambda( + self._parse_conjunction(), + {node.name for node in expressions}, ), expressions=expressions, ), @@ -441,6 +442,7 @@ class Parser(metaclass=_Parser): exp.With: lambda self: self._parse_with(), exp.Window: lambda self: self._parse_named_window(), exp.Qualify: lambda self: self._parse_qualify(), + exp.Returning: lambda self: self._parse_returning(), "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(), } @@ -460,6 +462,7 @@ class Parser(metaclass=_Parser): TokenType.LOAD_DATA: lambda self: self._parse_load_data(), TokenType.MERGE: lambda self: self._parse_merge(), TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), + TokenType.SET: lambda self: self._parse_set(), TokenType.UNCACHE: lambda self: self._parse_uncache(), TokenType.UPDATE: lambda self: self._parse_update(), TokenType.USE: lambda self: self.expression( @@ -656,15 +659,15 @@ class Parser(metaclass=_Parser): } FUNCTION_PARSERS: t.Dict[str, t.Callable] = { + "CAST": lambda self: self._parse_cast(self.STRICT_CAST), "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), - "TRY_CONVERT": lambda self: self._parse_convert(False), "EXTRACT": lambda self: self._parse_extract(), "POSITION": lambda self: self._parse_position(), + "STRING_AGG": lambda self: self._parse_string_agg(), "SUBSTRING": lambda self: self._parse_substring(), "TRIM": lambda self: self._parse_trim(), - "CAST": lambda self: self._parse_cast(self.STRICT_CAST), "TRY_CAST": lambda self: self._parse_cast(False), - "STRING_AGG": lambda self: self._parse_string_agg(), + "TRY_CONVERT": lambda self: self._parse_convert(False), } QUERY_MODIFIER_PARSERS = { @@ -684,13 +687,28 @@ class Parser(metaclass=_Parser): "sample": lambda self: self._parse_table_sample(as_modifier=True), } + SET_PARSERS = { + "GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"), + "LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"), + "SESSION": lambda self: self._parse_set_item_assignment("SESSION"), + "TRANSACTION": lambda self: self._parse_set_transaction(), + } + SHOW_PARSERS: t.Dict[str, t.Callable] = {} - SET_PARSERS: t.Dict[str, t.Callable] = {} MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} + TRANSACTION_CHARACTERISTICS = { + "ISOLATION LEVEL REPEATABLE READ", + "ISOLATION LEVEL READ COMMITTED", + "ISOLATION LEVEL READ UNCOMMITTED", + "ISOLATION LEVEL SERIALIZABLE", + "READ WRITE", + "READ ONLY", + } + INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"} WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS} @@ -1775,11 +1793,12 @@ class Parser(metaclass=_Parser): self, alias_tokens: t.Optional[t.Collection[TokenType]] = None ) -> t.Optional[exp.Expression]: any_token = self._match(TokenType.ALIAS) - alias = self._parse_id_var( - any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS + alias = ( + self._parse_id_var(any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS) + or self._parse_string_as_identifier() ) - index = self._index + index = self._index if self._match(TokenType.L_PAREN): columns = self._parse_csv(self._parse_function_parameter) self._match_r_paren() if columns else self._retreat(index) @@ -2046,7 +2065,12 @@ class Parser(metaclass=_Parser): def _parse_table_parts(self, schema: bool = False) -> exp.Expression: catalog = None db = None - table = (not schema and self._parse_function()) or self._parse_id_var(any_token=False) + + table = ( + (not schema and self._parse_function()) + or self._parse_id_var(any_token=False) + or self._parse_string_as_identifier() + ) while self._match(TokenType.DOT): if catalog: @@ -2085,6 +2109,8 @@ class Parser(metaclass=_Parser): subquery = self._parse_select(table=True) if subquery: + if not subquery.args.get("pivots"): + subquery.set("pivots", self._parse_pivots()) return subquery this = self._parse_table_parts(schema=schema) @@ -3370,9 +3396,9 @@ class Parser(metaclass=_Parser): def _parse_window( self, this: t.Optional[exp.Expression], alias: bool = False ) -> t.Optional[exp.Expression]: - if self._match(TokenType.FILTER): - where = self._parse_wrapped(self._parse_where) - this = self.expression(exp.Filter, this=this, expression=where) + if self._match_pair(TokenType.FILTER, TokenType.L_PAREN): + this = self.expression(exp.Filter, this=this, expression=self._parse_where()) + self._match_r_paren() # T-SQL allows the OVER (...) syntax after WITHIN GROUP. # https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16 @@ -3504,6 +3530,9 @@ class Parser(metaclass=_Parser): return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev) return self._parse_placeholder() + def _parse_string_as_identifier(self) -> t.Optional[exp.Expression]: + return exp.to_identifier(self._match(TokenType.STRING) and self._prev.text, quoted=True) + def _parse_number(self) -> t.Optional[exp.Expression]: if self._match(TokenType.NUMBER): return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev) @@ -3778,23 +3807,6 @@ class Parser(metaclass=_Parser): ) return self._parse_as_command(start) - def _parse_show(self) -> t.Optional[exp.Expression]: - parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) # type: ignore - if parser: - return parser(self) - self._advance() - return self.expression(exp.Show, this=self._prev.text.upper()) - - def _default_parse_set_item(self) -> exp.Expression: - return self.expression( - exp.SetItem, - this=self._parse_statement(), - ) - - def _parse_set_item(self) -> t.Optional[exp.Expression]: - parser = self._find_parser(self.SET_PARSERS, self._set_trie) # type: ignore - return parser(self) if parser else self._default_parse_set_item() - def _parse_merge(self) -> exp.Expression: self._match(TokenType.INTO) target = self._parse_table() @@ -3861,8 +3873,71 @@ class Parser(metaclass=_Parser): expressions=whens, ) + def _parse_show(self) -> t.Optional[exp.Expression]: + parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) # type: ignore + if parser: + return parser(self) + self._advance() + return self.expression(exp.Show, this=self._prev.text.upper()) + + def _parse_set_item_assignment( + self, kind: t.Optional[str] = None + ) -> t.Optional[exp.Expression]: + index = self._index + + if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"): + return self._parse_set_transaction(global_=kind == "GLOBAL") + + left = self._parse_primary() or self._parse_id_var() + + if not self._match_texts(("=", "TO")): + self._retreat(index) + return None + + right = self._parse_statement() or self._parse_id_var() + this = self.expression( + exp.EQ, + this=left, + expression=right, + ) + + return self.expression( + exp.SetItem, + this=this, + kind=kind, + ) + + def _parse_set_transaction(self, global_: bool = False) -> exp.Expression: + self._match_text_seq("TRANSACTION") + characteristics = self._parse_csv( + lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS) + ) + return self.expression( + exp.SetItem, + expressions=characteristics, + kind="TRANSACTION", + **{"global": global_}, # type: ignore + ) + + def _parse_set_item(self) -> t.Optional[exp.Expression]: + parser = self._find_parser(self.SET_PARSERS, self._set_trie) # type: ignore + return parser(self) if parser else self._parse_set_item_assignment(kind=None) + def _parse_set(self) -> exp.Expression: - return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item)) + index = self._index + set_ = self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item)) + + if self._curr: + self._retreat(index) + return self._parse_as_command(self._prev) + + return set_ + + def _parse_var_from_options(self, options: t.Collection[str]) -> t.Optional[exp.Expression]: + for option in options: + if self._match_text_seq(*option.split(" ")): + return exp.Var(this=option) + return None def _parse_as_command(self, start: Token) -> exp.Command: while self._curr: @@ -3874,6 +3949,9 @@ class Parser(metaclass=_Parser): def _find_parser( self, parsers: t.Dict[str, t.Callable], trie: t.Dict ) -> t.Optional[t.Callable]: + if not self._curr: + return None + index = self._index this = [] while True: @@ -3973,7 +4051,16 @@ class Parser(metaclass=_Parser): return this def _replace_lambda(self, node, lambda_variables): - if isinstance(node, exp.Column): - if node.name in lambda_variables: - return node.this + for column in node.find_all(exp.Column): + if column.parts[0].name in lambda_variables: + dot_or_id = column.to_dot() if column.table else column.this + parent = column.parent + + while isinstance(parent, exp.Dot): + if not isinstance(parent.parent, exp.Dot): + parent.replace(dot_or_id) + break + parent = parent.parent + else: + column.replace(dot_or_id) return node diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 053bbdd..eb3c08f 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -502,6 +502,7 @@ class Tokenizer(metaclass=_Tokenizer): "CUBE": TokenType.CUBE, "CURRENT_DATE": TokenType.CURRENT_DATE, "CURRENT ROW": TokenType.CURRENT_ROW, + "CURRENT_TIME": TokenType.CURRENT_TIME, "CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP, "DATABASE": TokenType.DATABASE, "DEFAULT": TokenType.DEFAULT, @@ -725,7 +726,6 @@ class Tokenizer(metaclass=_Tokenizer): TokenType.COMMAND, TokenType.EXECUTE, TokenType.FETCH, - TokenType.SET, TokenType.SHOW, } @@ -851,8 +851,10 @@ class Tokenizer(metaclass=_Tokenizer): # If we have either a semicolon or a begin token before the command's token, we'll parse # whatever follows the command's token as a string - if token_type in self.COMMANDS and ( - len(self.tokens) == 1 or self.tokens[-2].token_type in self.COMMAND_PREFIX_TOKENS + if ( + token_type in self.COMMANDS + and self._peek != ";" + and (len(self.tokens) == 1 or self.tokens[-2].token_type in self.COMMAND_PREFIX_TOKENS) ): start = self._current tokens = len(self.tokens) diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index aa7d240..2eafb0b 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -2,13 +2,12 @@ from __future__ import annotations import typing as t +from sqlglot import expressions as exp from sqlglot.helper import find_new_name if t.TYPE_CHECKING: from sqlglot.generator import Generator -from sqlglot import expressions as exp - def unalias_group(expression: exp.Expression) -> exp.Expression: """ @@ -61,8 +60,7 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: and expression.args["distinct"].args.get("on") and isinstance(expression.args["distinct"].args["on"], exp.Tuple) ): - distinct_cols = expression.args["distinct"].args["on"].expressions - expression.args["distinct"].pop() + distinct_cols = expression.args["distinct"].pop().args["on"].expressions outer_selects = expression.selects row_number = find_new_name(expression.named_selects, "_row_number") window = exp.Window( @@ -71,14 +69,49 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: ) order = expression.args.get("order") if order: - window.set("order", order.copy()) - order.pop() + window.set("order", order.pop().copy()) window = exp.alias_(window, row_number) expression.select(window, copy=False) return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') return expression +def eliminate_qualify(expression: exp.Expression) -> exp.Expression: + """ + Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. + + The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: + https://docs.snowflake.com/en/sql-reference/constructs/qualify + + Some dialects don't support window functions in the WHERE clause, so we need to include them as + projections in the subquery, in order to refer to them in the outer filter using aliases. Also, + if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, + otherwise we won't be able to refer to it in the outer query's WHERE clause. + """ + if isinstance(expression, exp.Select) and expression.args.get("qualify"): + taken = set(expression.named_selects) + for select in expression.selects: + if not select.alias_or_name: + alias = find_new_name(taken, "_c") + select.replace(exp.alias_(select.copy(), alias)) + taken.add(alias) + + outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) + qualify_filters = expression.args["qualify"].pop().this + + for expr in qualify_filters.find_all((exp.Window, exp.Column)): + if isinstance(expr, exp.Window): + alias = find_new_name(expression.named_selects, "_w") + expression.select(exp.alias_(expr.copy(), alias), copy=False) + expr.replace(exp.column(alias)) + elif expr.name not in expression.named_selects: + expression.select(expr.copy(), copy=False) + + return outer_selects.from_(expression.subquery(alias="_t")).where(qualify_filters) + + return expression + + def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: """ Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. @@ -139,6 +172,7 @@ def delegate(attr: str) -> t.Callable: UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))} ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))} +ELIMINATE_QUALIFY = {exp.Select: preprocess([eliminate_qualify], delegate("select_sql"))} REMOVE_PRECISION_PARAMETERIZED_TYPES = { exp.Cast: preprocess([remove_precision_parameterized_types], delegate("cast_sql")) } -- cgit v1.2.3