summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/qualify_columns.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/qualify_columns.py')
-rw-r--r--sqlglot/optimizer/qualify_columns.py47
1 files changed, 32 insertions, 15 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index b06ea1d..742cdf5 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -8,7 +8,7 @@ from sqlglot._typing import E
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.errors import OptimizeError
from sqlglot.helper import seq_get
-from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
+from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope
from sqlglot.optimizer.simplify import simplify_parens
from sqlglot.schema import Schema, ensure_schema
@@ -58,7 +58,7 @@ def qualify_columns(
if not isinstance(scope.expression, exp.UDTF):
_expand_stars(scope, resolver, using_column_tables, pseudocolumns)
- _qualify_outputs(scope)
+ qualify_outputs(scope)
_expand_group_by(scope)
_expand_order_by(scope, resolver)
@@ -237,7 +237,7 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
ordereds = order.expressions
for ordered, new_expression in zip(
ordereds,
- _expand_positional_references(scope, (o.this for o in ordereds)),
+ _expand_positional_references(scope, (o.this for o in ordereds), alias=True),
):
for agg in ordered.find_all(exp.AggFunc):
for col in agg.find_all(exp.Column):
@@ -259,17 +259,23 @@ def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
)
-def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
- new_nodes = []
+def _expand_positional_references(
+ scope: Scope, expressions: t.Iterable[exp.Expression], alias: bool = False
+) -> t.List[exp.Expression]:
+ new_nodes: t.List[exp.Expression] = []
for node in expressions:
if node.is_int:
- select = _select_by_pos(scope, t.cast(exp.Literal, node)).this
+ select = _select_by_pos(scope, t.cast(exp.Literal, node))
- if isinstance(select, exp.Literal):
- new_nodes.append(node)
+ if alias:
+ new_nodes.append(exp.column(select.args["alias"].copy()))
else:
- new_nodes.append(select.copy())
- scope.clear_cache()
+ select = select.this
+
+ if isinstance(select, exp.Literal):
+ new_nodes.append(node)
+ else:
+ new_nodes.append(select.copy())
else:
new_nodes.append(node)
@@ -307,7 +313,9 @@ def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
if column_table:
column.set("table", column_table)
elif column_table not in scope.sources and (
- not scope.parent or column_table not in scope.parent.sources
+ not scope.parent
+ or column_table not in scope.parent.sources
+ or not scope.is_correlated_subquery
):
# structs are used like tables (e.g. "struct"."field"), so they need to be qualified
# separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
@@ -381,15 +389,18 @@ def _expand_stars(
columns = [name for name in columns if name.upper() not in pseudocolumns]
if columns and "*" not in columns:
+ table_id = id(table)
+ columns_to_exclude = except_columns.get(table_id) or set()
+
if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
implicit_columns = [col for col in columns if col not in pivot_columns]
new_selections.extend(
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
for name in implicit_columns + pivot_output_columns
+ if name not in columns_to_exclude
)
continue
- table_id = id(table)
for name in columns:
if name in using_column_tables and table in using_column_tables[name]:
if name in coalesced_columns:
@@ -406,7 +417,7 @@ def _expand_stars(
copy=False,
)
)
- elif name not in except_columns.get(table_id, set()):
+ elif name not in columns_to_exclude:
alias_ = replace_columns.get(table_id, {}).get(name, name)
column = exp.column(name, table=table)
new_selections.append(
@@ -448,10 +459,16 @@ def _add_replace_columns(
replace_columns[id(table)] = columns
-def _qualify_outputs(scope: Scope) -> None:
+def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:
"""Ensure all output columns are aliased"""
- new_selections = []
+ if isinstance(scope_or_expression, exp.Expression):
+ scope = build_scope(scope_or_expression)
+ if not isinstance(scope, Scope):
+ return
+ else:
+ scope = scope_or_expression
+ new_selections = []
for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
):