diff options
Diffstat (limited to 'sqlglot/optimizer/qualify_columns.py')
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 47 |
1 files changed, 32 insertions, 15 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index b06ea1d..742cdf5 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -8,7 +8,7 @@ from sqlglot._typing import E from sqlglot.dialects.dialect import Dialect, DialectType from sqlglot.errors import OptimizeError from sqlglot.helper import seq_get -from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope +from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope from sqlglot.optimizer.simplify import simplify_parens from sqlglot.schema import Schema, ensure_schema @@ -58,7 +58,7 @@ def qualify_columns( if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver, using_column_tables, pseudocolumns) - _qualify_outputs(scope) + qualify_outputs(scope) _expand_group_by(scope) _expand_order_by(scope, resolver) @@ -237,7 +237,7 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None: ordereds = order.expressions for ordered, new_expression in zip( ordereds, - _expand_positional_references(scope, (o.this for o in ordereds)), + _expand_positional_references(scope, (o.this for o in ordereds), alias=True), ): for agg in ordered.find_all(exp.AggFunc): for col in agg.find_all(exp.Column): @@ -259,17 +259,23 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None: ) -def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]: - new_nodes = [] +def _expand_positional_references( + scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False +) -> t.List[exp.Expression]: + new_nodes: t.List[exp.Expression] = [] for node in expressions: if node.is_int: - select = _select_by_pos(scope, t.cast(exp.Literal, node)).this + select = _select_by_pos(scope, t.cast(exp.Literal, node)) - if isinstance(select, exp.Literal): - new_nodes.append(node) + if alias: + new_nodes.append(exp.column(select.args["alias"].copy())) else: - new_nodes.append(select.copy()) - scope.clear_cache() + select = select.this + + if isinstance(select, exp.Literal): + new_nodes.append(node) + else: + new_nodes.append(select.copy()) else: new_nodes.append(node) @@ -307,7 +313,9 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None: if column_table: column.set("table", column_table) elif column_table not in scope.sources and ( - not scope.parent or column_table not in scope.parent.sources + not scope.parent + or column_table not in scope.parent.sources + or not scope.is_correlated_subquery ): # structs are used like tables (e.g. "struct"."field"), so they need to be qualified # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...)) @@ -381,15 +389,18 @@ def _expand_stars( columns = [name for name in columns if name.upper() not in pseudocolumns] if columns and "*" not in columns: + table_id = id(table) + columns_to_exclude = except_columns.get(table_id) or set() + if pivot and has_pivoted_source and pivot_columns and pivot_output_columns: implicit_columns = [col for col in columns if col not in pivot_columns] new_selections.extend( exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) for name in implicit_columns + pivot_output_columns + if name not in columns_to_exclude ) continue - table_id = id(table) for name in columns: if name in using_column_tables and table in using_column_tables[name]: if name in coalesced_columns: @@ -406,7 +417,7 @@ def _expand_stars( copy=False, ) ) - elif name not in except_columns.get(table_id, set()): + elif name not in columns_to_exclude: alias_ = replace_columns.get(table_id, {}).get(name, name) column = exp.column(name, table=table) new_selections.append( @@ -448,10 +459,16 @@ def _add_replace_columns( replace_columns[id(table)] = columns -def _qualify_outputs(scope: Scope) -> None: +def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: """Ensure all output columns are aliased""" - new_selections = [] + if isinstance(scope_or_expression, exp.Expression): + scope = build_scope(scope_or_expression) + if not isinstance(scope, Scope): + return + else: + scope = scope_or_expression + new_selections = [] for i, (selection, aliased_column) in enumerate( itertools.zip_longest(scope.expression.selects, scope.outer_column_list) ): |