diff options
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/eliminate_joins.py | 4 | ||||
-rw-r--r-- | sqlglot/optimizer/eliminate_subqueries.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/isolate_table_selects.py | 7 | ||||
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py | 3 | ||||
-rw-r--r-- | sqlglot/optimizer/optimize_joins.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_predicates.py | 5 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_projections.py | 10 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 12 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_tables.py | 37 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 55 |
10 files changed, 72 insertions, 65 deletions
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index cd8ba3b..3134e65 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -85,7 +85,7 @@ def _unique_outputs(scope): grouped_outputs = set() unique_outputs = set() - for select in scope.selects: + for select in scope.expression.selects: output = select.unalias() if output in grouped_expressions: grouped_outputs.add(output) @@ -105,7 +105,7 @@ def _unique_outputs(scope): def _has_single_output_row(scope): return isinstance(scope.expression, exp.Select) and ( - all(isinstance(e.unalias(), exp.AggFunc) for e in scope.selects) + all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects) or _is_limit_1(scope) or not scope.expression.args.get("from") ) diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 84f50e9..5ae1fa0 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -113,7 +113,7 @@ def _eliminate_union(scope, existing_ctes, taken): taken[alias] = scope # Try to maintain the selections - expressions = scope.selects + expressions = scope.expression.selects selects = [ exp.alias_(exp.column(e.alias_or_name, table=alias), alias=e.alias_or_name, copy=False) for e in expressions diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py index 79e3ed5..a6524b8 100644 --- a/sqlglot/optimizer/isolate_table_selects.py +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -12,7 +12,12 @@ def isolate_table_selects(expression, schema=None): continue for _, source in scope.selected_sources.values(): - if not isinstance(source, exp.Table) or not schema.column_names(source): + if ( + not isinstance(source, exp.Table) + or not schema.column_names(source) + or isinstance(source.parent, exp.Subquery) + or isinstance(source.parent.parent, exp.Table) + ): continue if not source.alias: diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index e156d5e..6ee057b 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -107,6 +107,7 @@ def merge_derived_tables(expression, leave_tables_isolated=False): _merge_order(outer_scope, inner_scope) _merge_hints(outer_scope, inner_scope) outer_scope.clear_cache() + return expression @@ -166,7 +167,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): if not inner_from: return False inner_from_table = inner_from.alias_or_name - inner_projections = {s.alias_or_name: s for s in inner_scope.selects} + inner_projections = {s.alias_or_name: s for s in inner_scope.expression.selects} return any( col.table != inner_from_table for selection in selections diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py index d51276f..7b3b2b1 100644 --- a/sqlglot/optimizer/optimize_joins.py +++ b/sqlglot/optimizer/optimize_joins.py @@ -59,7 +59,7 @@ def reorder_joins(expression): dag = {name: other_table_names(join) for name, join in joins.items()} parent.set( "joins", - [joins[name] for name in tsort(dag) if name != from_.alias_or_name], + [joins[name] for name in tsort(dag) if name != from_.alias_or_name and name in joins], ) return expression diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index fb1662d..58b988d 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -42,7 +42,10 @@ def pushdown_predicates(expression): # so we limit the selected sources to only itself for join in select.args.get("joins") or []: name = join.alias_or_name - pushdown(join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count) + if name in scope.selected_sources: + pushdown( + join.args.get("on"), {name: scope.selected_sources[name]}, scope_ref_count + ) return expression diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index be3ddb2..97e8ff6 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -48,12 +48,12 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True) left, right = scope.union_scopes referenced_columns[left] = parent_selections - if any(select.is_star for select in right.selects): + if any(select.is_star for select in right.expression.selects): referenced_columns[right] = parent_selections - elif not any(select.is_star for select in left.selects): + elif not any(select.is_star for select in left.expression.selects): referenced_columns[right] = [ - right.selects[i].alias_or_name - for i, select in enumerate(left.selects) + right.expression.selects[i].alias_or_name + for i, select in enumerate(left.expression.selects) if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections ] @@ -90,7 +90,7 @@ def _remove_unused_selections(scope, parent_selections, schema): removed = False star = False - for selection in scope.selects: + for selection in scope.expression.selects: name = selection.alias_or_name if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs: diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 435585c..7972b2b 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -192,13 +192,13 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: if table and (not alias_expr or double_agg): column.set("table", table) elif not column.table and alias_expr and not double_agg: - if isinstance(alias_expr, exp.Literal): + if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): if literal_index: column.replace(exp.Literal.number(i)) else: column.replace(alias_expr.copy()) - for i, projection in enumerate(scope.selects): + for i, projection in enumerate(scope.expression.selects): replace_columns(projection) if isinstance(projection, exp.Alias): @@ -239,7 +239,7 @@ def _expand_order_by(scope: Scope, resolver: Resolver): ordered.set("this", new_expression) if scope.expression.args.get("group"): - selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects} + selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} for ordered in ordereds: ordered = ordered.this @@ -270,7 +270,7 @@ def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: try: - return scope.selects[int(node.this) - 1].assert_is(exp.Alias) + return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) except IndexError: raise OptimizeError(f"Unknown output column: {node.name}") @@ -347,7 +347,7 @@ def _expand_stars( if not pivot_output_columns: pivot_output_columns = [col.alias_or_name for col in pivot.expressions] - for expression in scope.selects: + for expression in scope.expression.selects: if isinstance(expression, exp.Star): tables = list(scope.selected_sources) _add_except_columns(expression, tables, except_columns) @@ -446,7 +446,7 @@ def _qualify_outputs(scope: Scope): new_selections = [] for i, (selection, aliased_column) in enumerate( - itertools.zip_longest(scope.selects, scope.outer_column_list) + itertools.zip_longest(scope.expression.selects, scope.outer_column_list) ): if isinstance(selection, exp.Subquery): if not selection.output_name: diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index af8c716..31c9cc0 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -15,7 +15,8 @@ def qualify_tables( schema: t.Optional[Schema] = None, ) -> E: """ - Rewrite sqlglot AST to have fully qualified, unnested tables. + Rewrite sqlglot AST to have fully qualified tables. Join constructs such as + (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t. Examples: >>> import sqlglot @@ -23,18 +24,9 @@ def qualify_tables( >>> qualify_tables(expression, db="db").sql() 'SELECT 1 FROM db.tbl AS tbl' >>> - >>> expression = sqlglot.parse_one("SELECT * FROM (tbl)") + >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") >>> qualify_tables(expression).sql() - 'SELECT * FROM tbl AS tbl' - >>> - >>> expression = sqlglot.parse_one("SELECT * FROM (tbl1 JOIN tbl2 ON id1 = id2)") - >>> qualify_tables(expression).sql() - 'SELECT * FROM tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2' - - Note: - This rule effectively enforces a left-to-right join order, since all joins - are unnested. This means that the optimizer doesn't necessarily preserve the - original join order, e.g. when parentheses are used to specify it explicitly. + 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t' Args: expression: Expression to qualify @@ -49,6 +41,13 @@ def qualify_tables( for scope in traverse_scope(expression): for derived_table in itertools.chain(scope.ctes, scope.derived_tables): + if isinstance(derived_table, exp.Subquery): + unnested = derived_table.unnest() + if isinstance(unnested, exp.Table): + joins = unnested.args.pop("joins", None) + derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) + derived_table.this.set("joins", joins) + if not derived_table.args.get("alias"): alias_ = next_alias_name() derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) @@ -66,19 +65,9 @@ def qualify_tables( if not source.args.get("catalog"): source.set("catalog", exp.to_identifier(catalog)) - # Unnest joins attached in tables by appending them to the closest query - for join in source.args.get("joins") or []: - scope.expression.append("joins", join) - - source.set("joins", None) - source.set("wrapped", None) - if not source.alias: - source = source.replace( - alias( - source, name or source.name or next_alias_name(), copy=True, table=True - ) - ) + # Mutates the source by attaching an alias to it + alias(source, name or source.name or next_alias_name(), copy=False, table=True) pivots = source.args.get("pivots") if pivots and not pivots[0].alias: diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 7dcfb37..b2b4230 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -122,7 +122,11 @@ class Scope: self._udtfs.append(node) elif isinstance(node, exp.CTE): self._ctes.append(node) - elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): + elif ( + isinstance(node, exp.Subquery) + and isinstance(parent, (exp.From, exp.Join)) + and _is_subquery_scope(node) + ): self._derived_tables.append(node) elif isinstance(node, exp.Subqueryable): self._subqueries.append(node) @@ -274,6 +278,7 @@ class Scope: not ancestor or column.table or isinstance(ancestor, exp.Select) + or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) or ( isinstance(ancestor, exp.Order) and ( @@ -341,23 +346,6 @@ class Scope: } @property - def selects(self): - """ - Select expressions of this scope. - - For example, for the following expression: - SELECT 1 as a, 2 as b FROM x - - The outputs are the "1 as a" and "2 as b" expressions. - - Returns: - list[exp.Expression]: expressions - """ - if isinstance(self.expression, exp.Union): - return self.expression.unnest().selects - return self.expression.selects - - @property def external_columns(self): """ Columns that appear to reference sources in outer scopes. @@ -548,6 +536,8 @@ def _traverse_scope(scope): yield from _traverse_union(scope) elif isinstance(scope.expression, exp.Subquery): yield from _traverse_subqueries(scope) + elif isinstance(scope.expression, exp.Table): + yield from _traverse_tables(scope) elif isinstance(scope.expression, exp.UDTF): pass else: @@ -620,6 +610,15 @@ def _traverse_ctes(scope): scope.sources.update(sources) +def _is_subquery_scope(expression: exp.Subquery) -> bool: + """ + We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a new scope. + If an alias is present, it shadows all names under the Subquery, so that's an + exception to this rule. + """ + return bool(not isinstance(expression.unnest(), exp.Table) or expression.alias) + + def _traverse_tables(scope): sources = {} @@ -629,9 +628,8 @@ def _traverse_tables(scope): if from_: expressions.append(from_.this) - for expression in (scope.expression, *scope.find_all(exp.Table)): - for join in expression.args.get("joins") or []: - expressions.append(join.this) + for join in scope.expression.args.get("joins") or []: + expressions.append(join.this) if isinstance(scope.expression, exp.Table): expressions.append(scope.expression) @@ -655,6 +653,8 @@ def _traverse_tables(scope): sources[find_new_name(sources, table_name)] = expression else: sources[source_name] = expression + + expressions.extend(join.this for join in expression.args.get("joins") or []) continue if not isinstance(expression, exp.DerivedTable): @@ -664,10 +664,15 @@ def _traverse_tables(scope): lateral_sources = sources scope_type = ScopeType.UDTF scopes = scope.udtf_scopes - else: + elif _is_subquery_scope(expression): lateral_sources = None scope_type = ScopeType.DERIVED_TABLE scopes = scope.derived_table_scopes + else: + # Makes sure we check for possible sources in nested table constructs + expressions.append(expression.this) + expressions.extend(join.this for join in expression.args.get("joins") or []) + continue for child_scope in _traverse_scope( scope.branch( @@ -728,7 +733,11 @@ def walk_in_scope(expression, bfs=True): continue if ( isinstance(node, exp.CTE) - or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) + or ( + isinstance(node, exp.Subquery) + and isinstance(parent, (exp.From, exp.Join)) + and _is_subquery_scope(node) + ) or isinstance(node, exp.UDTF) or isinstance(node, exp.Subqueryable) ): |