summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/qualify_columns.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/qualify_columns.py')
-rw-r--r--sqlglot/optimizer/qualify_columns.py50
1 files changed, 34 insertions, 16 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index ef8aeb1..8c3f599 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -170,9 +170,11 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
if not isinstance(expression, exp.Select):
return
- alias_to_expression: t.Dict[str, exp.Expression] = {}
+ alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
- def replace_columns(node: t.Optional[exp.Expression], resolve_table: bool = False) -> None:
+ def replace_columns(
+ node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
+ ) -> None:
if not node:
return
@@ -180,7 +182,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
if not isinstance(column, exp.Column):
continue
table = resolver.get_table(column.name) if resolve_table and not column.table else None
- alias_expr = alias_to_expression.get(column.name)
+ alias_expr, i = alias_to_expression.get(column.name, (None, 1))
double_agg = (
(alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
if alias_expr
@@ -190,16 +192,20 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
if table and (not alias_expr or double_agg):
column.set("table", table)
elif not column.table and alias_expr and not double_agg:
- column.replace(alias_expr.copy())
+ if isinstance(alias_expr, exp.Literal):
+ if literal_index:
+ column.replace(exp.Literal.number(i))
+ else:
+ column.replace(alias_expr.copy())
- for projection in scope.selects:
+ for i, projection in enumerate(scope.selects):
replace_columns(projection)
if isinstance(projection, exp.Alias):
- alias_to_expression[projection.alias] = projection.this
+ alias_to_expression[projection.alias] = (projection.this, i + 1)
replace_columns(expression.args.get("where"))
- replace_columns(expression.args.get("group"))
+ replace_columns(expression.args.get("group"), literal_index=True)
replace_columns(expression.args.get("having"), resolve_table=True)
replace_columns(expression.args.get("qualify"), resolve_table=True)
scope.clear_cache()
@@ -255,27 +261,39 @@ def _expand_order_by(scope: Scope, resolver: Resolver):
selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects}
for ordered in ordereds:
- ordered.set("this", selects.get(ordered.this, ordered.this))
+ ordered = ordered.this
+
+ ordered.replace(
+ exp.to_identifier(_select_by_pos(scope, ordered).alias)
+ if ordered.is_int
+ else selects.get(ordered, ordered)
+ )
def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
new_nodes = []
for node in expressions:
if node.is_int:
- try:
- select = scope.selects[int(node.name) - 1]
- except IndexError:
- raise OptimizeError(f"Unknown output column: {node.name}")
- if isinstance(select, exp.Alias):
- select = select.this
- new_nodes.append(select.copy())
- scope.clear_cache()
+ select = _select_by_pos(scope, t.cast(exp.Literal, node)).this
+
+ if isinstance(select, exp.Literal):
+ new_nodes.append(node)
+ else:
+ new_nodes.append(select.copy())
+ scope.clear_cache()
else:
new_nodes.append(node)
return new_nodes
+def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
+ try:
+ return scope.selects[int(node.this) - 1].assert_is(exp.Alias)
+ except IndexError:
+ raise OptimizeError(f"Unknown output column: {node.name}")
+
+
def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
"""Disambiguate columns, ensuring each column specifies a source"""
for column in scope.columns: