diff options
Diffstat (limited to 'sqlglot/optimizer/qualify_columns.py')
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 221 |
1 files changed, 126 insertions, 95 deletions
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 6ac39f0..4a31171 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -1,14 +1,23 @@ +from __future__ import annotations + import itertools import typing as t from sqlglot import alias, exp +from sqlglot._typing import E +from sqlglot.dialects.dialect import DialectType from sqlglot.errors import OptimizeError -from sqlglot.optimizer.expand_laterals import expand_laterals as _expand_laterals -from sqlglot.optimizer.scope import Scope, traverse_scope -from sqlglot.schema import ensure_schema +from sqlglot.helper import case_sensitive, seq_get +from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope +from sqlglot.schema import Schema, ensure_schema -def qualify_columns(expression, schema, expand_laterals=True): +def qualify_columns( + expression: exp.Expression, + schema: dict | Schema, + expand_alias_refs: bool = True, + infer_schema: t.Optional[bool] = None, +) -> exp.Expression: """ Rewrite sqlglot AST to have fully qualified columns. @@ -20,32 +29,36 @@ def qualify_columns(expression, schema, expand_laterals=True): 'SELECT tbl.col AS col FROM tbl' Args: - expression (sqlglot.Expression): expression to qualify - schema (dict|sqlglot.optimizer.Schema): Database schema + expression: expression to qualify + schema: Database schema + expand_alias_refs: whether or not to expand references to aliases + infer_schema: whether or not to infer the schema if missing Returns: sqlglot.Expression: qualified expression """ schema = ensure_schema(schema) - - if not schema.mapping and expand_laterals: - expression = _expand_laterals(expression) + infer_schema = schema.empty if infer_schema is None else infer_schema for scope in traverse_scope(expression): - resolver = Resolver(scope, schema) + resolver = Resolver(scope, schema, infer_schema=infer_schema) _pop_table_column_aliases(scope.ctes) _pop_table_column_aliases(scope.derived_tables) using_column_tables = _expand_using(scope, resolver) + + if schema.empty and expand_alias_refs: + _expand_alias_refs(scope, resolver) + _qualify_columns(scope, resolver) + + if not schema.empty and expand_alias_refs: + _expand_alias_refs(scope, resolver) + if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver, using_column_tables) _qualify_outputs(scope) - _expand_alias_refs(scope, resolver) _expand_group_by(scope, resolver) _expand_order_by(scope) - if schema.mapping and expand_laterals: - expression = _expand_laterals(expression) - return expression @@ -55,9 +68,11 @@ def validate_qualify_columns(expression): for scope in traverse_scope(expression): if isinstance(scope.expression, exp.Select): unqualified_columns.extend(scope.unqualified_columns) - if scope.external_columns and not scope.is_correlated_subquery: + if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: column = scope.external_columns[0] - raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'") + raise OptimizeError( + f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" + ) if unqualified_columns: raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") @@ -142,52 +157,48 @@ def _expand_using(scope, resolver): # Ensure selects keep their output name if isinstance(column.parent, exp.Select): - replacement = exp.alias_(replacement, alias=column.name) + replacement = alias(replacement, alias=column.name, copy=False) scope.replace(column, replacement) return column_tables -def _expand_alias_refs(scope, resolver): - selects = {} - - # Replace references to select aliases - def transform(node, source_first=True): - if isinstance(node, exp.Column) and not node.table: - table = resolver.get_table(node.name) +def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: + expression = scope.expression - # Source columns get priority over select aliases - if source_first and table: - node.set("table", table) - return node + if not isinstance(expression, exp.Select): + return - if not selects: - for s in scope.selects: - selects[s.alias_or_name] = s - select = selects.get(node.name) + alias_to_expression: t.Dict[str, exp.Expression] = {} - if select: - scope.clear_cache() - if isinstance(select, exp.Alias): - select = select.this - return select.copy() + def replace_columns( + node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False + ): + if not node: + return - node.set("table", table) - elif isinstance(node, exp.Expression) and not isinstance(node, exp.Subqueryable): - exp.replace_children(node, transform, source_first) + for column, *_ in walk_in_scope(node): + if not isinstance(column, exp.Column): + continue + table = resolver.get_table(column.name) if resolve_agg and not column.table else None + if table and column.find_ancestor(exp.AggFunc): + column.set("table", table) + elif expand and not column.table and column.name in alias_to_expression: + column.replace(alias_to_expression[column.name].copy()) - return node + for projection in scope.selects: + replace_columns(projection) - for select in scope.expression.selects: - transform(select) + if isinstance(projection, exp.Alias): + alias_to_expression[projection.alias] = projection.this - for modifier, source_first in ( - ("where", True), - ("group", True), - ("having", False), - ): - transform(scope.expression.args.get(modifier), source_first=source_first) + replace_columns(expression.args.get("where")) + replace_columns(expression.args.get("group")) + replace_columns(expression.args.get("having"), resolve_agg=True) + replace_columns(expression.args.get("qualify"), resolve_agg=True) + replace_columns(expression.args.get("order"), expand=False, resolve_agg=True) + scope.clear_cache() def _expand_group_by(scope, resolver): @@ -242,6 +253,12 @@ def _qualify_columns(scope, resolver): raise OptimizeError(f"Unknown column: {column_name}") if not column_table: + if scope.pivots and not column.find_ancestor(exp.Pivot): + # If the column is under the Pivot expression, we need to qualify it + # using the name of the pivoted source instead of the pivot's alias + column.set("table", exp.to_identifier(scope.pivots[0].alias)) + continue + column_table = resolver.get_table(column_name) # column_table can be a '' because bigquery unnest has no table alias @@ -265,38 +282,12 @@ def _qualify_columns(scope, resolver): if column_table: column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) - columns_missing_from_scope = [] - - # Determine whether each reference in the order by clause is to a column or an alias. - order = scope.expression.args.get("order") - - if order: - for ordered in order.expressions: - for column in ordered.find_all(exp.Column): - if ( - not column.table - and column.parent is not ordered - and column.name in resolver.all_columns - ): - columns_missing_from_scope.append(column) - - # Determine whether each reference in the having clause is to a column or an alias. - having = scope.expression.args.get("having") - - if having: - for column in having.find_all(exp.Column): - if ( - not column.table - and column.find_ancestor(exp.AggFunc) - and column.name in resolver.all_columns - ): - columns_missing_from_scope.append(column) - - for column in columns_missing_from_scope: - column_table = resolver.get_table(column.name) - - if column_table: - column.set("table", column_table) + for pivot in scope.pivots: + for column in pivot.find_all(exp.Column): + if not column.table and column.name in resolver.all_columns: + column_table = resolver.get_table(column.name) + if column_table: + column.set("table", column_table) def _expand_stars(scope, resolver, using_column_tables): @@ -307,6 +298,19 @@ def _expand_stars(scope, resolver, using_column_tables): replace_columns = {} coalesced_columns = set() + # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future + pivot_columns = None + pivot_output_columns = None + pivot = seq_get(scope.pivots, 0) + + has_pivoted_source = pivot and not pivot.args.get("unpivot") + if has_pivoted_source: + pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column)) + + pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])] + if not pivot_output_columns: + pivot_output_columns = [col.alias_or_name for col in pivot.expressions] + for expression in scope.selects: if isinstance(expression, exp.Star): tables = list(scope.selected_sources) @@ -323,9 +327,18 @@ def _expand_stars(scope, resolver, using_column_tables): for table in tables: if table not in scope.sources: raise OptimizeError(f"Unknown table: {table}") + columns = resolver.get_source_columns(table, only_visible=True) if columns and "*" not in columns: + if has_pivoted_source: + implicit_columns = [col for col in columns if col not in pivot_columns] + new_selections.extend( + exp.alias_(exp.column(name, table=pivot.alias), name, copy=False) + for name in implicit_columns + pivot_output_columns + ) + continue + table_id = id(table) for name in columns: if name in using_column_tables and table in using_column_tables[name]: @@ -337,16 +350,21 @@ def _expand_stars(scope, resolver, using_column_tables): coalesce = [exp.column(name, table=table) for table in tables] new_selections.append( - exp.alias_( - exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name + alias( + exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), + alias=name, + copy=False, ) ) elif name not in except_columns.get(table_id, set()): alias_ = replace_columns.get(table_id, {}).get(name, name) - column = exp.column(name, table) - new_selections.append(alias(column, alias_) if alias_ != name else column) + column = exp.column(name, table=table) + new_selections.append( + alias(column, alias_, copy=False) if alias_ != name else column + ) else: return + scope.expression.set("expressions", new_selections) @@ -388,9 +406,6 @@ def _qualify_outputs(scope): selection = alias( selection, alias=selection.output_name or f"_col_{i}", - quoted=True - if isinstance(selection, exp.Column) and selection.this.quoted - else None, ) if aliased_column: selection.set("alias", exp.to_identifier(aliased_column)) @@ -400,6 +415,23 @@ def _qualify_outputs(scope): scope.expression.set("expressions", new_selections) +def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: + """Makes sure all identifiers that need to be quoted are quoted.""" + + def _quote(expression: E) -> E: + if isinstance(expression, exp.Identifier): + name = expression.this + expression.set( + "quoted", + identify + or case_sensitive(name, dialect=dialect) + or not exp.SAFE_IDENTIFIER_RE.match(name), + ) + return expression + + return expression.transform(_quote, copy=False) + + class Resolver: """ Helper for resolving columns. @@ -407,12 +439,13 @@ class Resolver: This is a class so we can lazily load some things and easily share them across functions. """ - def __init__(self, scope, schema): + def __init__(self, scope, schema, infer_schema: bool = True): self.scope = scope self.schema = schema self._source_columns = None - self._unambiguous_columns = None + self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None self._all_columns = None + self._infer_schema = infer_schema def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: """ @@ -430,7 +463,7 @@ class Resolver: table_name = self._unambiguous_columns.get(column_name) - if not table_name: + if not table_name and self._infer_schema: sources_without_schema = tuple( source for source, columns in self._get_all_source_columns().items() @@ -450,11 +483,9 @@ class Resolver: node_alias = node.args.get("alias") if node_alias: - return node_alias.this + return exp.to_identifier(node_alias.this) - return exp.to_identifier( - table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None - ) + return exp.to_identifier(table_name) @property def all_columns(self): |