from __future__ import annotations import itertools import typing as t from sqlglot import alias, exp from sqlglot._typing import E from sqlglot.dialects.dialect import Dialect, DialectType from sqlglot.errors import OptimizeError from sqlglot.helper import seq_get from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope from sqlglot.schema import Schema, ensure_schema def qualify_columns( expression: exp.Expression, schema: t.Dict | Schema, expand_alias_refs: bool = True, infer_schema: t.Optional[bool] = None, ) -> exp.Expression: """ Rewrite sqlglot AST to have fully qualified columns. Example: >>> import sqlglot >>> schema = {"tbl": {"col": "INT"}} >>> expression = sqlglot.parse_one("SELECT col FROM tbl") >>> qualify_columns(expression, schema).sql() 'SELECT tbl.col AS col FROM tbl' Args: expression: Expression to qualify. schema: Database schema. expand_alias_refs: Whether or not to expand references to aliases. infer_schema: Whether or not to infer the schema if missing. Returns: The qualified expression. """ schema = ensure_schema(schema) infer_schema = schema.empty if infer_schema is None else infer_schema pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS for scope in traverse_scope(expression): resolver = Resolver(scope, schema, infer_schema=infer_schema) _pop_table_column_aliases(scope.ctes) _pop_table_column_aliases(scope.derived_tables) using_column_tables = _expand_using(scope, resolver) if schema.empty and expand_alias_refs: _expand_alias_refs(scope, resolver) _qualify_columns(scope, resolver) if not schema.empty and expand_alias_refs: _expand_alias_refs(scope, resolver) if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver, using_column_tables, pseudocolumns) _qualify_outputs(scope) _expand_group_by(scope) _expand_order_by(scope, resolver) return expression def validate_qualify_columns(expression: E) -> E: """Raise an `OptimizeError` if any columns aren't qualified""" unqualified_columns = [] for scope in traverse_scope(expression): if isinstance(scope.expression, exp.Select): unqualified_columns.extend(scope.unqualified_columns) if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: column = scope.external_columns[0] raise OptimizeError( f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}""" ) if unqualified_columns: raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") return expression def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: """ Remove table column aliases. (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) """ for derived_table in derived_tables: table_alias = derived_table.args.get("alias") if table_alias: table_alias.args.pop("columns", None) def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: joins = list(scope.find_all(exp.Join)) names = {join.alias_or_name for join in joins} ordered = [key for key in scope.selected_sources if key not in names] # Mapping of automatically joined column names to an ordered set of source names (dict). column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} for join in joins: using = join.args.get("using") if not using: continue join_table = join.alias_or_name columns = {} for k in scope.selected_sources: if k in ordered: for column in resolver.get_source_columns(k): if column not in columns: columns[column] = k source_table = ordered[-1] ordered.append(join_table) join_columns = resolver.get_source_columns(join_table) conditions = [] for identifier in using: identifier = identifier.name table = columns.get(identifier) if not table or identifier not in join_columns: if columns and join_columns: raise OptimizeError(f"Cannot automatically join: {identifier}") table = table or source_table conditions.append( exp.condition( exp.EQ( this=exp.column(identifier, table=table), expression=exp.column(identifier, table=join_table), ) ) ) # Set all values in the dict to None, because we only care about the key ordering tables = column_tables.setdefault(identifier, {}) if table not in tables: tables[table] = None if join_table not in tables: tables[join_table] = None join.args.pop("using") join.set("on", exp.and_(*conditions, copy=False)) if column_tables: for column in scope.columns: if not column.table and column.name in column_tables: tables = column_tables[column.name] coalesce = [exp.column(column.name, table=table) for table in tables] replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) # Ensure selects keep their output name if isinstance(column.parent, exp.Select): replacement = alias(replacement, alias=column.name, copy=False) scope.replace(column, replacement) return column_tables def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None: expression = scope.expression if not isinstance(expression, exp.Select): return alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} def replace_columns( node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False ) -> None: if not node: return for column, *_ in walk_in_scope(node): if not isinstance(column, exp.Column): continue table = resolver.get_table(column.name) if resolve_table and not column.table else None alias_expr, i = alias_to_expression.get(column.name, (None, 1)) double_agg = ( (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc)) if alias_expr else False ) if table and (not alias_expr or double_agg): column.set("table", table) elif not column.table and alias_expr and not double_agg: if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): if literal_index: column.replace(exp.Literal.number(i)) else: column.replace(alias_expr.copy()) for i, projection in enumerate(scope.expression.selects): replace_columns(projection) if isinstance(projection, exp.Alias): alias_to_expression[projection.alias] = (projection.this, i + 1) replace_columns(expression.args.get("where")) replace_columns(expression.args.get("group"), literal_index=True) replace_columns(expression.args.get("having"), resolve_table=True) replace_columns(expression.args.get("qualify"), resolve_table=True) scope.clear_cache() def _expand_group_by(scope: Scope): expression = scope.expression group = expression.args.get("group") if not group: return group.set("expressions", _expand_positional_references(scope, group.expressions)) expression.set("group", group) def _expand_order_by(scope: Scope, resolver: Resolver): order = scope.expression.args.get("order") if not order: return ordereds = order.expressions for ordered, new_expression in zip( ordereds, _expand_positional_references(scope, (o.this for o in ordereds)), ): for agg in ordered.find_all(exp.AggFunc): for col in agg.find_all(exp.Column): if not col.table: col.set("table", resolver.get_table(col.name)) ordered.set("this", new_expression) if scope.expression.args.get("group"): selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} for ordered in ordereds: ordered = ordered.this ordered.replace( exp.to_identifier(_select_by_pos(scope, ordered).alias) if ordered.is_int else selects.get(ordered, ordered) ) def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]: new_nodes = [] for node in expressions: if node.is_int: select = _select_by_pos(scope, t.cast(exp.Literal, node)).this if isinstance(select, exp.Literal): new_nodes.append(node) else: new_nodes.append(select.copy()) scope.clear_cache() else: new_nodes.append(node) return new_nodes def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: try: return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) except IndexError: raise OptimizeError(f"Unknown output column: {node.name}") def _qualify_columns(scope: Scope, resolver: Resolver) -> None: """Disambiguate columns, ensuring each column specifies a source""" for column in scope.columns: column_table = column.table column_name = column.name if column_table and column_table in scope.sources: source_columns = resolver.get_source_columns(column_table) if source_columns and column_name not in source_columns and "*" not in source_columns: raise OptimizeError(f"Unknown column: {column_name}") if not column_table: if scope.pivots and not column.find_ancestor(exp.Pivot): # If the column is under the Pivot expression, we need to qualify it # using the name of the pivoted source instead of the pivot's alias column.set("table", exp.to_identifier(scope.pivots[0].alias)) continue column_table = resolver.get_table(column_name) # column_table can be a '' because bigquery unnest has no table alias if column_table: column.set("table", column_table) elif column_table not in scope.sources and ( not scope.parent or column_table not in scope.parent.sources ): # structs are used like tables (e.g. "struct"."field"), so they need to be qualified # separately and represented as dot(dot(...(