diff options
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/eliminate_subqueries.py | 3 | ||||
-rw-r--r-- | sqlglot/optimizer/merge_subqueries.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_projections.py | 3 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 15 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 32 |
5 files changed, 37 insertions, 18 deletions
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 5ae1fa0..728493d 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -144,8 +144,9 @@ def _eliminate_derived_table(scope, existing_ctes, taken): name, cte = _new_cte(scope, existing_ctes, taken) table = exp.alias_(exp.table_(name), alias=parent.alias or name) - parent.replace(table) + table.set("joins", parent.args.get("joins")) + parent.replace(table) return cte diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 6ee057b..7322424 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -176,6 +176,7 @@ def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): return ( isinstance(outer_scope.expression, exp.Select) + and not outer_scope.expression.is_star and isinstance(inner_select, exp.Select) and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) and inner_select.args.get("from") @@ -242,6 +243,7 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias): alias (str) """ new_subquery = inner_scope.expression.args["from"].this + new_subquery.set("joins", node_to_replace.args.get("joins")) node_to_replace.replace(new_subquery) for join_hint in outer_scope.join_hints: tables = join_hint.find_all(exp.Table) diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 97e8ff6..c81fd00 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -61,6 +61,9 @@ def pushdown_projections(expression, schema=None, remove_unused_selections=True) if remove_unused_selections: _remove_unused_selections(scope, parent_selections, schema) + if scope.expression.is_star: + continue + # Group columns by source name selects = defaultdict(set) for col in scope.columns: diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 7972b2b..2657188 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -29,12 +29,13 @@ def qualify_columns( 'SELECT tbl.col AS col FROM tbl' Args: - expression: expression to qualify - schema: Database schema - expand_alias_refs: whether or not to expand references to aliases - infer_schema: whether or not to infer the schema if missing + expression: Expression to qualify. + schema: Database schema. + expand_alias_refs: Whether or not to expand references to aliases. + infer_schema: Whether or not to infer the schema if missing. + Returns: - sqlglot.Expression: qualified expression + The qualified expression. """ schema = ensure_schema(schema) infer_schema = schema.empty if infer_schema is None else infer_schema @@ -410,7 +411,9 @@ def _expand_stars( else: return - scope.expression.set("expressions", new_selections) + # Ensures we don't overwrite the initial selections with an empty list + if new_selections: + scope.expression.set("expressions", new_selections) def _add_except_columns( diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index b2b4230..a7dab35 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -124,8 +124,8 @@ class Scope: self._ctes.append(node) elif ( isinstance(node, exp.Subquery) - and isinstance(parent, (exp.From, exp.Join)) - and _is_subquery_scope(node) + and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) + and _is_derived_table(node) ): self._derived_tables.append(node) elif isinstance(node, exp.Subqueryable): @@ -610,13 +610,13 @@ def _traverse_ctes(scope): scope.sources.update(sources) -def _is_subquery_scope(expression: exp.Subquery) -> bool: +def _is_derived_table(expression: exp.Subquery) -> bool: """ - We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a new scope. - If an alias is present, it shadows all names under the Subquery, so that's an - exception to this rule. + We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", + as it doesn't introduce a new scope. If an alias is present, it shadows all names + under the Subquery, so that's one exception to this rule. """ - return bool(not isinstance(expression.unnest(), exp.Table) or expression.alias) + return bool(expression.alias or isinstance(expression.this, exp.Subqueryable)) def _traverse_tables(scope): @@ -654,7 +654,10 @@ def _traverse_tables(scope): else: sources[source_name] = expression - expressions.extend(join.this for join in expression.args.get("joins") or []) + # Make sure to not include the joins twice + if expression is not scope.expression: + expressions.extend(join.this for join in expression.args.get("joins") or []) + continue if not isinstance(expression, exp.DerivedTable): @@ -664,10 +667,11 @@ def _traverse_tables(scope): lateral_sources = sources scope_type = ScopeType.UDTF scopes = scope.udtf_scopes - elif _is_subquery_scope(expression): + elif _is_derived_table(expression): lateral_sources = None scope_type = ScopeType.DERIVED_TABLE scopes = scope.derived_table_scopes + expressions.extend(join.this for join in expression.args.get("joins") or []) else: # Makes sure we check for possible sources in nested table constructs expressions.append(expression.this) @@ -735,10 +739,16 @@ def walk_in_scope(expression, bfs=True): isinstance(node, exp.CTE) or ( isinstance(node, exp.Subquery) - and isinstance(parent, (exp.From, exp.Join)) - and _is_subquery_scope(node) + and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) + and _is_derived_table(node) ) or isinstance(node, exp.UDTF) or isinstance(node, exp.Subqueryable) ): prune = True + + if isinstance(node, (exp.Subquery, exp.UDTF)): + # The following args are not actually in the inner scope, so we should visit them + for key in ("joins", "laterals", "pivots"): + for arg in node.args.get(key) or []: + yield from walk_in_scope(arg, bfs=bfs) |