diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 50 |
1 files changed, 34 insertions, 16 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index ef8aeb1..8c3f599 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -170,9 +170,11 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: if not isinstance(expression, exp.Select): return - alias_to_expression: t.Dict[str, exp.Expression] = {} + alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} - def replace_columns(node: t.Optional[exp.Expression], resolve_table: bool = False) -> None: + def replace_columns( + node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False + ) -> None: if not node: return @@ -180,7 +182,7 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: if not isinstance(column, exp.Column): continue table = resolver.get_table(column.name) if resolve_table and not column.table else None - alias_expr = alias_to_expression.get(column.name) + alias_expr, i = alias_to_expression.get(column.name, (None, 1)) double_agg = ( (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) if alias_expr @@ -190,16 +192,20 @@ def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: if table and (not alias_expr or double_agg): column.set("table", table) elif not column.table and alias_expr and not double_agg: - column.replace(alias_expr.copy()) + if isinstance(alias_expr, exp.Literal): + if literal_index: + column.replace(exp.Literal.number(i)) + else: + column.replace(alias_expr.copy()) - for projection in scope.selects: + for i, projection in enumerate(scope.selects): replace_columns(projection) if isinstance(projection, exp.Alias): - alias_to_expression[projection.alias] = projection.this + alias_to_expression[projection.alias] = (projection.this, i + 1) replace_columns(expression.args.get("where")) - replace_columns(expression.args.get("group")) + replace_columns(expression.args.get("group"), literal_index=True) replace_columns(expression.args.get("having"), resolve_table=True) replace_columns(expression.args.get("qualify"), resolve_table=True) scope.clear_cache() @@ -255,27 +261,39 @@ def _expand_order_by(scope: Scope, resolver: Resolver): selects = {s.this: exp.column(s.alias_or_name) for s in scope.selects} for ordered in ordereds: - ordered.set("this", selects.get(ordered.this, ordered.this)) + ordered = ordered.this + + ordered.replace( + exp.to_identifier(_select_by_pos(scope, ordered).alias) + if ordered.is_int + else selects.get(ordered, ordered) + ) def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]: new_nodes = [] for node in expressions: if node.is_int: - try: - select = scope.selects[int(node.name) - 1] - except IndexError: - raise OptimizeError(f"Unknown output column: {node.name}") - if isinstance(select, exp.Alias): - select = select.this - new_nodes.append(select.copy()) - scope.clear_cache() + select = _select_by_pos(scope, t.cast(exp.Literal, node)).this + + if isinstance(select, exp.Literal): + new_nodes.append(node) + else: + new_nodes.append(select.copy()) + scope.clear_cache() else: new_nodes.append(node) return new_nodes +def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: + try: + return scope.selects[int(node.this) - 1].assert_is(exp.Alias) + except IndexError: + raise OptimizeError(f"Unknown output column: {node.name}") + + def _qualify_columns(scope: Scope, resolver: Resolver) -> None: """Disambiguate columns, ensuring each column specifies a source""" for column in scope.columns: |