From b38d717d5933fdae3fe85c87df7aee9a251fb58e Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 3 Apr 2023 09:31:54 +0200 Subject: Merging upstream version 11.4.5. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 2 +- sqlglot/dialects/bigquery.py | 20 ++-- sqlglot/dialects/clickhouse.py | 2 + sqlglot/dialects/databricks.py | 2 + sqlglot/dialects/dialect.py | 5 + sqlglot/dialects/drill.py | 5 +- sqlglot/dialects/duckdb.py | 3 + sqlglot/dialects/hive.py | 26 ++++- sqlglot/dialects/mysql.py | 16 ++- sqlglot/dialects/oracle.py | 51 ++++----- sqlglot/dialects/postgres.py | 2 + sqlglot/dialects/snowflake.py | 11 +- sqlglot/dialects/sqlite.py | 5 +- sqlglot/dialects/teradata.py | 3 +- sqlglot/dialects/tsql.py | 11 +- sqlglot/diff.py | 24 ++--- sqlglot/executor/__init__.py | 2 +- sqlglot/executor/python.py | 42 +++----- sqlglot/expressions.py | 163 ++++++++++++++++------------- sqlglot/generator.py | 80 ++++++++++++-- sqlglot/helper.py | 11 +- sqlglot/optimizer/annotate_types.py | 12 +-- sqlglot/optimizer/canonicalize.py | 2 +- sqlglot/optimizer/eliminate_joins.py | 5 +- sqlglot/optimizer/eliminate_subqueries.py | 2 - sqlglot/optimizer/lower_identities.py | 8 +- sqlglot/optimizer/merge_subqueries.py | 5 +- sqlglot/optimizer/normalize.py | 104 ++++++++++++------- sqlglot/optimizer/optimize_joins.py | 2 - sqlglot/optimizer/optimizer.py | 4 +- sqlglot/optimizer/qualify_columns.py | 88 +++++++++++----- sqlglot/optimizer/scope.py | 2 +- sqlglot/optimizer/simplify.py | 166 ++++++++++++++++-------------- sqlglot/parser.py | 109 +++++++++++++++++--- sqlglot/planner.py | 2 +- sqlglot/schema.py | 4 +- sqlglot/tokens.py | 15 +-- 37 files changed, 650 insertions(+), 366 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 10046d1..b53b261 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -47,7 +47,7 @@ if t.TYPE_CHECKING: T = t.TypeVar("T", bound=Expression) -__version__ = "11.4.1" +__version__ = "11.4.5" pretty = False """Whether to format generated SQL by default.""" diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 6a43846..a3f9e6d 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import ( Dialect, datestrtodate_sql, inline_array_sql, + max_or_greatest, min_or_least, no_ilike_sql, rename_func, @@ -212,6 +213,9 @@ class BigQuery(Dialect): ), } + LOG_BASE_FIRST = False + LOG_DEFAULTS_TO_LN = True + class Generator(generator.Generator): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore @@ -227,6 +231,7 @@ class BigQuery(Dialect): exp.GroupConcat: rename_func("STRING_AGG"), exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), + exp.Max: max_or_greatest, exp.Min: min_or_least, exp.Select: transforms.preprocess( [_unqualify_unnest], transforms.delegate("select_sql") @@ -253,17 +258,19 @@ class BigQuery(Dialect): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore - exp.DataType.Type.TINYINT: "INT64", - exp.DataType.Type.SMALLINT: "INT64", - exp.DataType.Type.INT: "INT64", exp.DataType.Type.BIGINT: "INT64", + exp.DataType.Type.BOOLEAN: "BOOL", + exp.DataType.Type.CHAR: "STRING", exp.DataType.Type.DECIMAL: "NUMERIC", - exp.DataType.Type.FLOAT: "FLOAT64", exp.DataType.Type.DOUBLE: "FLOAT64", - exp.DataType.Type.BOOLEAN: "BOOL", + exp.DataType.Type.FLOAT: "FLOAT64", + exp.DataType.Type.INT: "INT64", + exp.DataType.Type.NCHAR: "STRING", + exp.DataType.Type.NVARCHAR: "STRING", + exp.DataType.Type.SMALLINT: "INT64", exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.TINYINT: "INT64", exp.DataType.Type.VARCHAR: "STRING", - exp.DataType.Type.NVARCHAR: "STRING", } PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore @@ -271,6 +278,7 @@ class BigQuery(Dialect): } EXPLICIT_UNION = True + LIMIT_FETCH = "LIMIT" def array_sql(self, expression: exp.Array) -> str: first_arg = seq_get(expression.expressions, 0) diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index b54a77d..89e2296 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -68,6 +68,8 @@ class ClickHouse(Dialect): TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore + LOG_DEFAULTS_TO_LN = True + def _parse_in( self, this: t.Optional[exp.Expression], is_global: bool = False ) -> exp.Expression: diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 4268f1b..2f93ee7 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -16,6 +16,8 @@ class Databricks(Spark): "DATEDIFF": parse_date_delta(exp.DateDiff), } + LOG_DEFAULTS_TO_LN = True + class Generator(Spark.Generator): TRANSFORMS = { **Spark.Generator.TRANSFORMS, # type: ignore diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index b267521..839589d 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -430,6 +430,11 @@ def min_or_least(self: Generator, expression: exp.Min) -> str: return rename_func(name)(self, expression) +def max_or_greatest(self: Generator, expression: exp.Max) -> str: + name = "GREATEST" if expression.expressions else "MAX" + return rename_func(name)(self, expression) + + def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: cond = expression.this diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index dc0e519..a33aadc 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re import typing as t from sqlglot import exp, generator, parser, tokens @@ -102,6 +101,8 @@ class Drill(Dialect): "TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"), } + LOG_DEFAULTS_TO_LN = True + class Generator(generator.Generator): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -154,4 +155,4 @@ class Drill(Dialect): } def normalize_func(self, name: str) -> str: - return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`" + return name if exp.SAFE_IDENTIFIER_RE.match(name) else f"`{name}`" diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index f1d2266..c034208 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -80,6 +80,7 @@ class DuckDB(Dialect): class Tokenizer(tokens.Tokenizer): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "~": TokenType.RLIKE, ":=": TokenType.EQ, "ATTACH": TokenType.COMMAND, "BINARY": TokenType.VARBINARY, @@ -212,5 +213,7 @@ class DuckDB(Dialect): "except": "EXCLUDE", } + LIMIT_FETCH = "LIMIT" + def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str: return super().tablesample_sql(expression, seed_prefix="REPEATABLE") diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 0110eee..68137ae 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import ( format_time_lambda, if_sql, locate_to_strposition, + max_or_greatest, min_or_least, no_ilike_sql, no_recursive_cte_sql, @@ -34,6 +35,13 @@ DATE_DELTA_INTERVAL = { "DAY": ("DATE_ADD", 1), } +TIME_DIFF_FACTOR = { + "MILLISECOND": " * 1000", + "SECOND": "", + "MINUTE": " / 60", + "HOUR": " / 3600", +} + DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") @@ -51,6 +59,14 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str: def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: unit = expression.text("unit").upper() + + factor = TIME_DIFF_FACTOR.get(unit) + if factor is not None: + left = self.sql(expression, "this") + right = self.sql(expression, "expression") + sec_diff = f"UNIX_TIMESTAMP({left}) - UNIX_TIMESTAMP({right})" + return f"({sec_diff}){factor}" if factor else sec_diff + sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF" _, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1)) multiplier_sql = f" / {multiplier}" if multiplier > 1 else "" @@ -237,11 +253,6 @@ class Hive(Dialect): "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, "LOCATE": locate_to_strposition, - "LOG": ( - lambda args: exp.Log.from_arg_list(args) - if len(args) > 1 - else exp.Ln.from_arg_list(args) - ), "MAP": parse_var_map, "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "PERCENTILE": exp.Quantile.from_arg_list, @@ -261,6 +272,8 @@ class Hive(Dialect): ), } + LOG_DEFAULTS_TO_LN = True + class Generator(generator.Generator): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -293,6 +306,7 @@ class Hive(Dialect): exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), exp.Map: var_map_sql, + exp.Max: max_or_greatest, exp.Min: min_or_least, exp.VarMap: var_map_sql, exp.Create: create_with_partitions_sql, @@ -338,6 +352,8 @@ class Hive(Dialect): exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA, } + LIMIT_FETCH = "LIMIT" + def arrayagg_sql(self, expression: exp.ArrayAgg) -> str: return self.func( "COLLECT_LIST", diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 1e2cfa3..5dfa811 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -3,7 +3,9 @@ from __future__ import annotations from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, + arrow_json_extract_scalar_sql, locate_to_strposition, + max_or_greatest, min_or_least, no_ilike_sql, no_paren_current_date_sql, @@ -288,6 +290,8 @@ class MySQL(Dialect): "SWAPS", } + LOG_DEFAULTS_TO_LN = True + def _parse_show_mysql(self, this, target=False, full=None, global_=None): if target: if isinstance(target, str): @@ -303,7 +307,13 @@ class MySQL(Dialect): db = None else: position = None - db = self._parse_id_var() if self._match_text_seq("FROM") else None + db = None + + if self._match(TokenType.FROM): + db = self._parse_id_var() + elif self._match(TokenType.DOT): + db = target_id + target_id = self._parse_id_var() channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None @@ -384,6 +394,8 @@ class MySQL(Dialect): exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.ILike: no_ilike_sql, + exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.Max: max_or_greatest, exp.Min: min_or_least, exp.TableSample: no_tablesample_sql, exp.TryCast: no_trycast_sql, @@ -415,6 +427,8 @@ class MySQL(Dialect): exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, } + LIMIT_FETCH = "LIMIT" + def show_sql(self, expression): this = f" {expression.name}" full = " FULL" if expression.args.get("full") else "" diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 7028a04..fad6c4a 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -4,7 +4,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql -from sqlglot.helper import csv, seq_get +from sqlglot.helper import seq_get from sqlglot.tokens import TokenType PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { @@ -13,10 +13,6 @@ PASSING_TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { } -def _limit_sql(self, expression): - return self.fetch_sql(exp.Fetch(direction="FIRST", count=expression.expression)) - - def _parse_xml_table(self) -> exp.XMLTable: this = self._parse_string() @@ -89,6 +85,20 @@ class Oracle(Dialect): column.set("join_mark", self._match(TokenType.JOIN_MARKER)) return column + def _parse_hint(self) -> t.Optional[exp.Expression]: + if self._match(TokenType.HINT): + start = self._curr + while self._curr and not self._match_pair(TokenType.STAR, TokenType.SLASH): + self._advance() + + if not self._curr: + self.raise_error("Expected */ after HINT") + + end = self._tokens[self._index - 3] + return exp.Hint(expressions=[self._find_sql(start, end)]) + + return None + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True @@ -110,41 +120,20 @@ class Oracle(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore + exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.ILike: no_ilike_sql, - exp.Limit: _limit_sql, - exp.Trim: trim_sql, exp.Matches: rename_func("DECODE"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "), + exp.Substring: rename_func("SUBSTR"), exp.Table: lambda self, e: self.table_sql(e, sep=" "), exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", - exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", - exp.Substring: rename_func("SUBSTR"), exp.ToChar: lambda self, e: self.function_fallback_sql(e), + exp.Trim: trim_sql, + exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", } - def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: - return csv( - *sqls, - *[self.sql(sql) for sql in expression.args.get("joins") or []], - self.sql(expression, "match"), - *[self.sql(sql) for sql in expression.args.get("laterals") or []], - self.sql(expression, "where"), - self.sql(expression, "group"), - self.sql(expression, "having"), - self.sql(expression, "qualify"), - self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True) - if expression.args.get("windows") - else "", - self.sql(expression, "distribute"), - self.sql(expression, "sort"), - self.sql(expression, "cluster"), - self.sql(expression, "order"), - self.sql(expression, "offset"), # offset before limit in oracle - self.sql(expression, "limit"), - self.sql(expression, "lock"), - sep="", - ) + LIMIT_FETCH = "FETCH" def offset_sql(self, expression: exp.Offset) -> str: return f"{super().offset_sql(expression)} ROWS" diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 5f556a5..31b7e45 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_scalar_sql, arrow_json_extract_sql, format_time_lambda, + max_or_greatest, min_or_least, no_paren_current_date_sql, no_tablesample_sql, @@ -315,6 +316,7 @@ class Postgres(Dialect): exp.DateDiff: _date_diff_sql, exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), + exp.Max: max_or_greatest, exp.Min: min_or_least, exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), exp.ArrayContains: lambda self, e: self.binary(e, "@>"), diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 799e9a6..c50961c 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -9,6 +9,7 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, format_time_lambda, inline_array_sql, + max_or_greatest, min_or_least, rename_func, timestamptrunc_sql, @@ -275,6 +276,9 @@ class Snowflake(Dialect): exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.ArrayJoin: rename_func("ARRAY_TO_STRING"), exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this), + exp.DateDiff: lambda self, e: self.func( + "DATEDIFF", e.text("unit"), e.expression, e.this + ), exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, exp.If: rename_func("IFF"), @@ -296,6 +300,7 @@ class Snowflake(Dialect): exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), exp.UnixToTime: _unix_to_time_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), + exp.Max: max_or_greatest, exp.Min: min_or_least, } @@ -314,12 +319,6 @@ class Snowflake(Dialect): exp.SetProperty: exp.Properties.Location.UNSUPPORTED, } - def ilikeany_sql(self, expression: exp.ILikeAny) -> str: - return self.binary(expression, "ILIKE ANY") - - def likeany_sql(self, expression: exp.LikeAny) -> str: - return self.binary(expression, "LIKE ANY") - def except_op(self, expression): if not expression.args.get("distinct", False): self.unsupported("EXCEPT with All is not supported in Snowflake") diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index ab78b6e..4091dbb 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -82,6 +82,8 @@ class SQLite(Dialect): exp.TryCast: no_trycast_sql, } + LIMIT_FETCH = "LIMIT" + def cast_sql(self, expression: exp.Cast) -> str: if expression.to.this == exp.DataType.Type.DATE: return self.func("DATE", expression.this) @@ -115,9 +117,6 @@ class SQLite(Dialect): return f"CAST({sql} AS INTEGER)" - def fetch_sql(self, expression: exp.Fetch) -> str: - return self.limit_sql(exp.Limit(expression=expression.args.get("count"))) - # https://www.sqlite.org/lang_aggfunc.html#group_concat def groupconcat_sql(self, expression): this = expression.this diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 8bd0a0c..3d43793 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -1,7 +1,7 @@ from __future__ import annotations from sqlglot import exp, generator, parser, tokens -from sqlglot.dialects.dialect import Dialect, min_or_least +from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least from sqlglot.tokens import TokenType @@ -128,6 +128,7 @@ class Teradata(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.Max: max_or_greatest, exp.Min: min_or_least, exp.ToChar: lambda self, e: self.function_fallback_sql(e), } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 7b52047..8e9b6c3 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -6,6 +6,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, + max_or_greatest, min_or_least, parse_date_delta, rename_func, @@ -269,7 +270,6 @@ class TSQL(Dialect): # TSQL allows @, # to appear as a variable/identifier prefix SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy() - SINGLE_TOKENS.pop("@") SINGLE_TOKENS.pop("#") class Parser(parser.Parser): @@ -313,6 +313,9 @@ class TSQL(Dialect): TokenType.END: lambda self: self._parse_command(), } + LOG_BASE_FIRST = False + LOG_DEFAULTS_TO_LN = True + def _parse_system_time(self) -> t.Optional[exp.Expression]: if not self._match_text_seq("FOR", "SYSTEM_TIME"): return None @@ -435,11 +438,17 @@ class TSQL(Dialect): exp.NumberToStr: _format_sql, exp.TimeToStr: _format_sql, exp.GroupConcat: _string_agg_sql, + exp.Max: max_or_greatest, exp.Min: min_or_least, } TRANSFORMS.pop(exp.ReturnsProperty) + LIMIT_FETCH = "FETCH" + + def offset_sql(self, expression: exp.Offset) -> str: + return f"{super().offset_sql(expression)} ROWS" + def systemtime_sql(self, expression: exp.SystemTime) -> str: kind = expression.args["kind"] if kind == "ALL": diff --git a/sqlglot/diff.py b/sqlglot/diff.py index dddb9ad..86665e0 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -12,7 +12,7 @@ from dataclasses import dataclass from heapq import heappop, heappush from sqlglot import Dialect, expressions as exp -from sqlglot.helper import ensure_collection +from sqlglot.helper import ensure_list @dataclass(frozen=True) @@ -151,8 +151,8 @@ class ChangeDistiller: self._source = source self._target = target - self._source_index = {id(n[0]): n[0] for n in source.bfs()} - self._target_index = {id(n[0]): n[0] for n in target.bfs()} + self._source_index = {id(n): n for n, *_ in self._source.bfs()} + self._target_index = {id(n): n for n, *_ in self._target.bfs()} self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes) self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values()) self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {} @@ -199,10 +199,10 @@ class ChangeDistiller: matching_set = leaves_matching_set.copy() ordered_unmatched_source_nodes = { - id(n[0]): None for n in self._source.bfs() if id(n[0]) in self._unmatched_source_nodes + id(n): None for n, *_ in self._source.bfs() if id(n) in self._unmatched_source_nodes } ordered_unmatched_target_nodes = { - id(n[0]): None for n in self._target.bfs() if id(n[0]) in self._unmatched_target_nodes + id(n): None for n, *_ in self._target.bfs() if id(n) in self._unmatched_target_nodes } for source_node_id in ordered_unmatched_source_nodes: @@ -304,18 +304,18 @@ class ChangeDistiller: def _get_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]: has_child_exprs = False - for a in expression.args.values(): - for node in ensure_collection(a): - if isinstance(node, exp.Expression): - has_child_exprs = True - yield from _get_leaves(node) + for _, node in expression.iter_expressions(): + has_child_exprs = True + yield from _get_leaves(node) if not has_child_exprs: yield expression def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool: - if type(source) is type(target): + if type(source) is type(target) and ( + not isinstance(source, exp.Identifier) or type(source.parent) is type(target.parent) + ): if isinstance(source, exp.Join): return source.args.get("side") == target.args.get("side") @@ -331,7 +331,7 @@ def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]: args: t.List[t.Union[exp.Expression, t.List]] = [] if expression: for a in expression.args.values(): - args.extend(ensure_collection(a)) + args.extend(ensure_list(a)) return [a for a in args if isinstance(a, exp.Expression)] diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py index a676e7d..a67c155 100644 --- a/sqlglot/executor/__init__.py +++ b/sqlglot/executor/__init__.py @@ -57,7 +57,7 @@ def execute( for name, table in tables_.mapping.items() } - schema = ensure_schema(schema) + schema = ensure_schema(schema, dialect=read) if tables_.supported_table_args and tables_.supported_table_args != schema.supported_table_args: raise ExecuteError("Tables must support the same table args as schema") diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index d417328..b71cc6a 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -94,13 +94,10 @@ class PythonExecutor: if source and isinstance(source, exp.Expression): source = source.name or source.alias - condition = self.generate(step.condition) - projections = self.generate_tuple(step.projections) - if source is None: context, table_iter = self.static() elif source in context: - if not projections and not condition: + if not step.projections and not step.condition: return self.context({step.name: context.tables[source]}) table_iter = context.table_iter(source) elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV): @@ -109,10 +106,12 @@ class PythonExecutor: else: context, table_iter = self.scan_table(step) - if projections: - sink = self.table(step.projections) - else: - sink = self.table(context.columns) + return self.context({step.name: self._project_and_filter(context, step, table_iter)}) + + def _project_and_filter(self, context, step, table_iter): + sink = self.table(step.projections if step.projections else context.columns) + condition = self.generate(step.condition) + projections = self.generate_tuple(step.projections) for reader in table_iter: if len(sink) >= step.limit: @@ -126,7 +125,7 @@ class PythonExecutor: else: sink.append(reader.row) - return self.context({step.name: sink}) + return sink def static(self): return self.context({}), [RowReader(())] @@ -185,27 +184,16 @@ class PythonExecutor: if condition: source_context.filter(condition) - condition = self.generate(step.condition) - projections = self.generate_tuple(step.projections) - - if not condition and not projections: + if not step.condition and not step.projections: return source_context - sink = self.table(step.projections if projections else source_context.columns) - - for reader, ctx in source_context: - if condition and not ctx.eval(condition): - continue - - if projections: - sink.append(ctx.eval_tuple(projections)) - else: - sink.append(reader.row) - - if len(sink) >= step.limit: - break + sink = self._project_and_filter( + source_context, + step, + (reader for reader, _ in iter(source_context)), + ) - if projections: + if step.projections: return self.context({step.name: sink}) else: return self.context( diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index b9da4cc..f4aae47 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -26,6 +26,7 @@ from sqlglot.helper import ( AutoName, camel_to_snake_case, ensure_collection, + ensure_list, seq_get, split_num_words, subclasses, @@ -84,7 +85,7 @@ class Expression(metaclass=_Expression): key = "expression" arg_types = {"this": True} - __slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta") + __slots__ = ("args", "parent", "arg_key", "comments", "_type", "_meta", "_hash") def __init__(self, **args: t.Any): self.args: t.Dict[str, t.Any] = args @@ -93,23 +94,31 @@ class Expression(metaclass=_Expression): self.comments: t.Optional[t.List[str]] = None self._type: t.Optional[DataType] = None self._meta: t.Optional[t.Dict[str, t.Any]] = None + self._hash: t.Optional[int] = None for arg_key, value in self.args.items(): self._set_parent(arg_key, value) def __eq__(self, other) -> bool: - return type(self) is type(other) and _norm_args(self) == _norm_args(other) + return type(self) is type(other) and hash(self) == hash(other) - def __hash__(self) -> int: - return hash( - ( - self.key, - tuple( - (k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items() - ), - ) + @property + def hashable_args(self) -> t.Any: + args = (self.args.get(k) for k in self.arg_types) + + return tuple( + (tuple(_norm_arg(a) for a in arg) if arg else None) + if type(arg) is list + else (_norm_arg(arg) if arg is not None and arg is not False else None) + for arg in args ) + def __hash__(self) -> int: + if self._hash is not None: + return self._hash + + return hash((self.__class__, self.hashable_args)) + @property def this(self): """ @@ -247,9 +256,6 @@ class Expression(metaclass=_Expression): """ new = deepcopy(self) new.parent = self.parent - for item, parent, _ in new.bfs(): - if isinstance(item, Expression) and parent: - item.parent = parent return new def append(self, arg_key, value): @@ -277,12 +283,12 @@ class Expression(metaclass=_Expression): self._set_parent(arg_key, value) def _set_parent(self, arg_key, value): - if isinstance(value, Expression): + if hasattr(value, "parent"): value.parent = self value.arg_key = arg_key - elif isinstance(value, list): + elif type(value) is list: for v in value: - if isinstance(v, Expression): + if hasattr(v, "parent"): v.parent = self v.arg_key = arg_key @@ -295,6 +301,17 @@ class Expression(metaclass=_Expression): return self.parent.depth + 1 return 0 + def iter_expressions(self) -> t.Iterator[t.Tuple[str, Expression]]: + """Yields the key and expression for all arguments, exploding list args.""" + for k, vs in self.args.items(): + if type(vs) is list: + for v in vs: + if hasattr(v, "parent"): + yield k, v + else: + if hasattr(vs, "parent"): + yield k, vs + def find(self, *expression_types: t.Type[E], bfs=True) -> E | None: """ Returns the first node in this tree which matches at least one of @@ -319,7 +336,7 @@ class Expression(metaclass=_Expression): Returns: The generator object. """ - for expression, _, _ in self.walk(bfs=bfs): + for expression, *_ in self.walk(bfs=bfs): if isinstance(expression, expression_types): yield expression @@ -345,6 +362,11 @@ class Expression(metaclass=_Expression): """ return self.find_ancestor(Select) + @property + def same_parent(self): + """Returns if the parent is the same class as itself.""" + return type(self.parent) is self.__class__ + def root(self) -> Expression: """ Returns the root expression of this tree. @@ -385,10 +407,8 @@ class Expression(metaclass=_Expression): if prune and prune(self, parent, key): return - for k, v in self.args.items(): - for node in ensure_collection(v): - if isinstance(node, Expression): - yield from node.dfs(self, k, prune) + for k, v in self.iter_expressions(): + yield from v.dfs(self, k, prune) def bfs(self, prune=None): """ @@ -407,18 +427,15 @@ class Expression(metaclass=_Expression): if prune and prune(item, parent, key): continue - if isinstance(item, Expression): - for k, v in item.args.items(): - for node in ensure_collection(v): - if isinstance(node, Expression): - queue.append((node, item, k)) + for k, v in item.iter_expressions(): + queue.append((v, item, k)) def unnest(self): """ Returns the first non parenthesis child or self. """ expression = self - while isinstance(expression, Paren): + while type(expression) is Paren: expression = expression.this return expression @@ -434,7 +451,7 @@ class Expression(metaclass=_Expression): """ Returns unnested operands as a tuple. """ - return tuple(arg.unnest() for arg in self.args.values() if arg) + return tuple(arg.unnest() for _, arg in self.iter_expressions()) def flatten(self, unnest=True): """ @@ -442,8 +459,8 @@ class Expression(metaclass=_Expression): A AND B AND C -> [A, B, C] """ - for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not isinstance(n, self.__class__)): - if not isinstance(node, self.__class__): + for node, _, _ in self.dfs(prune=lambda n, p, *_: p and not type(n) is self.__class__): + if not type(node) is self.__class__: yield node.unnest() if unnest else node def __str__(self): @@ -477,7 +494,7 @@ class Expression(metaclass=_Expression): v._to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "_to_s") else str(v) - for v in ensure_collection(vs) + for v in ensure_list(vs) if v is not None ) for k, vs in self.args.items() @@ -812,6 +829,10 @@ class Describe(Expression): arg_types = {"this": True, "kind": False} +class Pragma(Expression): + pass + + class Set(Expression): arg_types = {"expressions": False} @@ -1170,6 +1191,7 @@ class Drop(Expression): "temporary": False, "materialized": False, "cascade": False, + "constraints": False, } @@ -1232,11 +1254,11 @@ class Identifier(Expression): def quoted(self): return bool(self.args.get("quoted")) - def __eq__(self, other): - return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(other.this) - - def __hash__(self): - return hash((self.key, self.this.lower())) + @property + def hashable_args(self) -> t.Any: + if self.quoted and any(char.isupper() for char in self.this): + return (self.this, self.quoted) + return self.this.lower() @property def output_name(self): @@ -1322,15 +1344,9 @@ class Limit(Expression): class Literal(Condition): arg_types = {"this": True, "is_string": True} - def __eq__(self, other): - return ( - isinstance(other, Literal) - and self.this == other.this - and self.args["is_string"] == other.args["is_string"] - ) - - def __hash__(self): - return hash((self.key, self.this, self.args["is_string"])) + @property + def hashable_args(self) -> t.Any: + return (self.this, self.args.get("is_string")) @classmethod def number(cls, number) -> Literal: @@ -1784,7 +1800,7 @@ class Subqueryable(Unionable): instance = _maybe_copy(self, copy) return Subquery( this=instance, - alias=TableAlias(this=to_identifier(alias)), + alias=TableAlias(this=to_identifier(alias)) if alias else None, ) def limit(self, expression, dialect=None, copy=True, **opts) -> Select: @@ -2058,6 +2074,7 @@ class Lock(Expression): class Select(Subqueryable): arg_types = { "with": False, + "kind": False, "expressions": False, "hint": False, "distinct": False, @@ -3595,6 +3612,21 @@ class Initcap(Func): pass +class JSONKeyValue(Expression): + arg_types = {"this": True, "expression": True} + + +class JSONObject(Func): + arg_types = { + "expressions": False, + "null_handling": False, + "unique_keys": False, + "return_type": False, + "format_json": False, + "encoding": False, + } + + class JSONBContains(Binary): _sql_names = ["JSONB_CONTAINS"] @@ -3766,8 +3798,10 @@ class RegexpILike(Func): arg_types = {"this": True, "expression": True, "flag": False} +# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split.html +# limit is the number of times a pattern is applied class RegexpSplit(Func): - arg_types = {"this": True, "expression": True} + arg_types = {"this": True, "expression": True, "limit": False} class Repeat(Func): @@ -3967,25 +4001,8 @@ class When(Func): arg_types = {"matched": True, "source": False, "condition": False, "then": True} -def _norm_args(expression): - args = {} - - for k, arg in expression.args.items(): - if isinstance(arg, list): - arg = [_norm_arg(a) for a in arg] - if not arg: - arg = None - else: - arg = _norm_arg(arg) - - if arg is not None and arg is not False: - args[k] = arg - - return args - - def _norm_arg(arg): - return arg.lower() if isinstance(arg, str) else arg + return arg.lower() if type(arg) is str else arg ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func)) @@ -4512,7 +4529,7 @@ def to_identifier(name, quoted=None): elif isinstance(name, str): identifier = Identifier( this=name, - quoted=not re.match(SAFE_IDENTIFIER_RE, name) if quoted is None else quoted, + quoted=not SAFE_IDENTIFIER_RE.match(name) if quoted is None else quoted, ) else: raise ValueError(f"Name needs to be a string or an Identifier, got: {name.__class__}") @@ -4586,8 +4603,7 @@ def to_column(sql_path: str | Column, **kwargs) -> Column: return sql_path if not isinstance(sql_path, str): raise ValueError(f"Invalid type provided for column: {type(sql_path)}") - table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2)) - return Column(this=column_name, table=table_name, **kwargs) + return column(*reversed(sql_path.split(".")), **kwargs) # type: ignore def alias_( @@ -4672,7 +4688,8 @@ def subquery(expression, alias=None, dialect=None, **opts): def column( col: str | Identifier, table: t.Optional[str | Identifier] = None, - schema: t.Optional[str | Identifier] = None, + db: t.Optional[str | Identifier] = None, + catalog: t.Optional[str | Identifier] = None, quoted: t.Optional[bool] = None, ) -> Column: """ @@ -4681,7 +4698,8 @@ def column( Args: col: column name table: table name - schema: schema name + db: db name + catalog: catalog name quoted: whether or not to force quote each part Returns: Column: column instance @@ -4689,7 +4707,8 @@ def column( return Column( this=to_identifier(col, quoted=quoted), table=to_identifier(table, quoted=quoted), - schema=to_identifier(schema, quoted=quoted), + db=to_identifier(db, quoted=quoted), + catalog=to_identifier(catalog, quoted=quoted), ) @@ -4864,7 +4883,7 @@ def replace_children(expression, fun, *args, **kwargs): Replace children of an expression with the result of a lambda fun(child) -> exp. """ for k, v in expression.args.items(): - is_list_arg = isinstance(v, list) + is_list_arg = type(v) is list child_nodes = v if is_list_arg else [v] new_child_nodes = [] diff --git a/sqlglot/generator.py b/sqlglot/generator.py index a6f4772..6871dd8 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -110,6 +110,10 @@ class Generator: # Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed MATCHED_BY_SOURCE = True + # Whether or not limit and fetch are supported + # "ALL", "LIMIT", "FETCH" + LIMIT_FETCH = "ALL" + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -209,6 +213,7 @@ class Generator: "_leading_comma", "_max_text_width", "_comments", + "_cache", ) def __init__( @@ -265,19 +270,28 @@ class Generator: self._leading_comma = leading_comma self._max_text_width = max_text_width self._comments = comments + self._cache = None - def generate(self, expression: t.Optional[exp.Expression]) -> str: + def generate( + self, + expression: t.Optional[exp.Expression], + cache: t.Optional[t.Dict[int, str]] = None, + ) -> str: """ Generates a SQL string by interpreting the given syntax tree. Args expression: the syntax tree. + cache: an optional sql string cache. this leverages the hash of an expression which is slow, so only use this if you set _hash on each node. Returns the SQL string. """ + if cache is not None: + self._cache = cache self.unsupported_messages = [] sql = self.sql(expression).strip() + self._cache = None if self.unsupported_level == ErrorLevel.IGNORE: return sql @@ -387,6 +401,12 @@ class Generator: if key: return self.sql(expression.args.get(key)) + if self._cache is not None: + expression_id = hash(expression) + + if expression_id in self._cache: + return self._cache[expression_id] + transform = self.TRANSFORMS.get(expression.__class__) if callable(transform): @@ -407,7 +427,11 @@ class Generator: else: raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}") - return self.maybe_comment(sql, expression) if self._comments and comment else sql + sql = self.maybe_comment(sql, expression) if self._comments and comment else sql + + if self._cache is not None: + self._cache[expression_id] = sql + return sql def uncache_sql(self, expression: exp.Uncache) -> str: table = self.sql(expression, "this") @@ -697,7 +721,8 @@ class Generator: temporary = " TEMPORARY" if expression.args.get("temporary") else "" materialized = " MATERIALIZED" if expression.args.get("materialized") else "" cascade = " CASCADE" if expression.args.get("cascade") else "" - return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}" + constraints = " CONSTRAINTS" if expression.args.get("constraints") else "" + return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}" def except_sql(self, expression: exp.Except) -> str: return self.prepend_ctes( @@ -733,9 +758,9 @@ class Generator: def identifier_sql(self, expression: exp.Identifier) -> str: text = expression.name - text = text.lower() if self.normalize else text + text = text.lower() if self.normalize and not expression.quoted else text text = text.replace(self.identifier_end, self._escaped_identifier_end) - if expression.args.get("quoted") or should_identify(text, self.identify): + if expression.quoted or should_identify(text, self.identify): text = f"{self.identifier_start}{text}{self.identifier_end}" return text @@ -1191,6 +1216,9 @@ class Generator: ) return f"SET{expressions}" + def pragma_sql(self, expression: exp.Pragma) -> str: + return f"PRAGMA {self.sql(expression, 'this')}" + def lock_sql(self, expression: exp.Lock) -> str: if self.LOCKING_READS_SUPPORTED: lock_type = "UPDATE" if expression.args["update"] else "SHARE" @@ -1299,6 +1327,15 @@ class Generator: return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}" def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: + limit = expression.args.get("limit") + + if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch): + limit = exp.Limit(expression=limit.args.get("count")) + elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit): + limit = exp.Fetch(direction="FIRST", count=limit.expression) + + fetch = isinstance(limit, exp.Fetch) + return csv( *sqls, *[self.sql(sql) for sql in expression.args.get("joins") or []], @@ -1315,14 +1352,16 @@ class Generator: self.sql(expression, "sort"), self.sql(expression, "cluster"), self.sql(expression, "order"), - self.sql(expression, "limit"), - self.sql(expression, "offset"), + self.sql(expression, "offset") if fetch else self.sql(limit), + self.sql(limit) if fetch else self.sql(expression, "offset"), self.sql(expression, "lock"), self.sql(expression, "sample"), sep="", ) def select_sql(self, expression: exp.Select) -> str: + kind = expression.args.get("kind") + kind = f" AS {kind}" if kind else "" hint = self.sql(expression, "hint") distinct = self.sql(expression, "distinct") distinct = f" {distinct}" if distinct else "" @@ -1330,7 +1369,7 @@ class Generator: expressions = f"{self.sep()}{expressions}" if expressions else expressions sql = self.query_modifiers( expression, - f"SELECT{hint}{distinct}{expressions}", + f"SELECT{kind}{hint}{distinct}{expressions}", self.sql(expression, "into", comment=False), self.sql(expression, "from", comment=False), ) @@ -1552,6 +1591,25 @@ class Generator: exp.Case(ifs=[expression.copy()], default=expression.args.get("false")) ) + def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str: + return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}" + + def jsonobject_sql(self, expression: exp.JSONObject) -> str: + expressions = self.expressions(expression) + null_handling = expression.args.get("null_handling") + null_handling = f" {null_handling}" if null_handling else "" + unique_keys = expression.args.get("unique_keys") + if unique_keys is not None: + unique_keys = f" {'WITH' if unique_keys else 'WITHOUT'} UNIQUE KEYS" + else: + unique_keys = "" + return_type = self.sql(expression, "return_type") + return_type = f" RETURNING {return_type}" if return_type else "" + format_json = " FORMAT JSON" if expression.args.get("format_json") else "" + encoding = self.sql(expression, "encoding") + encoding = f" ENCODING {encoding}" if encoding else "" + return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})" + def in_sql(self, expression: exp.In) -> str: query = expression.args.get("query") unnest = expression.args.get("unnest") @@ -1808,12 +1866,18 @@ class Generator: def ilike_sql(self, expression: exp.ILike) -> str: return self.binary(expression, "ILIKE") + def ilikeany_sql(self, expression: exp.ILikeAny) -> str: + return self.binary(expression, "ILIKE ANY") + def is_sql(self, expression: exp.Is) -> str: return self.binary(expression, "IS") def like_sql(self, expression: exp.Like) -> str: return self.binary(expression, "LIKE") + def likeany_sql(self, expression: exp.LikeAny) -> str: + return self.binary(expression, "LIKE ANY") + def similarto_sql(self, expression: exp.SimilarTo) -> str: return self.binary(expression, "SIMILAR TO") diff --git a/sqlglot/helper.py b/sqlglot/helper.py index 6eff974..d44d7dd 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -59,7 +59,7 @@ def ensure_list(value): """ if value is None: return [] - elif isinstance(value, (list, tuple)): + if isinstance(value, (list, tuple)): return list(value) return [value] @@ -162,9 +162,7 @@ def camel_to_snake_case(name: str) -> str: return CAMEL_CASE_PATTERN.sub("_", name).upper() -def while_changing( - expression: t.Optional[Expression], func: t.Callable[[t.Optional[Expression]], E] -) -> E: +def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E: """ Applies a transformation to a given expression until a fix point is reached. @@ -176,8 +174,13 @@ def while_changing( The transformed expression. """ while True: + for n, *_ in reversed(tuple(expression.walk())): + n._hash = hash(n) start = hash(expression) expression = func(expression) + + for n, *_ in expression.walk(): + n._hash = None if start == hash(expression): break return expression diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index c2d6655..99888c6 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,5 +1,5 @@ from sqlglot import exp -from sqlglot.helper import ensure_collection, ensure_list, subclasses +from sqlglot.helper import ensure_list, subclasses from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import ensure_schema @@ -108,6 +108,7 @@ class TypeAnnotator: exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), + exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), exp.GroupConcat: lambda self, expr: self._annotate_with_type( expr, exp.DataType.Type.VARCHAR @@ -116,6 +117,7 @@ class TypeAnnotator: expr, exp.DataType.Type.VARCHAR ), exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL), exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"), exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), @@ -296,9 +298,6 @@ class TypeAnnotator: return self._maybe_annotate(expression) # This takes care of non-traversable expressions def _maybe_annotate(self, expression): - if not isinstance(expression, exp.Expression): - return None - if expression.type: return expression # We've already inferred the expression's type @@ -311,9 +310,8 @@ class TypeAnnotator: ) def _annotate_args(self, expression): - for value in expression.args.values(): - for v in ensure_collection(value): - self._maybe_annotate(v) + for _, value in expression.iter_expressions(): + self._maybe_annotate(value) return expression diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index c5c780d..ef929ac 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -75,7 +75,7 @@ def _coerce_date(a: exp.Expression, b: exp.Expression) -> None: a.type and a.type.this == exp.DataType.Type.DATE and b.type - and b.type.this != exp.DataType.Type.DATE + and b.type.this not in (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL) ): _replace_cast(b, "date") diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index 8e6a520..e0ddfa2 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -1,7 +1,6 @@ from sqlglot import expressions as exp from sqlglot.optimizer.normalize import normalized from sqlglot.optimizer.scope import Scope, traverse_scope -from sqlglot.optimizer.simplify import simplify def eliminate_joins(expression): @@ -179,6 +178,4 @@ def join_condition(join): for condition in conditions: extract_condition(condition) - on = simplify(on) - remaining_condition = None if on == exp.true() else on - return source_key, join_key, remaining_condition + return source_key, join_key, on diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 6f9db82..a39fe96 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -3,7 +3,6 @@ import itertools from sqlglot import expressions as exp from sqlglot.helper import find_new_name from sqlglot.optimizer.scope import build_scope -from sqlglot.optimizer.simplify import simplify def eliminate_subqueries(expression): @@ -31,7 +30,6 @@ def eliminate_subqueries(expression): eliminate_subqueries(expression.this) return expression - expression = simplify(expression) root = build_scope(expression) # Map of alias->Scope|Table diff --git a/sqlglot/optimizer/lower_identities.py b/sqlglot/optimizer/lower_identities.py index 1cc76cf..fae1726 100644 --- a/sqlglot/optimizer/lower_identities.py +++ b/sqlglot/optimizer/lower_identities.py @@ -1,5 +1,4 @@ from sqlglot import exp -from sqlglot.helper import ensure_collection def lower_identities(expression): @@ -40,13 +39,10 @@ def lower_identities(expression): lower_identities(expression.right) traversed |= {"this", "expression"} - for k, v in expression.args.items(): + for k, v in expression.iter_expressions(): if k in traversed: continue - - for child in ensure_collection(v): - if isinstance(child, exp.Expression): - child.transform(_lower, copy=False) + v.transform(_lower, copy=False) return expression diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 70172f4..c3467b2 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -3,7 +3,6 @@ from collections import defaultdict from sqlglot import expressions as exp from sqlglot.helper import find_new_name from sqlglot.optimizer.scope import Scope, traverse_scope -from sqlglot.optimizer.simplify import simplify def merge_subqueries(expression, leave_tables_isolated=False): @@ -330,11 +329,11 @@ def _merge_where(outer_scope, inner_scope, from_or_join): if set(exp.column_table_names(where.this)) <= sources: from_or_join.on(where.this, copy=False) - from_or_join.set("on", simplify(from_or_join.args.get("on"))) + from_or_join.set("on", from_or_join.args.get("on")) return expression.where(where.this, copy=False) - expression.set("where", simplify(expression.args.get("where"))) + expression.set("where", expression.args.get("where")) def _merge_order(outer_scope, inner_scope): diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index f16f519..f2df230 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -1,29 +1,63 @@ +from __future__ import annotations + +import logging +import typing as t + from sqlglot import exp +from sqlglot.errors import OptimizeError from sqlglot.helper import while_changing -from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort +from sqlglot.optimizer.simplify import flatten, uniq_sort + +logger = logging.getLogger("sqlglot") -def normalize(expression, dnf=False, max_distance=128): +def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): """ - Rewrite sqlglot AST into conjunctive normal form. + Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. Example: >>> import sqlglot >>> expression = sqlglot.parse_one("(x AND y) OR z") - >>> normalize(expression).sql() + >>> normalize(expression, dnf=False).sql() '(x OR z) AND (y OR z)' Args: - expression (sqlglot.Expression): expression to normalize - dnf (bool): rewrite in disjunctive normal form instead - max_distance (int): the maximal estimated distance from cnf to attempt conversion + expression: expression to normalize + dnf: rewrite in disjunctive normal form instead. + max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion Returns: sqlglot.Expression: normalized expression """ - expression = simplify(expression) + cache: t.Dict[int, str] = {} + + for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): + if isinstance(node, exp.Connector): + if normalized(node, dnf=dnf): + continue + + distance = normalization_distance(node, dnf=dnf) + + if distance > max_distance: + logger.info( + f"Skipping normalization because distance {distance} exceeds max {max_distance}" + ) + return expression + + root = node is expression + original = node.copy() + try: + node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache)) + except OptimizeError as e: + logger.info(e) + node.replace(original) + if root: + return original + return expression + + if root: + expression = node - expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance)) - return simplify(expression) + return expression def normalized(expression, dnf=False): @@ -51,7 +85,7 @@ def normalization_distance(expression, dnf=False): int: difference """ return sum(_predicate_lengths(expression, dnf)) - ( - len(list(expression.find_all(exp.Connector))) + 1 + sum(1 for _ in expression.find_all(exp.Connector)) + 1 ) @@ -64,29 +98,32 @@ def _predicate_lengths(expression, dnf): expression = expression.unnest() if not isinstance(expression, exp.Connector): - return [1] + return (1,) left, right = expression.args.values() if isinstance(expression, exp.And if dnf else exp.Or): - return [ + return tuple( a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf) - ] + ) return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) -def distributive_law(expression, dnf, max_distance): +def distributive_law(expression, dnf, max_distance, cache=None): """ x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) """ - if isinstance(expression.unnest(), exp.Connector): - if normalization_distance(expression, dnf) > max_distance: - return expression + if normalized(expression, dnf=dnf): + return expression - to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) + distance = normalization_distance(expression, dnf=dnf) - exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) + if distance > max_distance: + raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") + + exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache)) + to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) if isinstance(expression, from_exp): a, b = expression.unnest_operands() @@ -96,32 +133,29 @@ def distributive_law(expression, dnf, max_distance): if isinstance(a, to_exp) and isinstance(b, to_exp): if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): - return _distribute(a, b, from_func, to_func) - return _distribute(b, a, from_func, to_func) + return _distribute(a, b, from_func, to_func, cache) + return _distribute(b, a, from_func, to_func, cache) if isinstance(a, to_exp): - return _distribute(b, a, from_func, to_func) + return _distribute(b, a, from_func, to_func, cache) if isinstance(b, to_exp): - return _distribute(a, b, from_func, to_func) + return _distribute(a, b, from_func, to_func, cache) return expression -def _distribute(a, b, from_func, to_func): +def _distribute(a, b, from_func, to_func, cache): if isinstance(a, exp.Connector): exp.replace_children( a, lambda c: to_func( - exp.paren(from_func(c, b.left)), - exp.paren(from_func(c, b.right)), + uniq_sort(flatten(from_func(c, b.left)), cache), + uniq_sort(flatten(from_func(c, b.right)), cache), ), ) else: - a = to_func(from_func(a, b.left), from_func(a, b.right)) - - return _simplify(a) - + a = to_func( + uniq_sort(flatten(from_func(a, b.left)), cache), + uniq_sort(flatten(from_func(a, b.right)), cache), + ) -def _simplify(node): - node = uniq_sort(flatten(node)) - exp.replace_children(node, _simplify) - return node + return a diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index dc5ce44..8589657 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -1,6 +1,5 @@ from sqlglot import exp from sqlglot.helper import tsort -from sqlglot.optimizer.simplify import simplify def optimize_joins(expression): @@ -29,7 +28,6 @@ def optimize_joins(expression): for name, join in cross_joins: for dep in references.get(name, []): on = dep.args["on"] - on = on.replace(simplify(on)) if isinstance(on, exp.Connector): for predicate in on.flatten(): diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index d9d04be..62eb11e 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -21,6 +21,7 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.pushdown_projections import pushdown_projections from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns from sqlglot.optimizer.qualify_tables import qualify_tables +from sqlglot.optimizer.simplify import simplify from sqlglot.optimizer.unnest_subqueries import unnest_subqueries from sqlglot.schema import ensure_schema @@ -43,6 +44,7 @@ RULES = ( eliminate_ctes, annotate_types, canonicalize, + simplify, ) @@ -78,7 +80,7 @@ def optimize( Returns: sqlglot.Expression: optimized expression """ - schema = ensure_schema(schema or sqlglot.schema) + schema = ensure_schema(schema or sqlglot.schema, dialect=dialect) possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs} expression = exp.maybe_parse(expression, dialect=dialect, copy=True) for rule in rules: diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 66b3170..5e40cf3 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -30,11 +30,12 @@ def qualify_columns(expression, schema): resolver = Resolver(scope, schema) _pop_table_column_aliases(scope.ctes) _pop_table_column_aliases(scope.derived_tables) - _expand_using(scope, resolver) + using_column_tables = _expand_using(scope, resolver) _qualify_columns(scope, resolver) if not isinstance(scope.expression, exp.UDTF): - _expand_stars(scope, resolver) + _expand_stars(scope, resolver, using_column_tables) _qualify_outputs(scope) + _expand_alias_refs(scope, resolver) _expand_group_by(scope, resolver) _expand_order_by(scope) @@ -69,11 +70,11 @@ def _pop_table_column_aliases(derived_tables): def _expand_using(scope, resolver): - joins = list(scope.expression.find_all(exp.Join)) + joins = list(scope.find_all(exp.Join)) names = {join.this.alias for join in joins} ordered = [key for key in scope.selected_sources if key not in names] - # Mapping of automatically joined column names to source names + # Mapping of automatically joined column names to an ordered set of source names (dict). column_tables = {} for join in joins: @@ -112,11 +113,12 @@ def _expand_using(scope, resolver): ) ) - tables = column_tables.setdefault(identifier, []) + # Set all values in the dict to None, because we only care about the key ordering + tables = column_tables.setdefault(identifier, {}) if table not in tables: - tables.append(table) + tables[table] = None if join_table not in tables: - tables.append(join_table) + tables[join_table] = None join.args.pop("using") join.set("on", exp.and_(*conditions)) @@ -134,11 +136,11 @@ def _expand_using(scope, resolver): scope.replace(column, replacement) + return column_tables -def _expand_group_by(scope, resolver): - group = scope.expression.args.get("group") - if not group: - return + +def _expand_alias_refs(scope, resolver): + selects = {} # Replace references to select aliases def transform(node, *_): @@ -150,9 +152,11 @@ def _expand_group_by(scope, resolver): node.set("table", table) return node - selects = {s.alias_or_name: s for s in scope.selects} - + if not selects: + for s in scope.selects: + selects[s.alias_or_name] = s select = selects.get(node.name) + if select: scope.clear_cache() if isinstance(select, exp.Alias): @@ -161,7 +165,21 @@ def _expand_group_by(scope, resolver): return node - group.transform(transform, copy=False) + for select in scope.expression.selects: + select.transform(transform, copy=False) + + for modifier in ("where", "group"): + part = scope.expression.args.get(modifier) + + if part: + part.transform(transform, copy=False) + + +def _expand_group_by(scope, resolver): + group = scope.expression.args.get("group") + if not group: + return + group.set("expressions", _expand_positional_references(scope, group.expressions)) scope.expression.set("group", group) @@ -231,18 +249,24 @@ def _qualify_columns(scope, resolver): column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) columns_missing_from_scope = [] + # Determine whether each reference in the order by clause is to a column or an alias. - for ordered in scope.find_all(exp.Ordered): - for column in ordered.find_all(exp.Column): - if ( - not column.table - and column.parent is not ordered - and column.name in resolver.all_columns - ): - columns_missing_from_scope.append(column) + order = scope.expression.args.get("order") + + if order: + for ordered in order.expressions: + for column in ordered.find_all(exp.Column): + if ( + not column.table + and column.parent is not ordered + and column.name in resolver.all_columns + ): + columns_missing_from_scope.append(column) # Determine whether each reference in the having clause is to a column or an alias. - for having in scope.find_all(exp.Having): + having = scope.expression.args.get("having") + + if having: for column in having.find_all(exp.Column): if ( not column.table @@ -258,12 +282,13 @@ def _qualify_columns(scope, resolver): column.set("table", column_table) -def _expand_stars(scope, resolver): +def _expand_stars(scope, resolver, using_column_tables): """Expand stars to lists of column selections""" new_selections = [] except_columns = {} replace_columns = {} + coalesced_columns = set() for expression in scope.selects: if isinstance(expression, exp.Star): @@ -286,7 +311,20 @@ def _expand_stars(scope, resolver): if columns and "*" not in columns: table_id = id(table) for name in columns: - if name not in except_columns.get(table_id, set()): + if name in using_column_tables and table in using_column_tables[name]: + if name in coalesced_columns: + continue + + coalesced_columns.add(name) + tables = using_column_tables[name] + coalesce = [exp.column(name, table=table) for table in tables] + + new_selections.append( + exp.alias_( + exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name + ) + ) + elif name not in except_columns.get(table_id, set()): alias_ = replace_columns.get(table_id, {}).get(name, name) column = exp.column(name, table) new_selections.append(alias(column, alias_) if alias_ != name else column) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 9c0768c..b582eb0 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -160,7 +160,7 @@ class Scope: Yields: exp.Expression: nodes """ - for expression, _, _ in self.walk(bfs=bfs): + for expression, *_ in self.walk(bfs=bfs): if isinstance(expression, expression_types): yield expression diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index f80484d..1ed3ca2 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -5,11 +5,10 @@ from collections import deque from decimal import Decimal from sqlglot import exp -from sqlglot.expressions import FALSE, NULL, TRUE from sqlglot.generator import Generator from sqlglot.helper import first, while_changing -GENERATOR = Generator(normalize=True, identify=True) +GENERATOR = Generator(normalize=True, identify="safe") def simplify(expression): @@ -28,18 +27,20 @@ def simplify(expression): sqlglot.Expression: simplified expression """ + cache = {} + def _simplify(expression, root=True): node = expression node = rewrite_between(node) - node = uniq_sort(node) - node = absorb_and_eliminate(node) + node = uniq_sort(node, cache, root) + node = absorb_and_eliminate(node, root) exp.replace_children(node, lambda e: _simplify(e, False)) node = simplify_not(node) node = flatten(node) - node = simplify_connectors(node) - node = remove_compliments(node) + node = simplify_connectors(node, root) + node = remove_compliments(node, root) node.parent = expression.parent - node = simplify_literals(node) + node = simplify_literals(node, root) node = simplify_parens(node) if root: expression.replace(node) @@ -70,7 +71,7 @@ def simplify_not(expression): NOT (x AND y) -> NOT x OR NOT y """ if isinstance(expression, exp.Not): - if isinstance(expression.this, exp.Null): + if is_null(expression.this): return exp.null() if isinstance(expression.this, exp.Paren): condition = expression.this.unnest() @@ -78,11 +79,11 @@ def simplify_not(expression): return exp.or_(exp.not_(condition.left), exp.not_(condition.right)) if isinstance(condition, exp.Or): return exp.and_(exp.not_(condition.left), exp.not_(condition.right)) - if isinstance(condition, exp.Null): + if is_null(condition): return exp.null() if always_true(expression.this): return exp.false() - if expression.this == FALSE: + if is_false(expression.this): return exp.true() if isinstance(expression.this, exp.Not): # double negation @@ -104,42 +105,42 @@ def flatten(expression): return expression -def simplify_connectors(expression): +def simplify_connectors(expression, root=True): def _simplify_connectors(expression, left, right): - if isinstance(expression, exp.Connector): - if left == right: + if left == right: + return left + if isinstance(expression, exp.And): + if is_false(left) or is_false(right): + return exp.false() + if is_null(left) or is_null(right): + return exp.null() + if always_true(left) and always_true(right): + return exp.true() + if always_true(left): + return right + if always_true(right): return left - if isinstance(expression, exp.And): - if FALSE in (left, right): - return exp.false() - if NULL in (left, right): - return exp.null() - if always_true(left) and always_true(right): - return exp.true() - if always_true(left): - return right - if always_true(right): - return left - return _simplify_comparison(expression, left, right) - elif isinstance(expression, exp.Or): - if always_true(left) or always_true(right): - return exp.true() - if left == FALSE and right == FALSE: - return exp.false() - if ( - (left == NULL and right == NULL) - or (left == NULL and right == FALSE) - or (left == FALSE and right == NULL) - ): - return exp.null() - if left == FALSE: - return right - if right == FALSE: - return left - return _simplify_comparison(expression, left, right, or_=True) - return None + return _simplify_comparison(expression, left, right) + elif isinstance(expression, exp.Or): + if always_true(left) or always_true(right): + return exp.true() + if is_false(left) and is_false(right): + return exp.false() + if ( + (is_null(left) and is_null(right)) + or (is_null(left) and is_false(right)) + or (is_false(left) and is_null(right)) + ): + return exp.null() + if is_false(left): + return right + if is_false(right): + return left + return _simplify_comparison(expression, left, right, or_=True) - return _flat_simplify(expression, _simplify_connectors) + if isinstance(expression, exp.Connector): + return _flat_simplify(expression, _simplify_connectors, root) + return expression LT_LTE = (exp.LT, exp.LTE) @@ -220,14 +221,14 @@ def _simplify_comparison(expression, left, right, or_=False): return None -def remove_compliments(expression): +def remove_compliments(expression, root=True): """ Removing compliments. A AND NOT A -> FALSE A OR NOT A -> TRUE """ - if isinstance(expression, exp.Connector): + if isinstance(expression, exp.Connector) and (root or not expression.same_parent): compliment = exp.false() if isinstance(expression, exp.And) else exp.true() for a, b in itertools.permutations(expression.flatten(), 2): @@ -236,23 +237,23 @@ def remove_compliments(expression): return expression -def uniq_sort(expression): +def uniq_sort(expression, cache=None, root=True): """ Uniq and sort a connector. C AND A AND B AND B -> A AND B AND C """ - if isinstance(expression, exp.Connector): + if isinstance(expression, exp.Connector) and (root or not expression.same_parent): result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ flattened = tuple(expression.flatten()) - deduped = {GENERATOR.generate(e): e for e in flattened} + deduped = {GENERATOR.generate(e, cache): e for e in flattened} arr = tuple(deduped.items()) # check if the operands are already sorted, if not sort them # A AND C AND B -> A AND B AND C for i, (sql, e) in enumerate(arr[1:]): if sql < arr[i][0]: - expression = result_func(*(deduped[sql] for sql in sorted(deduped))) + expression = result_func(*(e for _, e in sorted(arr))) break else: # we didn't have to sort but maybe we need to dedup @@ -262,7 +263,7 @@ def uniq_sort(expression): return expression -def absorb_and_eliminate(expression): +def absorb_and_eliminate(expression, root=True): """ absorption: A AND (A OR B) -> A @@ -273,7 +274,7 @@ def absorb_and_eliminate(expression): (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A """ - if isinstance(expression, exp.Connector): + if isinstance(expression, exp.Connector) and (root or not expression.same_parent): kind = exp.Or if isinstance(expression, exp.And) else exp.And for a, b in itertools.permutations(expression.flatten(), 2): @@ -302,9 +303,9 @@ def absorb_and_eliminate(expression): return expression -def simplify_literals(expression): - if isinstance(expression, exp.Binary): - return _flat_simplify(expression, _simplify_binary) +def simplify_literals(expression, root=True): + if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): + return _flat_simplify(expression, _simplify_binary, root) elif isinstance(expression, exp.Neg): this = expression.this if this.is_number: @@ -325,14 +326,14 @@ def _simplify_binary(expression, a, b): c = b not_ = False - if c == NULL: + if is_null(c): if isinstance(a, exp.Literal): return exp.true() if not_ else exp.false() - if a == NULL: + if is_null(a): return exp.false() if not_ else exp.true() elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): return None - elif NULL in (a, b): + elif is_null(a) or is_null(b): return exp.null() if a.is_number and b.is_number: @@ -355,7 +356,7 @@ def _simplify_binary(expression, a, b): if boolean: return boolean elif a.is_string and b.is_string: - boolean = eval_boolean(expression, a, b) + boolean = eval_boolean(expression, a.this, b.this) if boolean: return boolean @@ -381,7 +382,7 @@ def simplify_parens(expression): and not isinstance(expression.this, exp.Select) and ( not isinstance(expression.parent, (exp.Condition, exp.Binary)) - or isinstance(expression.this, (exp.Is, exp.Like)) + or isinstance(expression.this, exp.Predicate) or not isinstance(expression.this, exp.Binary) ) ): @@ -400,13 +401,23 @@ def remove_where_true(expression): def always_true(expression): - return expression == TRUE or isinstance(expression, exp.Literal) + return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( + expression, exp.Literal + ) def is_complement(a, b): return isinstance(b, exp.Not) and b.this == a +def is_false(a: exp.Expression) -> bool: + return type(a) is exp.Boolean and not a.this + + +def is_null(a: exp.Expression) -> bool: + return type(a) is exp.Null + + def eval_boolean(expression, a, b): if isinstance(expression, (exp.EQ, exp.Is)): return boolean_literal(a == b) @@ -466,24 +477,27 @@ def boolean_literal(condition): return exp.true() if condition else exp.false() -def _flat_simplify(expression, simplifier): - operands = [] - queue = deque(expression.flatten(unnest=False)) - size = len(queue) +def _flat_simplify(expression, simplifier, root=True): + if root or not expression.same_parent: + operands = [] + queue = deque(expression.flatten(unnest=False)) + size = len(queue) - while queue: - a = queue.popleft() + while queue: + a = queue.popleft() - for b in queue: - result = simplifier(expression, a, b) + for b in queue: + result = simplifier(expression, a, b) - if result: - queue.remove(b) - queue.append(result) - break - else: - operands.append(a) + if result: + queue.remove(b) + queue.append(result) + break + else: + operands.append(a) - if len(operands) < size: - return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands) + if len(operands) < size: + return functools.reduce( + lambda a, b: expression.__class__(this=a, expression=b), operands + ) return expression diff --git a/sqlglot/parser.py b/sqlglot/parser.py index a36251e..8269525 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -19,7 +19,7 @@ from sqlglot.trie import in_trie, new_trie logger = logging.getLogger("sqlglot") -def parse_var_map(args): +def parse_var_map(args: t.Sequence) -> exp.Expression: keys = [] values = [] for i in range(0, len(args), 2): @@ -31,6 +31,11 @@ def parse_var_map(args): ) +def parse_like(args): + like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0)) + return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like + + def binary_range_parser( expr_type: t.Type[exp.Expression], ) -> t.Callable[[Parser, t.Optional[exp.Expression]], t.Optional[exp.Expression]]: @@ -77,6 +82,9 @@ class Parser(metaclass=_Parser): this=seq_get(args, 0), to=exp.DataType(this=exp.DataType.Type.TEXT), ), + "GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)), + "IFNULL": exp.Coalesce.from_arg_list, + "LIKE": parse_like, "TIME_TO_TIME_STR": lambda args: exp.Cast( this=seq_get(args, 0), to=exp.DataType(this=exp.DataType.Type.TEXT), @@ -90,7 +98,6 @@ class Parser(metaclass=_Parser): length=exp.Literal.number(10), ), "VAR_MAP": parse_var_map, - "IFNULL": exp.Coalesce.from_arg_list, } NO_PAREN_FUNCTIONS = { @@ -211,6 +218,7 @@ class Parser(metaclass=_Parser): TokenType.FILTER, TokenType.FOLLOWING, TokenType.FORMAT, + TokenType.FULL, TokenType.IF, TokenType.ISNULL, TokenType.INTERVAL, @@ -226,8 +234,10 @@ class Parser(metaclass=_Parser): TokenType.ONLY, TokenType.OPTIONS, TokenType.ORDINALITY, + TokenType.PARTITION, TokenType.PERCENT, TokenType.PIVOT, + TokenType.PRAGMA, TokenType.PRECEDING, TokenType.RANGE, TokenType.REFERENCES, @@ -257,6 +267,7 @@ class Parser(metaclass=_Parser): TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - { TokenType.APPLY, + TokenType.FULL, TokenType.LEFT, TokenType.NATURAL, TokenType.OFFSET, @@ -277,6 +288,7 @@ class Parser(metaclass=_Parser): TokenType.FILTER, TokenType.FIRST, TokenType.FORMAT, + TokenType.GLOB, TokenType.IDENTIFIER, TokenType.INDEX, TokenType.ISNULL, @@ -461,6 +473,7 @@ class Parser(metaclass=_Parser): TokenType.INSERT: lambda self: self._parse_insert(), TokenType.LOAD_DATA: lambda self: self._parse_load_data(), TokenType.MERGE: lambda self: self._parse_merge(), + TokenType.PRAGMA: lambda self: self.expression(exp.Pragma, this=self._parse_expression()), TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), TokenType.SET: lambda self: self._parse_set(), TokenType.UNCACHE: lambda self: self._parse_uncache(), @@ -662,6 +675,8 @@ class Parser(metaclass=_Parser): "CAST": lambda self: self._parse_cast(self.STRICT_CAST), "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), "EXTRACT": lambda self: self._parse_extract(), + "JSON_OBJECT": lambda self: self._parse_json_object(), + "LOG": lambda self: self._parse_logarithm(), "POSITION": lambda self: self._parse_position(), "STRING_AGG": lambda self: self._parse_string_agg(), "SUBSTRING": lambda self: self._parse_substring(), @@ -719,6 +734,9 @@ class Parser(metaclass=_Parser): CONVERT_TYPE_FIRST = False + LOG_BASE_FIRST = True + LOG_DEFAULTS_TO_LN = False + __slots__ = ( "error_level", "error_message_context", @@ -1032,6 +1050,7 @@ class Parser(metaclass=_Parser): temporary=temporary, materialized=materialized, cascade=self._match(TokenType.CASCADE), + constraints=self._match_text_seq("CONSTRAINTS"), ) def _parse_exists(self, not_: bool = False) -> t.Optional[bool]: @@ -1221,7 +1240,7 @@ class Parser(metaclass=_Parser): if not identified_property: break - for p in ensure_collection(identified_property): + for p in ensure_list(identified_property): properties.append(p) if properties: @@ -1704,6 +1723,11 @@ class Parser(metaclass=_Parser): elif self._match(TokenType.SELECT): comments = self._prev_comments + kind = ( + self._match(TokenType.ALIAS) + and self._match_texts(("STRUCT", "VALUE")) + and self._prev.text + ) hint = self._parse_hint() all_ = self._match(TokenType.ALL) distinct = self._match(TokenType.DISTINCT) @@ -1722,6 +1746,7 @@ class Parser(metaclass=_Parser): this = self.expression( exp.Select, + kind=kind, hint=hint, distinct=distinct, expressions=expressions, @@ -2785,7 +2810,6 @@ class Parser(metaclass=_Parser): this = seq_get(expressions, 0) self._parse_query_modifiers(this) - self._match_r_paren() if isinstance(this, exp.Subqueryable): this = self._parse_set_operations( @@ -2794,7 +2818,9 @@ class Parser(metaclass=_Parser): elif len(expressions) > 1: this = self.expression(exp.Tuple, expressions=expressions) else: - this = self.expression(exp.Paren, this=this) + this = self.expression(exp.Paren, this=self._parse_set_operations(this)) + + self._match_r_paren() if this and comments: this.comments = comments @@ -3318,6 +3344,60 @@ class Parser(metaclass=_Parser): return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) + def _parse_json_key_value(self) -> t.Optional[exp.Expression]: + self._match_text_seq("KEY") + key = self._parse_field() + self._match(TokenType.COLON) + self._match_text_seq("VALUE") + value = self._parse_field() + if not key and not value: + return None + return self.expression(exp.JSONKeyValue, this=key, expression=value) + + def _parse_json_object(self) -> exp.Expression: + expressions = self._parse_csv(self._parse_json_key_value) + + null_handling = None + if self._match_text_seq("NULL", "ON", "NULL"): + null_handling = "NULL ON NULL" + elif self._match_text_seq("ABSENT", "ON", "NULL"): + null_handling = "ABSENT ON NULL" + + unique_keys = None + if self._match_text_seq("WITH", "UNIQUE"): + unique_keys = True + elif self._match_text_seq("WITHOUT", "UNIQUE"): + unique_keys = False + + self._match_text_seq("KEYS") + + return_type = self._match_text_seq("RETURNING") and self._parse_type() + format_json = self._match_text_seq("FORMAT", "JSON") + encoding = self._match_text_seq("ENCODING") and self._parse_var() + + return self.expression( + exp.JSONObject, + expressions=expressions, + null_handling=null_handling, + unique_keys=unique_keys, + return_type=return_type, + format_json=format_json, + encoding=encoding, + ) + + def _parse_logarithm(self) -> exp.Expression: + # Default argument order is base, expression + args = self._parse_csv(self._parse_range) + + if len(args) > 1: + if not self.LOG_BASE_FIRST: + args.reverse() + return exp.Log.from_arg_list(args) + + return self.expression( + exp.Ln if self.LOG_DEFAULTS_TO_LN else exp.Log, this=seq_get(args, 0) + ) + def _parse_position(self, haystack_first: bool = False) -> exp.Expression: args = self._parse_csv(self._parse_bitwise) @@ -3654,7 +3734,7 @@ class Parser(metaclass=_Parser): return parse_result def _parse_select_or_expression(self) -> t.Optional[exp.Expression]: - return self._parse_select() or self._parse_expression() + return self._parse_select() or self._parse_set_operations(self._parse_expression()) def _parse_ddl_select(self) -> t.Optional[exp.Expression]: return self._parse_set_operations( @@ -3741,6 +3821,8 @@ class Parser(metaclass=_Parser): expression = self._parse_foreign_key() elif kind == TokenType.PRIMARY_KEY or self._match(TokenType.PRIMARY_KEY): expression = self._parse_primary_key() + else: + expression = None return self.expression(exp.AddConstraint, this=this, expression=expression) @@ -3799,12 +3881,15 @@ class Parser(metaclass=_Parser): parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None if parser: - return self.expression( - exp.AlterTable, - this=this, - exists=exists, - actions=ensure_list(parser(self)), - ) + actions = ensure_list(parser(self)) + + if not self._curr: + return self.expression( + exp.AlterTable, + this=this, + exists=exists, + actions=actions, + ) return self._parse_as_command(start) def _parse_merge(self) -> exp.Expression: diff --git a/sqlglot/planner.py b/sqlglot/planner.py index 40df39f..5fd96ef 100644 --- a/sqlglot/planner.py +++ b/sqlglot/planner.py @@ -175,7 +175,7 @@ class Step: } for projection in projections: for i, e in aggregate.group.items(): - for child, _, _ in projection.walk(): + for child, *_ in projection.walk(): if child == e: child.replace(exp.column(i, step.name)) aggregate.add_dependency(step) diff --git a/sqlglot/schema.py b/sqlglot/schema.py index f5d9f2b..8e39c7f 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -306,11 +306,11 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): return self._type_mapping_cache[schema_type] -def ensure_schema(schema: t.Any) -> Schema: +def ensure_schema(schema: t.Any, dialect: DialectType = None) -> Schema: if isinstance(schema, Schema): return schema - return MappingSchema(schema) + return MappingSchema(schema, dialect=dialect) def ensure_column_mapping(mapping: t.Optional[ColumnMapping]): diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index eb3c08f..e5b44e7 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -252,6 +252,7 @@ class TokenType(AutoName): PERCENT = auto() PIVOT = auto() PLACEHOLDER = auto() + PRAGMA = auto() PRECEDING = auto() PRIMARY_KEY = auto() PROCEDURE = auto() @@ -346,7 +347,8 @@ class Token: self.token_type = token_type self.text = text self.line = line - self.col = max(col - len(text), 1) + self.col = col - len(text) + self.col = self.col if self.col > 1 else 1 self.comments = comments def __repr__(self) -> str: @@ -586,6 +588,7 @@ class Tokenizer(metaclass=_Tokenizer): "PARTITIONED_BY": TokenType.PARTITION_BY, "PERCENT": TokenType.PERCENT, "PIVOT": TokenType.PIVOT, + "PRAGMA": TokenType.PRAGMA, "PRECEDING": TokenType.PRECEDING, "PRIMARY KEY": TokenType.PRIMARY_KEY, "PROCEDURE": TokenType.PROCEDURE, @@ -654,6 +657,7 @@ class Tokenizer(metaclass=_Tokenizer): "LONG": TokenType.BIGINT, "BIGINT": TokenType.BIGINT, "INT8": TokenType.BIGINT, + "DEC": TokenType.DECIMAL, "DECIMAL": TokenType.DECIMAL, "MAP": TokenType.MAP, "NULLABLE": TokenType.NULLABLE, @@ -714,7 +718,7 @@ class Tokenizer(metaclass=_Tokenizer): "VACUUM": TokenType.COMMAND, } - WHITE_SPACE: t.Dict[str, TokenType] = { + WHITE_SPACE: t.Dict[t.Optional[str], TokenType] = { " ": TokenType.SPACE, "\t": TokenType.SPACE, "\n": TokenType.BREAK, @@ -813,11 +817,8 @@ class Tokenizer(metaclass=_Tokenizer): return self.sql[start:end] return "" - def _line_break(self, char: t.Optional[str]) -> bool: - return self.WHITE_SPACE.get(char) == TokenType.BREAK # type: ignore - def _advance(self, i: int = 1) -> None: - if self._line_break(self._char): + if self.WHITE_SPACE.get(self._char) is TokenType.BREAK: self._set_new_line() self._col += i @@ -939,7 +940,7 @@ class Tokenizer(metaclass=_Tokenizer): self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) # type: ignore self._advance(comment_end_size - 1) else: - while not self._end and not self._line_break(self._peek): + while not self._end and not self.WHITE_SPACE.get(self._peek) is TokenType.BREAK: self._advance() self._comments.append(self._text[comment_start_size:]) # type: ignore -- cgit v1.2.3