Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot._typing import E
  8from sqlglot.dialects.dialect import Dialect, DialectType
  9from sqlglot.errors import OptimizeError
 10from sqlglot.helper import seq_get
 11from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
 12from sqlglot.schema import Schema, ensure_schema
 13
 14
 15def qualify_columns(
 16    expression: exp.Expression,
 17    schema: t.Dict | Schema,
 18    expand_alias_refs: bool = True,
 19    infer_schema: t.Optional[bool] = None,
 20) -> exp.Expression:
 21    """
 22    Rewrite sqlglot AST to have fully qualified columns.
 23
 24    Example:
 25        >>> import sqlglot
 26        >>> schema = {"tbl": {"col": "INT"}}
 27        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 28        >>> qualify_columns(expression, schema).sql()
 29        'SELECT tbl.col AS col FROM tbl'
 30
 31    Args:
 32        expression: Expression to qualify.
 33        schema: Database schema.
 34        expand_alias_refs: Whether or not to expand references to aliases.
 35        infer_schema: Whether or not to infer the schema if missing.
 36
 37    Returns:
 38        The qualified expression.
 39    """
 40    schema = ensure_schema(schema)
 41    infer_schema = schema.empty if infer_schema is None else infer_schema
 42
 43    for scope in traverse_scope(expression):
 44        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 45        _pop_table_column_aliases(scope.ctes)
 46        _pop_table_column_aliases(scope.derived_tables)
 47        using_column_tables = _expand_using(scope, resolver)
 48
 49        if schema.empty and expand_alias_refs:
 50            _expand_alias_refs(scope, resolver)
 51
 52        _qualify_columns(scope, resolver)
 53
 54        if not schema.empty and expand_alias_refs:
 55            _expand_alias_refs(scope, resolver)
 56
 57        if not isinstance(scope.expression, exp.UDTF):
 58            _expand_stars(scope, resolver, using_column_tables)
 59            _qualify_outputs(scope)
 60        _expand_group_by(scope)
 61        _expand_order_by(scope, resolver)
 62
 63    return expression
 64
 65
 66def validate_qualify_columns(expression: E) -> E:
 67    """Raise an `OptimizeError` if any columns aren't qualified"""
 68    unqualified_columns = []
 69    for scope in traverse_scope(expression):
 70        if isinstance(scope.expression, exp.Select):
 71            unqualified_columns.extend(scope.unqualified_columns)
 72            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 73                column = scope.external_columns[0]
 74                raise OptimizeError(
 75                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
 76                )
 77
 78    if unqualified_columns:
 79        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
 80    return expression
 81
 82
 83def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
 84    """
 85    Remove table column aliases.
 86
 87    (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 88    """
 89    for derived_table in derived_tables:
 90        table_alias = derived_table.args.get("alias")
 91        if table_alias:
 92            table_alias.args.pop("columns", None)
 93
 94
 95def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
 96    joins = list(scope.find_all(exp.Join))
 97    names = {join.alias_or_name for join in joins}
 98    ordered = [key for key in scope.selected_sources if key not in names]
 99
100    # Mapping of automatically joined column names to an ordered set of source names (dict).
101    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
102
103    for join in joins:
104        using = join.args.get("using")
105
106        if not using:
107            continue
108
109        join_table = join.alias_or_name
110
111        columns = {}
112
113        for k in scope.selected_sources:
114            if k in ordered:
115                for column in resolver.get_source_columns(k):
116                    if column not in columns:
117                        columns[column] = k
118
119        source_table = ordered[-1]
120        ordered.append(join_table)
121        join_columns = resolver.get_source_columns(join_table)
122        conditions = []
123
124        for identifier in using:
125            identifier = identifier.name
126            table = columns.get(identifier)
127
128            if not table or identifier not in join_columns:
129                if columns and join_columns:
130                    raise OptimizeError(f"Cannot automatically join: {identifier}")
131
132            table = table or source_table
133            conditions.append(
134                exp.condition(
135                    exp.EQ(
136                        this=exp.column(identifier, table=table),
137                        expression=exp.column(identifier, table=join_table),
138                    )
139                )
140            )
141
142            # Set all values in the dict to None, because we only care about the key ordering
143            tables = column_tables.setdefault(identifier, {})
144            if table not in tables:
145                tables[table] = None
146            if join_table not in tables:
147                tables[join_table] = None
148
149        join.args.pop("using")
150        join.set("on", exp.and_(*conditions, copy=False))
151
152    if column_tables:
153        for column in scope.columns:
154            if not column.table and column.name in column_tables:
155                tables = column_tables[column.name]
156                coalesce = [exp.column(column.name, table=table) for table in tables]
157                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
158
159                # Ensure selects keep their output name
160                if isinstance(column.parent, exp.Select):
161                    replacement = alias(replacement, alias=column.name, copy=False)
162
163                scope.replace(column, replacement)
164
165    return column_tables
166
167
168def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
169    expression = scope.expression
170
171    if not isinstance(expression, exp.Select):
172        return
173
174    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
175
176    def replace_columns(
177        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
178    ) -> None:
179        if not node:
180            return
181
182        for column, *_ in walk_in_scope(node):
183            if not isinstance(column, exp.Column):
184                continue
185            table = resolver.get_table(column.name) if resolve_table and not column.table else None
186            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
187            double_agg = (
188                (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
189                if alias_expr
190                else False
191            )
192
193            if table and (not alias_expr or double_agg):
194                column.set("table", table)
195            elif not column.table and alias_expr and not double_agg:
196                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
197                    if literal_index:
198                        column.replace(exp.Literal.number(i))
199                else:
200                    column.replace(alias_expr.copy())
201
202    for i, projection in enumerate(scope.expression.selects):
203        replace_columns(projection)
204
205        if isinstance(projection, exp.Alias):
206            alias_to_expression[projection.alias] = (projection.this, i + 1)
207
208    replace_columns(expression.args.get("where"))
209    replace_columns(expression.args.get("group"), literal_index=True)
210    replace_columns(expression.args.get("having"), resolve_table=True)
211    replace_columns(expression.args.get("qualify"), resolve_table=True)
212    scope.clear_cache()
213
214
215def _expand_group_by(scope: Scope):
216    expression = scope.expression
217    group = expression.args.get("group")
218    if not group:
219        return
220
221    group.set("expressions", _expand_positional_references(scope, group.expressions))
222    expression.set("group", group)
223
224
225def _expand_order_by(scope: Scope, resolver: Resolver):
226    order = scope.expression.args.get("order")
227    if not order:
228        return
229
230    ordereds = order.expressions
231    for ordered, new_expression in zip(
232        ordereds,
233        _expand_positional_references(scope, (o.this for o in ordereds)),
234    ):
235        for agg in ordered.find_all(exp.AggFunc):
236            for col in agg.find_all(exp.Column):
237                if not col.table:
238                    col.set("table", resolver.get_table(col.name))
239
240        ordered.set("this", new_expression)
241
242    if scope.expression.args.get("group"):
243        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
244
245        for ordered in ordereds:
246            ordered = ordered.this
247
248            ordered.replace(
249                exp.to_identifier(_select_by_pos(scope, ordered).alias)
250                if ordered.is_int
251                else selects.get(ordered, ordered)
252            )
253
254
255def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
256    new_nodes = []
257    for node in expressions:
258        if node.is_int:
259            select = _select_by_pos(scope, t.cast(exp.Literal, node)).this
260
261            if isinstance(select, exp.Literal):
262                new_nodes.append(node)
263            else:
264                new_nodes.append(select.copy())
265                scope.clear_cache()
266        else:
267            new_nodes.append(node)
268
269    return new_nodes
270
271
272def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
273    try:
274        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
275    except IndexError:
276        raise OptimizeError(f"Unknown output column: {node.name}")
277
278
279def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
280    """Disambiguate columns, ensuring each column specifies a source"""
281    for column in scope.columns:
282        column_table = column.table
283        column_name = column.name
284
285        if column_table and column_table in scope.sources:
286            source_columns = resolver.get_source_columns(column_table)
287            if source_columns and column_name not in source_columns and "*" not in source_columns:
288                raise OptimizeError(f"Unknown column: {column_name}")
289
290        if not column_table:
291            if scope.pivots and not column.find_ancestor(exp.Pivot):
292                # If the column is under the Pivot expression, we need to qualify it
293                # using the name of the pivoted source instead of the pivot's alias
294                column.set("table", exp.to_identifier(scope.pivots[0].alias))
295                continue
296
297            column_table = resolver.get_table(column_name)
298
299            # column_table can be a '' because bigquery unnest has no table alias
300            if column_table:
301                column.set("table", column_table)
302        elif column_table not in scope.sources and (
303            not scope.parent or column_table not in scope.parent.sources
304        ):
305            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
306            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
307
308            root, *parts = column.parts
309
310            if root.name in scope.sources:
311                # struct is already qualified, but we still need to change the AST representation
312                column_table = root
313                root, *parts = parts
314            else:
315                column_table = resolver.get_table(root.name)
316
317            if column_table:
318                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
319
320    for pivot in scope.pivots:
321        for column in pivot.find_all(exp.Column):
322            if not column.table and column.name in resolver.all_columns:
323                column_table = resolver.get_table(column.name)
324                if column_table:
325                    column.set("table", column_table)
326
327
328def _expand_stars(
329    scope: Scope, resolver: Resolver, using_column_tables: t.Dict[str, t.Any]
330) -> None:
331    """Expand stars to lists of column selections"""
332
333    new_selections = []
334    except_columns: t.Dict[int, t.Set[str]] = {}
335    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
336    coalesced_columns = set()
337
338    # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
339    pivot_columns = None
340    pivot_output_columns = None
341    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
342
343    has_pivoted_source = pivot and not pivot.args.get("unpivot")
344    if pivot and has_pivoted_source:
345        pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
346
347        pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
348        if not pivot_output_columns:
349            pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
350
351    for expression in scope.expression.selects:
352        if isinstance(expression, exp.Star):
353            tables = list(scope.selected_sources)
354            _add_except_columns(expression, tables, except_columns)
355            _add_replace_columns(expression, tables, replace_columns)
356        elif expression.is_star:
357            tables = [expression.table]
358            _add_except_columns(expression.this, tables, except_columns)
359            _add_replace_columns(expression.this, tables, replace_columns)
360        else:
361            new_selections.append(expression)
362            continue
363
364        for table in tables:
365            if table not in scope.sources:
366                raise OptimizeError(f"Unknown table: {table}")
367
368            columns = resolver.get_source_columns(table, only_visible=True)
369
370            # The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement
371            # https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
372            if resolver.schema.dialect == "bigquery":
373                columns = [
374                    name
375                    for name in columns
376                    if name.upper() not in ("_PARTITIONTIME", "_PARTITIONDATE")
377                ]
378
379            if columns and "*" not in columns:
380                if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
381                    implicit_columns = [col for col in columns if col not in pivot_columns]
382                    new_selections.extend(
383                        exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
384                        for name in implicit_columns + pivot_output_columns
385                    )
386                    continue
387
388                table_id = id(table)
389                for name in columns:
390                    if name in using_column_tables and table in using_column_tables[name]:
391                        if name in coalesced_columns:
392                            continue
393
394                        coalesced_columns.add(name)
395                        tables = using_column_tables[name]
396                        coalesce = [exp.column(name, table=table) for table in tables]
397
398                        new_selections.append(
399                            alias(
400                                exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
401                                alias=name,
402                                copy=False,
403                            )
404                        )
405                    elif name not in except_columns.get(table_id, set()):
406                        alias_ = replace_columns.get(table_id, {}).get(name, name)
407                        column = exp.column(name, table=table)
408                        new_selections.append(
409                            alias(column, alias_, copy=False) if alias_ != name else column
410                        )
411            else:
412                return
413
414    # Ensures we don't overwrite the initial selections with an empty list
415    if new_selections:
416        scope.expression.set("expressions", new_selections)
417
418
419def _add_except_columns(
420    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
421) -> None:
422    except_ = expression.args.get("except")
423
424    if not except_:
425        return
426
427    columns = {e.name for e in except_}
428
429    for table in tables:
430        except_columns[id(table)] = columns
431
432
433def _add_replace_columns(
434    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
435) -> None:
436    replace = expression.args.get("replace")
437
438    if not replace:
439        return
440
441    columns = {e.this.name: e.alias for e in replace}
442
443    for table in tables:
444        replace_columns[id(table)] = columns
445
446
447def _qualify_outputs(scope: Scope):
448    """Ensure all output columns are aliased"""
449    new_selections = []
450
451    for i, (selection, aliased_column) in enumerate(
452        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
453    ):
454        if isinstance(selection, exp.Subquery):
455            if not selection.output_name:
456                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
457        elif not isinstance(selection, exp.Alias) and not selection.is_star:
458            selection = alias(
459                selection,
460                alias=selection.output_name or f"_col_{i}",
461            )
462        if aliased_column:
463            selection.set("alias", exp.to_identifier(aliased_column))
464
465        new_selections.append(selection)
466
467    scope.expression.set("expressions", new_selections)
468
469
470def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
471    """Makes sure all identifiers that need to be quoted are quoted."""
472    return expression.transform(
473        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
474    )
475
476
477class Resolver:
478    """
479    Helper for resolving columns.
480
481    This is a class so we can lazily load some things and easily share them across functions.
482    """
483
484    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
485        self.scope = scope
486        self.schema = schema
487        self._source_columns = None
488        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
489        self._all_columns = None
490        self._infer_schema = infer_schema
491
492    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
493        """
494        Get the table for a column name.
495
496        Args:
497            column_name: The column name to find the table for.
498        Returns:
499            The table name if it can be found/inferred.
500        """
501        if self._unambiguous_columns is None:
502            self._unambiguous_columns = self._get_unambiguous_columns(
503                self._get_all_source_columns()
504            )
505
506        table_name = self._unambiguous_columns.get(column_name)
507
508        if not table_name and self._infer_schema:
509            sources_without_schema = tuple(
510                source
511                for source, columns in self._get_all_source_columns().items()
512                if not columns or "*" in columns
513            )
514            if len(sources_without_schema) == 1:
515                table_name = sources_without_schema[0]
516
517        if table_name not in self.scope.selected_sources:
518            return exp.to_identifier(table_name)
519
520        node, _ = self.scope.selected_sources.get(table_name)
521
522        if isinstance(node, exp.Subqueryable):
523            while node and node.alias != table_name:
524                node = node.parent
525
526        node_alias = node.args.get("alias")
527        if node_alias:
528            return exp.to_identifier(node_alias.this)
529
530        return exp.to_identifier(table_name)
531
532    @property
533    def all_columns(self):
534        """All available columns of all sources in this scope"""
535        if self._all_columns is None:
536            self._all_columns = {
537                column for columns in self._get_all_source_columns().values() for column in columns
538            }
539        return self._all_columns
540
541    def get_source_columns(self, name, only_visible=False):
542        """Resolve the source columns for a given source `name`"""
543        if name not in self.scope.sources:
544            raise OptimizeError(f"Unknown table: {name}")
545
546        source = self.scope.sources[name]
547
548        # If referencing a table, return the columns from the schema
549        if isinstance(source, exp.Table):
550            return self.schema.column_names(source, only_visible)
551
552        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
553            return source.expression.alias_column_names
554
555        # Otherwise, if referencing another scope, return that scope's named selects
556        return source.expression.named_selects
557
558    def _get_all_source_columns(self):
559        if self._source_columns is None:
560            self._source_columns = {
561                k: self.get_source_columns(k)
562                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
563            }
564        return self._source_columns
565
566    def _get_unambiguous_columns(self, source_columns):
567        """
568        Find all the unambiguous columns in sources.
569
570        Args:
571            source_columns (dict): Mapping of names to source columns
572        Returns:
573            dict: Mapping of column name to source name
574        """
575        if not source_columns:
576            return {}
577
578        source_columns = list(source_columns.items())
579
580        first_table, first_columns = source_columns[0]
581        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
582        all_columns = set(unambiguous_columns)
583
584        for table, columns in source_columns[1:]:
585            unique = self._find_unique_columns(columns)
586            ambiguous = set(all_columns).intersection(unique)
587            all_columns.update(columns)
588            for column in ambiguous:
589                unambiguous_columns.pop(column, None)
590            for column in unique.difference(ambiguous):
591                unambiguous_columns[column] = table
592
593        return unambiguous_columns
594
595    @staticmethod
596    def _find_unique_columns(columns):
597        """
598        Find the unique columns in a list of columns.
599
600        Example:
601            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
602            ['a', 'c']
603
604        This is necessary because duplicate column names are ambiguous.
605        """
606        counts = {}
607        for column in columns:
608            counts[column] = counts.get(column, 0) + 1
609        return {column for column, count in counts.items() if count == 1}
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
16def qualify_columns(
17    expression: exp.Expression,
18    schema: t.Dict | Schema,
19    expand_alias_refs: bool = True,
20    infer_schema: t.Optional[bool] = None,
21) -> exp.Expression:
22    """
23    Rewrite sqlglot AST to have fully qualified columns.
24
25    Example:
26        >>> import sqlglot
27        >>> schema = {"tbl": {"col": "INT"}}
28        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
29        >>> qualify_columns(expression, schema).sql()
30        'SELECT tbl.col AS col FROM tbl'
31
32    Args:
33        expression: Expression to qualify.
34        schema: Database schema.
35        expand_alias_refs: Whether or not to expand references to aliases.
36        infer_schema: Whether or not to infer the schema if missing.
37
38    Returns:
39        The qualified expression.
40    """
41    schema = ensure_schema(schema)
42    infer_schema = schema.empty if infer_schema is None else infer_schema
43
44    for scope in traverse_scope(expression):
45        resolver = Resolver(scope, schema, infer_schema=infer_schema)
46        _pop_table_column_aliases(scope.ctes)
47        _pop_table_column_aliases(scope.derived_tables)
48        using_column_tables = _expand_using(scope, resolver)
49
50        if schema.empty and expand_alias_refs:
51            _expand_alias_refs(scope, resolver)
52
53        _qualify_columns(scope, resolver)
54
55        if not schema.empty and expand_alias_refs:
56            _expand_alias_refs(scope, resolver)
57
58        if not isinstance(scope.expression, exp.UDTF):
59            _expand_stars(scope, resolver, using_column_tables)
60            _qualify_outputs(scope)
61        _expand_group_by(scope)
62        _expand_order_by(scope, resolver)
63
64    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether or not to expand references to aliases.
  • infer_schema: Whether or not to infer the schema if missing.
Returns:

The qualified expression.

def validate_qualify_columns(expression: ~E) -> ~E:
67def validate_qualify_columns(expression: E) -> E:
68    """Raise an `OptimizeError` if any columns aren't qualified"""
69    unqualified_columns = []
70    for scope in traverse_scope(expression):
71        if isinstance(scope.expression, exp.Select):
72            unqualified_columns.extend(scope.unqualified_columns)
73            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
74                column = scope.external_columns[0]
75                raise OptimizeError(
76                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
77                )
78
79    if unqualified_columns:
80        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
81    return expression

Raise an OptimizeError if any columns aren't qualified

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
471def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
472    """Makes sure all identifiers that need to be quoted are quoted."""
473    return expression.transform(
474        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
475    )

Makes sure all identifiers that need to be quoted are quoted.

class Resolver:
478class Resolver:
479    """
480    Helper for resolving columns.
481
482    This is a class so we can lazily load some things and easily share them across functions.
483    """
484
485    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
486        self.scope = scope
487        self.schema = schema
488        self._source_columns = None
489        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
490        self._all_columns = None
491        self._infer_schema = infer_schema
492
493    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
494        """
495        Get the table for a column name.
496
497        Args:
498            column_name: The column name to find the table for.
499        Returns:
500            The table name if it can be found/inferred.
501        """
502        if self._unambiguous_columns is None:
503            self._unambiguous_columns = self._get_unambiguous_columns(
504                self._get_all_source_columns()
505            )
506
507        table_name = self._unambiguous_columns.get(column_name)
508
509        if not table_name and self._infer_schema:
510            sources_without_schema = tuple(
511                source
512                for source, columns in self._get_all_source_columns().items()
513                if not columns or "*" in columns
514            )
515            if len(sources_without_schema) == 1:
516                table_name = sources_without_schema[0]
517
518        if table_name not in self.scope.selected_sources:
519            return exp.to_identifier(table_name)
520
521        node, _ = self.scope.selected_sources.get(table_name)
522
523        if isinstance(node, exp.Subqueryable):
524            while node and node.alias != table_name:
525                node = node.parent
526
527        node_alias = node.args.get("alias")
528        if node_alias:
529            return exp.to_identifier(node_alias.this)
530
531        return exp.to_identifier(table_name)
532
533    @property
534    def all_columns(self):
535        """All available columns of all sources in this scope"""
536        if self._all_columns is None:
537            self._all_columns = {
538                column for columns in self._get_all_source_columns().values() for column in columns
539            }
540        return self._all_columns
541
542    def get_source_columns(self, name, only_visible=False):
543        """Resolve the source columns for a given source `name`"""
544        if name not in self.scope.sources:
545            raise OptimizeError(f"Unknown table: {name}")
546
547        source = self.scope.sources[name]
548
549        # If referencing a table, return the columns from the schema
550        if isinstance(source, exp.Table):
551            return self.schema.column_names(source, only_visible)
552
553        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
554            return source.expression.alias_column_names
555
556        # Otherwise, if referencing another scope, return that scope's named selects
557        return source.expression.named_selects
558
559    def _get_all_source_columns(self):
560        if self._source_columns is None:
561            self._source_columns = {
562                k: self.get_source_columns(k)
563                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
564            }
565        return self._source_columns
566
567    def _get_unambiguous_columns(self, source_columns):
568        """
569        Find all the unambiguous columns in sources.
570
571        Args:
572            source_columns (dict): Mapping of names to source columns
573        Returns:
574            dict: Mapping of column name to source name
575        """
576        if not source_columns:
577            return {}
578
579        source_columns = list(source_columns.items())
580
581        first_table, first_columns = source_columns[0]
582        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
583        all_columns = set(unambiguous_columns)
584
585        for table, columns in source_columns[1:]:
586            unique = self._find_unique_columns(columns)
587            ambiguous = set(all_columns).intersection(unique)
588            all_columns.update(columns)
589            for column in ambiguous:
590                unambiguous_columns.pop(column, None)
591            for column in unique.difference(ambiguous):
592                unambiguous_columns[column] = table
593
594        return unambiguous_columns
595
596    @staticmethod
597    def _find_unique_columns(columns):
598        """
599        Find the unique columns in a list of columns.
600
601        Example:
602            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
603            ['a', 'c']
604
605        This is necessary because duplicate column names are ambiguous.
606        """
607        counts = {}
608        for column in columns:
609            counts[column] = counts.get(column, 0) + 1
610        return {column for column, count in counts.items() if count == 1}

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
485    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
486        self.scope = scope
487        self.schema = schema
488        self._source_columns = None
489        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
490        self._all_columns = None
491        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
493    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
494        """
495        Get the table for a column name.
496
497        Args:
498            column_name: The column name to find the table for.
499        Returns:
500            The table name if it can be found/inferred.
501        """
502        if self._unambiguous_columns is None:
503            self._unambiguous_columns = self._get_unambiguous_columns(
504                self._get_all_source_columns()
505            )
506
507        table_name = self._unambiguous_columns.get(column_name)
508
509        if not table_name and self._infer_schema:
510            sources_without_schema = tuple(
511                source
512                for source, columns in self._get_all_source_columns().items()
513                if not columns or "*" in columns
514            )
515            if len(sources_without_schema) == 1:
516                table_name = sources_without_schema[0]
517
518        if table_name not in self.scope.selected_sources:
519            return exp.to_identifier(table_name)
520
521        node, _ = self.scope.selected_sources.get(table_name)
522
523        if isinstance(node, exp.Subqueryable):
524            while node and node.alias != table_name:
525                node = node.parent
526
527        node_alias = node.args.get("alias")
528        if node_alias:
529            return exp.to_identifier(node_alias.this)
530
531        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns

All available columns of all sources in this scope

def get_source_columns(self, name, only_visible=False):
542    def get_source_columns(self, name, only_visible=False):
543        """Resolve the source columns for a given source `name`"""
544        if name not in self.scope.sources:
545            raise OptimizeError(f"Unknown table: {name}")
546
547        source = self.scope.sources[name]
548
549        # If referencing a table, return the columns from the schema
550        if isinstance(source, exp.Table):
551            return self.schema.column_names(source, only_visible)
552
553        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
554            return source.expression.alias_column_names
555
556        # Otherwise, if referencing another scope, return that scope's named selects
557        return source.expression.named_selects

Resolve the source columns for a given source name