From 20739a12c39121a9e7ad3c9a2469ec5a6876199d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 3 Jun 2023 01:59:40 +0200 Subject: Merging upstream version 15.0.0. Signed-off-by: Daniel Baumann --- sqlglot/optimizer/canonicalize.py | 10 +- sqlglot/optimizer/eliminate_ctes.py | 39 +++-- sqlglot/optimizer/eliminate_subqueries.py | 19 +- sqlglot/optimizer/expand_laterals.py | 34 ---- sqlglot/optimizer/expand_multi_table_selects.py | 24 --- sqlglot/optimizer/isolate_table_selects.py | 2 +- sqlglot/optimizer/lower_identities.py | 88 ---------- sqlglot/optimizer/merge_subqueries.py | 17 +- sqlglot/optimizer/normalize.py | 35 ++-- sqlglot/optimizer/normalize_identifiers.py | 36 ++++ sqlglot/optimizer/optimize_joins.py | 7 +- sqlglot/optimizer/optimizer.py | 37 ++-- sqlglot/optimizer/pushdown_predicates.py | 42 ++--- sqlglot/optimizer/pushdown_projections.py | 9 +- sqlglot/optimizer/qualify.py | 80 +++++++++ sqlglot/optimizer/qualify_columns.py | 221 ++++++++++++++---------- sqlglot/optimizer/qualify_tables.py | 45 +++-- sqlglot/optimizer/scope.py | 43 ++++- sqlglot/optimizer/simplify.py | 32 ++-- sqlglot/optimizer/unnest_subqueries.py | 23 +-- 20 files changed, 451 insertions(+), 392 deletions(-) delete mode 100644 sqlglot/optimizer/expand_laterals.py delete mode 100644 sqlglot/optimizer/expand_multi_table_selects.py delete mode 100644 sqlglot/optimizer/lower_identities.py create mode 100644 sqlglot/optimizer/normalize_identifiers.py create mode 100644 sqlglot/optimizer/qualify.py (limited to 'sqlglot/optimizer') diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index ef929ac..da2fce8 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -3,10 +3,9 @@ from __future__ import annotations import itertools from sqlglot import exp -from sqlglot.helper import should_identify -def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expression: +def canonicalize(expression: exp.Expression) -> exp.Expression: """Converts a sql expression into a standard form. This method relies on annotate_types because many of the @@ -14,19 +13,14 @@ def canonicalize(expression: exp.Expression, identify: str = "safe") -> exp.Expr Args: expression: The expression to canonicalize. - identify: Whether or not to force identify identifier. """ - exp.replace_children(expression, canonicalize, identify=identify) + exp.replace_children(expression, canonicalize) expression = add_text_to_concat(expression) expression = coerce_type(expression) expression = remove_redundant_casts(expression) expression = ensure_bool_predicates(expression) - if isinstance(expression, exp.Identifier): - if should_identify(expression.this, identify): - expression.set("quoted", True) - return expression diff --git a/sqlglot/optimizer/eliminate_ctes.py b/sqlglot/optimizer/eliminate_ctes.py index 7b862c6..6f1865c 100644 --- a/sqlglot/optimizer/eliminate_ctes.py +++ b/sqlglot/optimizer/eliminate_ctes.py @@ -19,24 +19,25 @@ def eliminate_ctes(expression): """ root = build_scope(expression) - ref_count = root.ref_count() - - # Traverse the scope tree in reverse so we can remove chains of unused CTEs - for scope in reversed(list(root.traverse())): - if scope.is_cte: - count = ref_count[id(scope)] - if count <= 0: - cte_node = scope.expression.parent - with_node = cte_node.parent - cte_node.pop() - - # Pop the entire WITH clause if this is the last CTE - if len(with_node.expressions) <= 0: - with_node.pop() - - # Decrement the ref count for all sources this CTE selects from - for _, source in scope.selected_sources.values(): - if isinstance(source, Scope): - ref_count[id(source)] -= 1 + if root: + ref_count = root.ref_count() + + # Traverse the scope tree in reverse so we can remove chains of unused CTEs + for scope in reversed(list(root.traverse())): + if scope.is_cte: + count = ref_count[id(scope)] + if count <= 0: + cte_node = scope.expression.parent + with_node = cte_node.parent + cte_node.pop() + + # Pop the entire WITH clause if this is the last CTE + if len(with_node.expressions) <= 0: + with_node.pop() + + # Decrement the ref count for all sources this CTE selects from + for _, source in scope.selected_sources.values(): + if isinstance(source, Scope): + ref_count[id(source)] -= 1 return expression diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index a39fe96..84f50e9 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -16,9 +16,9 @@ def eliminate_subqueries(expression): 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y' This also deduplicates common subqueries: - >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y JOIN (SELECT * FROM x) AS z") + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z") >>> eliminate_subqueries(expression).sql() - 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y JOIN y AS z' + 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z' Args: expression (sqlglot.Expression): expression @@ -32,6 +32,9 @@ def eliminate_subqueries(expression): root = build_scope(expression) + if not root: + return expression + # Map of alias->Scope|Table # These are all aliases that are already used in the expression. # We don't want to create new CTEs that conflict with these names. @@ -112,7 +115,7 @@ def _eliminate_union(scope, existing_ctes, taken): # Try to maintain the selections expressions = scope.selects selects = [ - exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name) + exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False) for e in expressions if e.alias_or_name ] @@ -120,7 +123,9 @@ def _eliminate_union(scope, existing_ctes, taken): if len(selects) != len(expressions): selects = ["*"] - scope.expression.replace(exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias))) + scope.expression.replace( + exp.select(*selects).from_(exp.alias_(exp.table_(alias), alias=alias, copy=False)) + ) if not duplicate_cte_alias: existing_ctes[scope.expression] = alias @@ -131,6 +136,10 @@ def _eliminate_union(scope, existing_ctes, taken): def _eliminate_derived_table(scope, existing_ctes, taken): + # This ensures we don't drop the "pivot" arg from a pivoted subquery + if scope.parent.pivots: + return None + parent = scope.expression.parent name, cte = _new_cte(scope, existing_ctes, taken) @@ -153,7 +162,7 @@ def _eliminate_cte(scope, existing_ctes, taken): for child_scope in scope.parent.traverse(): for table, source in child_scope.selected_sources.values(): if source is scope: - new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name) + new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False) table.replace(new_table) return cte diff --git a/sqlglot/optimizer/expand_laterals.py b/sqlglot/optimizer/expand_laterals.py deleted file mode 100644 index 5b2f706..0000000 --- a/sqlglot/optimizer/expand_laterals.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp - - -def expand_laterals(expression: exp.Expression) -> exp.Expression: - """ - Expand lateral column alias references. - - This assumes `qualify_columns` as already run. - - Example: - >>> import sqlglot - >>> sql = "SELECT x.a + 1 AS b, b + 1 AS c FROM x" - >>> expression = sqlglot.parse_one(sql) - >>> expand_laterals(expression).sql() - 'SELECT x.a + 1 AS b, x.a + 1 + 1 AS c FROM x' - - Args: - expression: expression to optimize - Returns: - optimized expression - """ - for select in expression.find_all(exp.Select): - alias_to_expression: t.Dict[str, exp.Expression] = {} - for projection in select.expressions: - for column in projection.find_all(exp.Column): - if not column.table and column.name in alias_to_expression: - column.replace(alias_to_expression[column.name].copy()) - if isinstance(projection, exp.Alias): - alias_to_expression[projection.alias] = projection.this - return expression diff --git a/sqlglot/optimizer/expand_multi_table_selects.py b/sqlglot/optimizer/expand_multi_table_selects.py deleted file mode 100644 index 86f0c2d..0000000 --- a/sqlglot/optimizer/expand_multi_table_selects.py +++ /dev/null @@ -1,24 +0,0 @@ -from sqlglot import exp - - -def expand_multi_table_selects(expression): - """ - Replace multiple FROM expressions with JOINs. - - Example: - >>> from sqlglot import parse_one - >>> expand_multi_table_selects(parse_one("SELECT * FROM x, y")).sql() - 'SELECT * FROM x CROSS JOIN y' - """ - for from_ in expression.find_all(exp.From): - parent = from_.parent - - for query in from_.expressions[1:]: - parent.join( - query, - join_type="CROSS", - copy=False, - ) - from_.expressions.remove(query) - - return expression diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py index 5d78353..5dfa4aa 100644 --- a/sqlglot/optimizer/isolate_table_selects.py +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -21,7 +21,7 @@ def isolate_table_selects(expression, schema=None): source.replace( exp.select("*") .from_( - alias(source.copy(), source.name or source.alias, table=True), + alias(source, source.name or source.alias, table=True), copy=False, ) .subquery(source.alias, copy=False) diff --git a/sqlglot/optimizer/lower_identities.py b/sqlglot/optimizer/lower_identities.py deleted file mode 100644 index fae1726..0000000 --- a/sqlglot/optimizer/lower_identities.py +++ /dev/null @@ -1,88 +0,0 @@ -from sqlglot import exp - - -def lower_identities(expression): - """ - Convert all unquoted identifiers to lower case. - - Assuming the schema is all lower case, this essentially makes identifiers case-insensitive. - - Example: - >>> import sqlglot - >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar') - >>> lower_identities(expression).sql() - 'SELECT bar.a AS A FROM "Foo".bar' - - Args: - expression (sqlglot.Expression): expression to quote - Returns: - sqlglot.Expression: quoted expression - """ - # We need to leave the output aliases unchanged, so the selects need special handling - _lower_selects(expression) - - # These clauses can reference output aliases and also need special handling - _lower_order(expression) - _lower_having(expression) - - # We've already handled these args, so don't traverse into them - traversed = {"expressions", "order", "having"} - - if isinstance(expression, exp.Subquery): - # Root subquery, e.g. (SELECT A AS A FROM X) LIMIT 1 - lower_identities(expression.this) - traversed |= {"this"} - - if isinstance(expression, exp.Union): - # Union, e.g. SELECT A AS A FROM X UNION SELECT A AS A FROM X - lower_identities(expression.left) - lower_identities(expression.right) - traversed |= {"this", "expression"} - - for k, v in expression.iter_expressions(): - if k in traversed: - continue - v.transform(_lower, copy=False) - - return expression - - -def _lower_selects(expression): - for e in expression.expressions: - # Leave output aliases as-is - e.unalias().transform(_lower, copy=False) - - -def _lower_order(expression): - order = expression.args.get("order") - - if not order: - return - - output_aliases = {e.alias for e in expression.expressions if isinstance(e, exp.Alias)} - - for ordered in order.expressions: - # Don't lower references to output aliases - if not ( - isinstance(ordered.this, exp.Column) - and not ordered.this.table - and ordered.this.name in output_aliases - ): - ordered.transform(_lower, copy=False) - - -def _lower_having(expression): - having = expression.args.get("having") - - if not having: - return - - # Don't lower references to output aliases - for agg in having.find_all(exp.AggFunc): - agg.transform(_lower, copy=False) - - -def _lower(node): - if isinstance(node, exp.Identifier) and not node.quoted: - node.set("this", node.this.lower()) - return node diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index c3467b2..f9c9664 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -13,15 +13,15 @@ def merge_subqueries(expression, leave_tables_isolated=False): Example: >>> import sqlglot - >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") >>> merge_subqueries(expression).sql() - 'SELECT x.a FROM x JOIN y' + 'SELECT x.a FROM x CROSS JOIN y' If `leave_tables_isolated` is True, this will not merge inner queries into outer queries if it would result in multiple table selects in a single query: - >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) JOIN y") + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") >>> merge_subqueries(expression, leave_tables_isolated=True).sql() - 'SELECT a FROM (SELECT x.a FROM x) JOIN y' + 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y' Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html @@ -154,7 +154,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): inner_from = inner_scope.expression.args.get("from") if not inner_from: return False - inner_from_table = inner_from.expressions[0].alias_or_name + inner_from_table = inner_from.alias_or_name inner_projections = {s.alias_or_name: s for s in inner_scope.selects} return any( col.table != inner_from_table @@ -167,6 +167,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): and isinstance(inner_select, exp.Select) and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) and inner_select.args.get("from") + and not outer_scope.pivots and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) and not ( @@ -210,7 +211,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias): elif isinstance(source, exp.Table) and source.alias: source.set("alias", new_alias) elif isinstance(source, exp.Table): - source.replace(exp.alias_(source.copy(), new_alias)) + source.replace(exp.alias_(source, new_alias)) for column in inner_scope.source_columns(conflict): column.set("table", exp.to_identifier(new_name)) @@ -228,7 +229,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias): node_to_replace (exp.Subquery|exp.Table) alias (str) """ - new_subquery = inner_scope.expression.args.get("from").expressions[0] + new_subquery = inner_scope.expression.args["from"].this node_to_replace.replace(new_subquery) for join_hint in outer_scope.join_hints: tables = join_hint.find_all(exp.Table) @@ -319,7 +320,7 @@ def _merge_where(outer_scope, inner_scope, from_or_join): # Merge predicates from an outer join to the ON clause # if it only has columns that are already joined from_ = expression.args.get("from") - sources = {table.alias_or_name for table in from_.expressions} if from_ else {} + sources = {from_.alias_or_name} if from_ else {} for join in expression.args["joins"]: source = join.alias_or_name diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index b013312..1db094e 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -1,12 +1,12 @@ from __future__ import annotations import logging -import typing as t from sqlglot import exp from sqlglot.errors import OptimizeError +from sqlglot.generator import cached_generator from sqlglot.helper import while_changing -from sqlglot.optimizer.simplify import flatten, uniq_sort +from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort logger = logging.getLogger("sqlglot") @@ -28,13 +28,16 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = Returns: sqlglot.Expression: normalized expression """ - cache: t.Dict[int, str] = {} + generate = cached_generator() for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): if isinstance(node, exp.Connector): if normalized(node, dnf=dnf): continue + root = node is expression + original = node.copy() + node.transform(rewrite_between, copy=False) distance = normalization_distance(node, dnf=dnf) if distance > max_distance: @@ -43,11 +46,9 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = ) return expression - root = node is expression - original = node.copy() try: node = node.replace( - while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache)) + while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate)) ) except OptimizeError as e: logger.info(e) @@ -111,7 +112,7 @@ def _predicate_lengths(expression, dnf): return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) -def distributive_law(expression, dnf, max_distance, cache=None): +def distributive_law(expression, dnf, max_distance, generate): """ x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) @@ -124,7 +125,7 @@ def distributive_law(expression, dnf, max_distance, cache=None): if distance > max_distance: raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") - exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache)) + exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate)) to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) if isinstance(expression, from_exp): @@ -135,30 +136,30 @@ def distributive_law(expression, dnf, max_distance, cache=None): if isinstance(a, to_exp) and isinstance(b, to_exp): if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): - return _distribute(a, b, from_func, to_func, cache) - return _distribute(b, a, from_func, to_func, cache) + return _distribute(a, b, from_func, to_func, generate) + return _distribute(b, a, from_func, to_func, generate) if isinstance(a, to_exp): - return _distribute(b, a, from_func, to_func, cache) + return _distribute(b, a, from_func, to_func, generate) if isinstance(b, to_exp): - return _distribute(a, b, from_func, to_func, cache) + return _distribute(a, b, from_func, to_func, generate) return expression -def _distribute(a, b, from_func, to_func, cache): +def _distribute(a, b, from_func, to_func, generate): if isinstance(a, exp.Connector): exp.replace_children( a, lambda c: to_func( - uniq_sort(flatten(from_func(c, b.left)), cache), - uniq_sort(flatten(from_func(c, b.right)), cache), + uniq_sort(flatten(from_func(c, b.left)), generate), + uniq_sort(flatten(from_func(c, b.right)), generate), copy=False, ), ) else: a = to_func( - uniq_sort(flatten(from_func(a, b.left)), cache), - uniq_sort(flatten(from_func(a, b.right)), cache), + uniq_sort(flatten(from_func(a, b.left)), generate), + uniq_sort(flatten(from_func(a, b.right)), generate), copy=False, ) diff --git a/sqlglot/optimizer/normalize_identifiers.py b/sqlglot/optimizer/normalize_identifiers.py new file mode 100644 index 0000000..1e5c104 --- /dev/null +++ b/sqlglot/optimizer/normalize_identifiers.py @@ -0,0 +1,36 @@ +from sqlglot import exp +from sqlglot._typing import E +from sqlglot.dialects.dialect import RESOLVES_IDENTIFIERS_AS_UPPERCASE, DialectType + + +def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: + """ + Normalize all unquoted identifiers to either lower or upper case, depending on + the dialect. This essentially makes those identifiers case-insensitive. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar') + >>> normalize_identifiers(expression).sql() + 'SELECT bar.a AS a FROM "Foo".bar' + + Args: + expression: The expression to transform. + dialect: The dialect to use in order to decide how to normalize identifiers. + + Returns: + The transformed expression. + """ + return expression.transform(_normalize, dialect, copy=False) + + +def _normalize(node: exp.Expression, dialect: DialectType = None) -> exp.Expression: + if isinstance(node, exp.Identifier) and not node.quoted: + node.set( + "this", + node.this.upper() + if dialect in RESOLVES_IDENTIFIERS_AS_UPPERCASE + else node.this.lower(), + ) + + return node diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index 8589657..43436cb 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -1,6 +1,8 @@ from sqlglot import exp from sqlglot.helper import tsort +JOIN_ATTRS = ("on", "side", "kind", "using", "natural") + def optimize_joins(expression): """ @@ -45,7 +47,7 @@ def reorder_joins(expression): Reorder joins by topological sort order based on predicate references. """ for from_ in expression.find_all(exp.From): - head = from_.expressions[0] + head = from_.this parent = from_.parent joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])} dag = {head.alias_or_name: []} @@ -65,6 +67,9 @@ def normalize(expression): Remove INNER and OUTER from joins as they are optional. """ for join in expression.find_all(exp.Join): + if not any(join.args.get(k) for k in JOIN_ATTRS): + join.set("kind", "CROSS") + if join.kind != "CROSS": join.set("kind", None) return expression diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index c165ffe..dbe33a2 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -10,36 +10,29 @@ from sqlglot.optimizer.canonicalize import canonicalize from sqlglot.optimizer.eliminate_ctes import eliminate_ctes from sqlglot.optimizer.eliminate_joins import eliminate_joins from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries -from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects -from sqlglot.optimizer.isolate_table_selects import isolate_table_selects -from sqlglot.optimizer.lower_identities import lower_identities from sqlglot.optimizer.merge_subqueries import merge_subqueries from sqlglot.optimizer.normalize import normalize from sqlglot.optimizer.optimize_joins import optimize_joins from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.pushdown_projections import pushdown_projections -from sqlglot.optimizer.qualify_columns import qualify_columns, validate_qualify_columns -from sqlglot.optimizer.qualify_tables import qualify_tables +from sqlglot.optimizer.qualify import qualify +from sqlglot.optimizer.qualify_columns import quote_identifiers from sqlglot.optimizer.simplify import simplify from sqlglot.optimizer.unnest_subqueries import unnest_subqueries from sqlglot.schema import ensure_schema RULES = ( - lower_identities, - qualify_tables, - isolate_table_selects, - qualify_columns, + qualify, pushdown_projections, - validate_qualify_columns, normalize, unnest_subqueries, - expand_multi_table_selects, pushdown_predicates, optimize_joins, eliminate_subqueries, merge_subqueries, eliminate_joins, eliminate_ctes, + quote_identifiers, annotate_types, canonicalize, simplify, @@ -54,7 +47,7 @@ def optimize( dialect: DialectType = None, rules: t.Sequence[t.Callable] = RULES, **kwargs, -): +) -> exp.Expression: """ Rewrite a sqlglot AST into an optimized form. @@ -72,14 +65,23 @@ def optimize( dialect: The dialect to parse the sql string. rules: sequence of optimizer rules to use. Many of the rules require tables and columns to be qualified. - Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know - what you're doing! + Do not remove `qualify` from the sequence of rules unless you know what you're doing! **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in. + Returns: - sqlglot.Expression: optimized expression + The optimized expression. """ schema = ensure_schema(schema or sqlglot.schema, dialect=dialect) - possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs} + possible_kwargs = { + "db": db, + "catalog": catalog, + "schema": schema, + "dialect": dialect, + "isolate_tables": True, # needed for other optimizations to perform well + "quote_identifiers": False, # this happens in canonicalize + **kwargs, + } + expression = exp.maybe_parse(expression, dialect=dialect, copy=True) for rule in rules: # Find any additional rule parameters, beyond `expression` @@ -88,4 +90,5 @@ def optimize( param: possible_kwargs[param] for param in rule_params if param in possible_kwargs } expression = rule(expression, **rule_kwargs) - return expression + + return t.cast(exp.Expression, expression) diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index ba5c8b5..96dda33 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -21,26 +21,28 @@ def pushdown_predicates(expression): sqlglot.Expression: optimized expression """ root = build_scope(expression) - scope_ref_count = root.ref_count() - - for scope in reversed(list(root.traverse())): - select = scope.expression - where = select.args.get("where") - if where: - selected_sources = scope.selected_sources - # a right join can only push down to itself and not the source FROM table - for k, (node, source) in selected_sources.items(): - parent = node.find_ancestor(exp.Join, exp.From) - if isinstance(parent, exp.Join) and parent.side == "RIGHT": - selected_sources = {k: (node, source)} - break - pushdown(where.this, selected_sources, scope_ref_count) - - # joins should only pushdown into itself, not to other joins - # so we limit the selected sources to only itself - for join in select.args.get("joins") or []: - name = join.this.alias_or_name - pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count) + + if root: + scope_ref_count = root.ref_count() + + for scope in reversed(list(root.traverse())): + select = scope.expression + where = select.args.get("where") + if where: + selected_sources = scope.selected_sources + # a right join can only push down to itself and not the source FROM table + for k, (node, source) in selected_sources.items(): + parent = node.find_ancestor(exp.Join, exp.From) + if isinstance(parent, exp.Join) and parent.side == "RIGHT": + selected_sources = {k: (node, source)} + break + pushdown(where.this, selected_sources, scope_ref_count) + + # joins should only pushdown into itself, not to other joins + # so we limit the selected sources to only itself + for join in select.args.get("joins") or []: + name = join.this.alias_or_name + pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count) return expression diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 2e51117..be3ddb2 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -39,8 +39,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True) for scope in reversed(traverse_scope(expression)): parent_selections = referenced_columns.get(scope, {SELECT_ALL}) - if scope.expression.args.get("distinct"): - # We can't remove columns SELECT DISTINCT nor UNION DISTINCT + if scope.expression.args.get("distinct") or scope.parent and scope.parent.pivots: + # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. The same holds if + # we select from a pivoted source in the parent scope. parent_selections = {SELECT_ALL} if isinstance(scope.expression, exp.Union): @@ -105,7 +106,9 @@ def _remove_unused_selections(scope, parent_selections, schema): for name in sorted(parent_selections): if name not in names: - new_selections.append(alias(exp.column(name, table=resolver.get_table(name)), name)) + new_selections.append( + alias(exp.column(name, table=resolver.get_table(name)), name, copy=False) + ) # If there are no remaining selections, just select a single constant if not new_selections: diff --git a/sqlglot/optimizer/qualify.py b/sqlglot/optimizer/qualify.py new file mode 100644 index 0000000..5fdbde8 --- /dev/null +++ b/sqlglot/optimizer/qualify.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp +from sqlglot.dialects.dialect import DialectType +from sqlglot.optimizer.isolate_table_selects import isolate_table_selects +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers +from sqlglot.optimizer.qualify_columns import ( + qualify_columns as qualify_columns_func, + quote_identifiers as quote_identifiers_func, + validate_qualify_columns as validate_qualify_columns_func, +) +from sqlglot.optimizer.qualify_tables import qualify_tables +from sqlglot.schema import Schema, ensure_schema + + +def qualify( + expression: exp.Expression, + dialect: DialectType = None, + db: t.Optional[str] = None, + catalog: t.Optional[str] = None, + schema: t.Optional[dict | Schema] = None, + expand_alias_refs: bool = True, + infer_schema: t.Optional[bool] = None, + isolate_tables: bool = False, + qualify_columns: bool = True, + validate_qualify_columns: bool = True, + quote_identifiers: bool = True, + identify: bool = True, +) -> exp.Expression: + """ + Rewrite sqlglot AST to have normalized and qualified tables and columns. + + This step is necessary for all further SQLGlot optimizations. + + Example: + >>> import sqlglot + >>> schema = {"tbl": {"col": "INT"}} + >>> expression = sqlglot.parse_one("SELECT col FROM tbl") + >>> qualify(expression, schema=schema).sql() + 'SELECT "tbl"."col" AS "col" FROM "tbl" AS "tbl"' + + Args: + expression: Expression to qualify. + db: Default database name for tables. + catalog: Default catalog name for tables. + schema: Schema to infer column names and types. + expand_alias_refs: Whether or not to expand references to aliases. + infer_schema: Whether or not to infer the schema if missing. + isolate_tables: Whether or not to isolate table selects. + qualify_columns: Whether or not to qualify columns. + validate_qualify_columns: Whether or not to validate columns. + quote_identifiers: Whether or not to run the quote_identifiers step. + This step is necessary to ensure correctness for case sensitive queries. + But this flag is provided in case this step is performed at a later time. + identify: If True, quote all identifiers, else only necessary ones. + + Returns: + The qualified expression. + """ + schema = ensure_schema(schema, dialect=dialect) + expression = normalize_identifiers(expression, dialect=dialect) + expression = qualify_tables(expression, db=db, catalog=catalog, schema=schema) + + if isolate_tables: + expression = isolate_table_selects(expression, schema=schema) + + if qualify_columns: + expression = qualify_columns_func( + expression, schema, expand_alias_refs=expand_alias_refs, infer_schema=infer_schema + ) + + if quote_identifiers: + expression = quote_identifiers_func(expression, dialect=dialect, identify=identify) + + if validate_qualify_columns: + validate_qualify_columns_func(expression) + + return expression diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 6ac39f0..4a31171 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -1,14 +1,23 @@ +from __future__ import annotations + import itertools import typing as t from sqlglot import alias, exp +from sqlglot._typing import E +from sqlglot.dialects.dialect import DialectType from sqlglot.errors import OptimizeError -from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals -from sqlglot.optimizer.scope import Scope, traverse_scope -from sqlglot.schema import ensure_schema +from sqlglot.helper import case_sensitive, seq_get +from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope +from sqlglot.schema import Schema, ensure_schema -def qualify_columns(expression, schema, expand_laterals=True): +def qualify_columns( + expression: exp.Expression, + schema: dict | Schema, + expand_alias_refs: bool = True, + infer_schema: t.Optional[bool] = None, +) -> exp.Expression: """ Rewrite sqlglot AST to have fully qualified columns. @@ -20,32 +29,36 @@ def qualify_columns(expression, schema, expand_laterals=True): 'SELECT tbl.col AS col FROM tbl' Args: - expression (sqlglot.Expression): expression to qualify - schema (dict|sqlglot.optimizer.Schema): Database schema + expression: expression to qualify + schema: Database schema + expand_alias_refs: whether or not to expand references to aliases + infer_schema: whether or not to infer the schema if missing Returns: sqlglot.Expression: qualified expression """ schema = ensure_schema(schema) - - if not schema.mapping and expand_laterals: - expression = _expand_laterals(expression) + infer_schema = schema.empty if infer_schema is None else infer_schema for scope in traverse_scope(expression): - resolver = Resolver(scope, schema) + resolver = Resolver(scope, schema, infer_schema=infer_schema) _pop_table_column_aliases(scope.ctes) _pop_table_column_aliases(scope.derived_tables) using_column_tables = _expand_using(scope, resolver) + + if schema.empty and expand_alias_refs: + _expand_alias_refs(scope, resolver) + _qualify_columns(scope, resolver) + + if not schema.empty and expand_alias_refs: + _expand_alias_refs(scope, resolver) + if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver, using_column_tables) _qualify_outputs(scope) - _expand_alias_refs(scope, resolver) _expand_group_by(scope, resolver) _expand_order_by(scope) - if schema.mapping and expand_laterals: - expression = _expand_laterals(expression) - return expression @@ -55,9 +68,11 @@ def validate_qualify_columns(expression): for scope in traverse_scope(expression): if isinstance(scope.expression, exp.Select): unqualified_columns.extend(scope.unqualified_columns) - if scope.external_columns and not scope.is_correlated_subquery: + if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: column = scope.external_columns[0] - raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'") + raise OptimizeError( + f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" + ) if unqualified_columns: raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") @@ -142,52 +157,48 @@ def _expand_using(scope, resolver): # Ensure selects keep their output name if isinstance(column.parent, exp.Select): - replacement = exp.alias_(replacement, alias=column.name) + replacement = alias(replacement, alias=column.name, copy=False) scope.replace(column, replacement) return column_tables -def _expand_alias_refs(scope, resolver): - selects = {} - - # Replace references to select aliases - def transform(node, source_first=True): - if isinstance(node, exp.Column) and not node.table: - table = resolver.get_table(node.name) +def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: + expression = scope.expression - # Source columns get priority over select aliases - if source_first and table: - node.set("table", table) - return node + if not isinstance(expression, exp.Select): + return - if not selects: - for s in scope.selects: - selects[s.alias_or_name] = s - select = selects.get(node.name) + alias_to_expression: t.Dict[str, exp.Expression] = {} - if select: - scope.clear_cache() - if isinstance(select, exp.Alias): - select = select.this - return select.copy() + def replace_columns( + node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False + ): + if not node: + return - node.set("table", table) - elif isinstance(node, exp.Expression) and not isinstance(node, exp.Subqueryable): - exp.replace_children(node, transform, source_first) + for column, *_ in walk_in_scope(node): + if not isinstance(column, exp.Column): + continue + table = resolver.get_table(column.name) if resolve_agg and not column.table else None + if table and column.find_ancestor(exp.AggFunc): + column.set("table", table) + elif expand and not column.table and column.name in alias_to_expression: + column.replace(alias_to_expression[column.name].copy()) - return node + for projection in scope.selects: + replace_columns(projection) - for select in scope.expression.selects: - transform(select) + if isinstance(projection, exp.Alias): + alias_to_expression[projection.alias] = projection.this - for modifier, source_first in ( - ("where", True), - ("group", True), - ("having", False), - ): - transform(scope.expression.args.get(modifier), source_first=source_first) + replace_columns(expression.args.get("where")) + replace_columns(expression.args.get("group")) + replace_columns(expression.args.get("having"), resolve_agg=True) + replace_columns(expression.args.get("qualify"), resolve_agg=True) + replace_columns(expression.args.get("order"), expand=False, resolve_agg=True) + scope.clear_cache() def _expand_group_by(scope, resolver): @@ -242,6 +253,12 @@ def _qualify_columns(scope, resolver): raise OptimizeError(f"Unknown column: {column_name}") if not column_table: + if scope.pivots and not column.find_ancestor(exp.Pivot): + # If the column is under the Pivot expression, we need to qualify it + # using the name of the pivoted source instead of the pivot's alias + column.set("table", exp.to_identifier(scope.pivots[0].alias)) + continue + column_table = resolver.get_table(column_name) # column_table can be a '' because bigquery unnest has no table alias @@ -265,38 +282,12 @@ def _qualify_columns(scope, resolver): if column_table: column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) - columns_missing_from_scope = [] - - # Determine whether each reference in the order by clause is to a column or an alias. - order = scope.expression.args.get("order") - - if order: - for ordered in order.expressions: - for column in ordered.find_all(exp.Column): - if ( - not column.table - and column.parent is not ordered - and column.name in resolver.all_columns - ): - columns_missing_from_scope.append(column) - - # Determine whether each reference in the having clause is to a column or an alias. - having = scope.expression.args.get("having") - - if having: - for column in having.find_all(exp.Column): - if ( - not column.table - and column.find_ancestor(exp.AggFunc) - and column.name in resolver.all_columns - ): - columns_missing_from_scope.append(column) - - for column in columns_missing_from_scope: - column_table = resolver.get_table(column.name) - - if column_table: - column.set("table", column_table) + for pivot in scope.pivots: + for column in pivot.find_all(exp.Column): + if not column.table and column.name in resolver.all_columns: + column_table = resolver.get_table(column.name) + if column_table: + column.set("table", column_table) def _expand_stars(scope, resolver, using_column_tables): @@ -307,6 +298,19 @@ def _expand_stars(scope, resolver, using_column_tables): replace_columns = {} coalesced_columns = set() + # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future + pivot_columns = None + pivot_output_columns = None + pivot = seq_get(scope.pivots, 0) + + has_pivoted_source = pivot and not pivot.args.get("unpivot") + if has_pivoted_source: + pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column)) + + pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] + if not pivot_output_columns: + pivot_output_columns = [col.alias_or_name for col in pivot.expressions] + for expression in scope.selects: if isinstance(expression, exp.Star): tables = list(scope.selected_sources) @@ -323,9 +327,18 @@ def _expand_stars(scope, resolver, using_column_tables): for table in tables: if table not in scope.sources: raise OptimizeError(f"Unknown table: {table}") + columns = resolver.get_source_columns(table, only_visible=True) if columns and "*" not in columns: + if has_pivoted_source: + implicit_columns = [col for col in columns if col not in pivot_columns] + new_selections.extend( + exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) + for name in implicit_columns + pivot_output_columns + ) + continue + table_id = id(table) for name in columns: if name in using_column_tables and table in using_column_tables[name]: @@ -337,16 +350,21 @@ def _expand_stars(scope, resolver, using_column_tables): coalesce = [exp.column(name, table=table) for table in tables] new_selections.append( - exp.alias_( - exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name + alias( + exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), + alias=name, + copy=False, ) ) elif name not in except_columns.get(table_id, set()): alias_ = replace_columns.get(table_id, {}).get(name, name) - column = exp.column(name, table) - new_selections.append(alias(column, alias_) if alias_ != name else column) + column = exp.column(name, table=table) + new_selections.append( + alias(column, alias_, copy=False) if alias_ != name else column + ) else: return + scope.expression.set("expressions", new_selections) @@ -388,9 +406,6 @@ def _qualify_outputs(scope): selection = alias( selection, alias=selection.output_name or f"_col_{i}", - quoted=True - if isinstance(selection, exp.Column) and selection.this.quoted - else None, ) if aliased_column: selection.set("alias", exp.to_identifier(aliased_column)) @@ -400,6 +415,23 @@ def _qualify_outputs(scope): scope.expression.set("expressions", new_selections) +def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: + """Makes sure all identifiers that need to be quoted are quoted.""" + + def _quote(expression: E) -> E: + if isinstance(expression, exp.Identifier): + name = expression.this + expression.set( + "quoted", + identify + or case_sensitive(name, dialect=dialect) + or not exp.SAFE_IDENTIFIER_RE.match(name), + ) + return expression + + return expression.transform(_quote, copy=False) + + class Resolver: """ Helper for resolving columns. @@ -407,12 +439,13 @@ class Resolver: This is a class so we can lazily load some things and easily share them across functions. """ - def __init__(self, scope, schema): + def __init__(self, scope, schema, infer_schema: bool = True): self.scope = scope self.schema = schema self._source_columns = None - self._unambiguous_columns = None + self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None self._all_columns = None + self._infer_schema = infer_schema def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: """ @@ -430,7 +463,7 @@ class Resolver: table_name = self._unambiguous_columns.get(column_name) - if not table_name: + if not table_name and self._infer_schema: sources_without_schema = tuple( source for source, columns in self._get_all_source_columns().items() @@ -450,11 +483,9 @@ class Resolver: node_alias = node.args.get("alias") if node_alias: - return node_alias.this + return exp.to_identifier(node_alias.this) - return exp.to_identifier( - table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None - ) + return exp.to_identifier(table_name) @property def all_columns(self): diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 1b451a6..fcc5f26 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -1,11 +1,19 @@ import itertools +import typing as t from sqlglot import alias, exp -from sqlglot.helper import csv_reader +from sqlglot._typing import E +from sqlglot.helper import csv_reader, name_sequence from sqlglot.optimizer.scope import Scope, traverse_scope +from sqlglot.schema import Schema -def qualify_tables(expression, db=None, catalog=None, schema=None): +def qualify_tables( + expression: E, + db: t.Optional[str] = None, + catalog: t.Optional[str] = None, + schema: t.Optional[Schema] = None, +) -> E: """ Rewrite sqlglot AST to have fully qualified tables. Additionally, this replaces "join constructs" (*) by equivalent SELECT * subqueries. @@ -21,19 +29,17 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): 'SELECT * FROM (SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2) AS _q_0' Args: - expression (sqlglot.Expression): expression to qualify - db (str): Database name - catalog (str): Catalog name + expression: Expression to qualify + db: Database name + catalog: Catalog name schema: A schema to populate Returns: - sqlglot.Expression: qualified expression + The qualified expression. (*) See section 7.2.1.2 in https://www.postgresql.org/docs/current/queries-table-expressions.html """ - sequence = itertools.count() - - next_name = lambda: f"_q_{next(sequence)}" + next_alias_name = name_sequence("_q_") for scope in traverse_scope(expression): for derived_table in itertools.chain(scope.ctes, scope.derived_tables): @@ -44,10 +50,14 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) if not derived_table.args.get("alias"): - alias_ = f"_q_{next(sequence)}" + alias_ = next_alias_name() derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) scope.rename_source(None, alias_) + pivots = derived_table.args.get("pivots") + if pivots and not pivots[0].alias: + pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))) + for name, source in scope.sources.items(): if isinstance(source, exp.Table): if isinstance(source.this, exp.Identifier): @@ -59,12 +69,19 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): if not source.alias: source = source.replace( alias( - source.copy(), - name if name else next_name(), + source, + name or source.name or next_alias_name(), + copy=True, table=True, ) ) + pivots = source.args.get("pivots") + if pivots and not pivots[0].alias: + pivots[0].set( + "alias", exp.TableAlias(this=exp.to_identifier(next_alias_name())) + ) + if schema and isinstance(source.this, exp.ReadCSV): with csv_reader(source.this) as reader: header = next(reader) @@ -74,11 +91,11 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): ) elif isinstance(source, Scope) and source.is_udtf: udtf = source.expression - table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_name()) + table_alias = udtf.args.get("alias") or exp.TableAlias(this=next_alias_name()) udtf.set("alias", table_alias) if not table_alias.name: - table_alias.set("this", next_name()) + table_alias.set("this", next_alias_name()) if isinstance(udtf, exp.Values) and not table_alias.columns: for i, e in enumerate(udtf.expressions[0].expressions): table_alias.append("columns", exp.to_identifier(f"_col_{i}")) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index e00b3c9..9ffb4d6 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -1,4 +1,5 @@ import itertools +import typing as t from collections import defaultdict from enum import Enum, auto @@ -83,6 +84,7 @@ class Scope: self._columns = None self._external_columns = None self._join_hints = None + self._pivots = None def branch(self, expression, scope_type, chain_sources=None, **kwargs): """Branch from the current scope to a new, inner scope""" @@ -261,12 +263,14 @@ class Scope: self._columns = [] for column in columns + external_columns: - ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint) + ancestor = column.find_ancestor( + exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint + ) if ( not ancestor - # Window functions can have an ORDER BY clause - or not isinstance(ancestor.parent, exp.Select) or column.table + or isinstance(ancestor, exp.Select) + or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window)) or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) ): self._columns.append(column) @@ -370,6 +374,17 @@ class Scope: return [] return self._join_hints + @property + def pivots(self): + if not self._pivots: + self._pivots = [ + pivot + for node in self.tables + self.derived_tables + for pivot in node.args.get("pivots") or [] + ] + + return self._pivots + def source_columns(self, source_name): """ Get all columns in the current scope for a particular source. @@ -463,7 +478,7 @@ class Scope: return scope_ref_count -def traverse_scope(expression): +def traverse_scope(expression: exp.Expression) -> t.List[Scope]: """ Traverse an expression by it's "scopes". @@ -488,10 +503,12 @@ def traverse_scope(expression): Returns: list[Scope]: scope instances """ + if not isinstance(expression, exp.Unionable): + return [] return list(_traverse_scope(Scope(expression))) -def build_scope(expression): +def build_scope(expression: exp.Expression) -> t.Optional[Scope]: """ Build a scope tree. @@ -500,7 +517,10 @@ def build_scope(expression): Returns: Scope: root scope """ - return traverse_scope(expression)[-1] + scopes = traverse_scope(expression) + if scopes: + return scopes[-1] + return None def _traverse_scope(scope): @@ -585,7 +605,7 @@ def _traverse_tables(scope): expressions = [] from_ = scope.expression.args.get("from") if from_: - expressions.extend(from_.expressions) + expressions.append(from_.this) for join in scope.expression.args.get("joins") or []: expressions.append(join.this) @@ -601,8 +621,13 @@ def _traverse_tables(scope): source_name = expression.alias_or_name if table_name in scope.sources: - # This is a reference to a parent source (e.g. a CTE), not an actual table. - sources[source_name] = scope.sources[table_name] + # This is a reference to a parent source (e.g. a CTE), not an actual table, unless + # it is pivoted, because then we get back a new table and hence a new source. + pivots = expression.args.get("pivots") + if pivots: + sources[pivots[0].alias] = expression + else: + sources[source_name] = scope.sources[table_name] elif source_name in sources: sources[find_new_name(sources, table_name)] = expression else: diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 0904189..e2772a0 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -5,11 +5,9 @@ from collections import deque from decimal import Decimal from sqlglot import exp -from sqlglot.generator import Generator +from sqlglot.generator import cached_generator from sqlglot.helper import first, while_changing -GENERATOR = Generator(normalize=True, identify="safe") - def simplify(expression): """ @@ -27,12 +25,12 @@ def simplify(expression): sqlglot.Expression: simplified expression """ - cache = {} + generate = cached_generator() def _simplify(expression, root=True): node = expression node = rewrite_between(node) - node = uniq_sort(node, cache, root) + node = uniq_sort(node, generate, root) node = absorb_and_eliminate(node, root) exp.replace_children(node, lambda e: _simplify(e, False)) node = simplify_not(node) @@ -247,7 +245,7 @@ def remove_compliments(expression, root=True): return expression -def uniq_sort(expression, cache=None, root=True): +def uniq_sort(expression, generate, root=True): """ Uniq and sort a connector. @@ -256,7 +254,7 @@ def uniq_sort(expression, cache=None, root=True): if isinstance(expression, exp.Connector) and (root or not expression.same_parent): result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ flattened = tuple(expression.flatten()) - deduped = {GENERATOR.generate(e, cache): e for e in flattened} + deduped = {generate(e): e for e in flattened} arr = tuple(deduped.items()) # check if the operands are already sorted, if not sort them @@ -388,14 +386,18 @@ def _simplify_binary(expression, a, b): def simplify_parens(expression): - if ( - isinstance(expression, exp.Paren) - and not isinstance(expression.this, exp.Select) - and ( - not isinstance(expression.parent, (exp.Condition, exp.Binary)) - or isinstance(expression.this, exp.Predicate) - or not isinstance(expression.this, exp.Binary) - ) + if not isinstance(expression, exp.Paren): + return expression + + this = expression.this + parent = expression.parent + + if not isinstance(this, exp.Select) and ( + not isinstance(parent, (exp.Condition, exp.Binary)) + or isinstance(this, exp.Predicate) + or not isinstance(this, exp.Binary) + or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) + or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) ): return expression.this return expression diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index a515489..09e3f2a 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -1,6 +1,5 @@ -import itertools - from sqlglot import exp +from sqlglot.helper import name_sequence from sqlglot.optimizer.scope import ScopeType, traverse_scope @@ -22,7 +21,7 @@ def unnest_subqueries(expression): Returns: sqlglot.Expression: unnested expression """ - sequence = itertools.count() + next_alias_name = name_sequence("_u_") for scope in traverse_scope(expression): select = scope.expression @@ -30,19 +29,19 @@ def unnest_subqueries(expression): if not parent: continue if scope.external_columns: - decorrelate(select, parent, scope.external_columns, sequence) + decorrelate(select, parent, scope.external_columns, next_alias_name) elif scope.scope_type == ScopeType.SUBQUERY: - unnest(select, parent, sequence) + unnest(select, parent, next_alias_name) return expression -def unnest(select, parent_select, sequence): +def unnest(select, parent_select, next_alias_name): if len(select.selects) > 1: return predicate = select.find_ancestor(exp.Condition) - alias = _alias(sequence) + alias = next_alias_name() if not predicate or parent_select is not predicate.parent_select: return @@ -87,13 +86,13 @@ def unnest(select, parent_select, sequence): ) -def decorrelate(select, parent_select, external_columns, sequence): +def decorrelate(select, parent_select, external_columns, next_alias_name): where = select.args.get("where") if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): return - table_alias = _alias(sequence) + table_alias = next_alias_name() keys = [] # for all external columns in the where statement, find the relevant predicate @@ -136,7 +135,7 @@ def decorrelate(select, parent_select, external_columns, sequence): group_by.append(key) else: if key not in key_aliases: - key_aliases[key] = _alias(sequence) + key_aliases[key] = next_alias_name() # all predicates that are equalities must also be in the unique # so that we don't do a many to many join if isinstance(predicate, exp.EQ) and key not in group_by: @@ -244,10 +243,6 @@ def decorrelate(select, parent_select, external_columns, sequence): ) -def _alias(sequence): - return f"_u_{next(sequence)}" - - def _replace(expression, condition): return expression.replace(exp.condition(condition)) -- cgit v1.2.3