Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot._typing import E
  8from sqlglot.dialects.dialect import Dialect, DialectType
  9from sqlglot.errors import OptimizeError
 10from sqlglot.helper import seq_get
 11from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
 12from sqlglot.schema import Schema, ensure_schema
 13
 14
 15def qualify_columns(
 16    expression: exp.Expression,
 17    schema: t.Dict | Schema,
 18    expand_alias_refs: bool = True,
 19    infer_schema: t.Optional[bool] = None,
 20) -> exp.Expression:
 21    """
 22    Rewrite sqlglot AST to have fully qualified columns.
 23
 24    Example:
 25        >>> import sqlglot
 26        >>> schema = {"tbl": {"col": "INT"}}
 27        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 28        >>> qualify_columns(expression, schema).sql()
 29        'SELECT tbl.col AS col FROM tbl'
 30
 31    Args:
 32        expression: Expression to qualify.
 33        schema: Database schema.
 34        expand_alias_refs: Whether or not to expand references to aliases.
 35        infer_schema: Whether or not to infer the schema if missing.
 36
 37    Returns:
 38        The qualified expression.
 39    """
 40    schema = ensure_schema(schema)
 41    infer_schema = schema.empty if infer_schema is None else infer_schema
 42    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
 43
 44    for scope in traverse_scope(expression):
 45        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 46        _pop_table_column_aliases(scope.ctes)
 47        _pop_table_column_aliases(scope.derived_tables)
 48        using_column_tables = _expand_using(scope, resolver)
 49
 50        if schema.empty and expand_alias_refs:
 51            _expand_alias_refs(scope, resolver)
 52
 53        _qualify_columns(scope, resolver)
 54
 55        if not schema.empty and expand_alias_refs:
 56            _expand_alias_refs(scope, resolver)
 57
 58        if not isinstance(scope.expression, exp.UDTF):
 59            _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
 60            _qualify_outputs(scope)
 61        _expand_group_by(scope)
 62        _expand_order_by(scope, resolver)
 63
 64    return expression
 65
 66
 67def validate_qualify_columns(expression: E) -> E:
 68    """Raise an `OptimizeError` if any columns aren't qualified"""
 69    unqualified_columns = []
 70    for scope in traverse_scope(expression):
 71        if isinstance(scope.expression, exp.Select):
 72            unqualified_columns.extend(scope.unqualified_columns)
 73            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 74                column = scope.external_columns[0]
 75                raise OptimizeError(
 76                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
 77                )
 78
 79    if unqualified_columns:
 80        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
 81    return expression
 82
 83
 84def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
 85    """
 86    Remove table column aliases.
 87
 88    (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 89    """
 90    for derived_table in derived_tables:
 91        table_alias = derived_table.args.get("alias")
 92        if table_alias:
 93            table_alias.args.pop("columns", None)
 94
 95
 96def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
 97    joins = list(scope.find_all(exp.Join))
 98    names = {join.alias_or_name for join in joins}
 99    ordered = [key for key in scope.selected_sources if key not in names]
100
101    # Mapping of automatically joined column names to an ordered set of source names (dict).
102    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
103
104    for join in joins:
105        using = join.args.get("using")
106
107        if not using:
108            continue
109
110        join_table = join.alias_or_name
111
112        columns = {}
113
114        for k in scope.selected_sources:
115            if k in ordered:
116                for column in resolver.get_source_columns(k):
117                    if column not in columns:
118                        columns[column] = k
119
120        source_table = ordered[-1]
121        ordered.append(join_table)
122        join_columns = resolver.get_source_columns(join_table)
123        conditions = []
124
125        for identifier in using:
126            identifier = identifier.name
127            table = columns.get(identifier)
128
129            if not table or identifier not in join_columns:
130                if columns and join_columns:
131                    raise OptimizeError(f"Cannot automatically join: {identifier}")
132
133            table = table or source_table
134            conditions.append(
135                exp.condition(
136                    exp.EQ(
137                        this=exp.column(identifier, table=table),
138                        expression=exp.column(identifier, table=join_table),
139                    )
140                )
141            )
142
143            # Set all values in the dict to None, because we only care about the key ordering
144            tables = column_tables.setdefault(identifier, {})
145            if table not in tables:
146                tables[table] = None
147            if join_table not in tables:
148                tables[join_table] = None
149
150        join.args.pop("using")
151        join.set("on", exp.and_(*conditions, copy=False))
152
153    if column_tables:
154        for column in scope.columns:
155            if not column.table and column.name in column_tables:
156                tables = column_tables[column.name]
157                coalesce = [exp.column(column.name, table=table) for table in tables]
158                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
159
160                # Ensure selects keep their output name
161                if isinstance(column.parent, exp.Select):
162                    replacement = alias(replacement, alias=column.name, copy=False)
163
164                scope.replace(column, replacement)
165
166    return column_tables
167
168
169def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
170    expression = scope.expression
171
172    if not isinstance(expression, exp.Select):
173        return
174
175    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
176
177    def replace_columns(
178        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
179    ) -> None:
180        if not node:
181            return
182
183        for column, *_ in walk_in_scope(node):
184            if not isinstance(column, exp.Column):
185                continue
186            table = resolver.get_table(column.name) if resolve_table and not column.table else None
187            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
188            double_agg = (
189                (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
190                if alias_expr
191                else False
192            )
193
194            if table and (not alias_expr or double_agg):
195                column.set("table", table)
196            elif not column.table and alias_expr and not double_agg:
197                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
198                    if literal_index:
199                        column.replace(exp.Literal.number(i))
200                else:
201                    column.replace(alias_expr.copy())
202
203    for i, projection in enumerate(scope.expression.selects):
204        replace_columns(projection)
205
206        if isinstance(projection, exp.Alias):
207            alias_to_expression[projection.alias] = (projection.this, i + 1)
208
209    replace_columns(expression.args.get("where"))
210    replace_columns(expression.args.get("group"), literal_index=True)
211    replace_columns(expression.args.get("having"), resolve_table=True)
212    replace_columns(expression.args.get("qualify"), resolve_table=True)
213    scope.clear_cache()
214
215
216def _expand_group_by(scope: Scope):
217    expression = scope.expression
218    group = expression.args.get("group")
219    if not group:
220        return
221
222    group.set("expressions", _expand_positional_references(scope, group.expressions))
223    expression.set("group", group)
224
225
226def _expand_order_by(scope: Scope, resolver: Resolver):
227    order = scope.expression.args.get("order")
228    if not order:
229        return
230
231    ordereds = order.expressions
232    for ordered, new_expression in zip(
233        ordereds,
234        _expand_positional_references(scope, (o.this for o in ordereds)),
235    ):
236        for agg in ordered.find_all(exp.AggFunc):
237            for col in agg.find_all(exp.Column):
238                if not col.table:
239                    col.set("table", resolver.get_table(col.name))
240
241        ordered.set("this", new_expression)
242
243    if scope.expression.args.get("group"):
244        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
245
246        for ordered in ordereds:
247            ordered = ordered.this
248
249            ordered.replace(
250                exp.to_identifier(_select_by_pos(scope, ordered).alias)
251                if ordered.is_int
252                else selects.get(ordered, ordered)
253            )
254
255
256def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
257    new_nodes = []
258    for node in expressions:
259        if node.is_int:
260            select = _select_by_pos(scope, t.cast(exp.Literal, node)).this
261
262            if isinstance(select, exp.Literal):
263                new_nodes.append(node)
264            else:
265                new_nodes.append(select.copy())
266                scope.clear_cache()
267        else:
268            new_nodes.append(node)
269
270    return new_nodes
271
272
273def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
274    try:
275        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
276    except IndexError:
277        raise OptimizeError(f"Unknown output column: {node.name}")
278
279
280def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
281    """Disambiguate columns, ensuring each column specifies a source"""
282    for column in scope.columns:
283        column_table = column.table
284        column_name = column.name
285
286        if column_table and column_table in scope.sources:
287            source_columns = resolver.get_source_columns(column_table)
288            if source_columns and column_name not in source_columns and "*" not in source_columns:
289                raise OptimizeError(f"Unknown column: {column_name}")
290
291        if not column_table:
292            if scope.pivots and not column.find_ancestor(exp.Pivot):
293                # If the column is under the Pivot expression, we need to qualify it
294                # using the name of the pivoted source instead of the pivot's alias
295                column.set("table", exp.to_identifier(scope.pivots[0].alias))
296                continue
297
298            column_table = resolver.get_table(column_name)
299
300            # column_table can be a '' because bigquery unnest has no table alias
301            if column_table:
302                column.set("table", column_table)
303        elif column_table not in scope.sources and (
304            not scope.parent or column_table not in scope.parent.sources
305        ):
306            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
307            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
308
309            root, *parts = column.parts
310
311            if root.name in scope.sources:
312                # struct is already qualified, but we still need to change the AST representation
313                column_table = root
314                root, *parts = parts
315            else:
316                column_table = resolver.get_table(root.name)
317
318            if column_table:
319                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
320
321    for pivot in scope.pivots:
322        for column in pivot.find_all(exp.Column):
323            if not column.table and column.name in resolver.all_columns:
324                column_table = resolver.get_table(column.name)
325                if column_table:
326                    column.set("table", column_table)
327
328
329def _expand_stars(
330    scope: Scope,
331    resolver: Resolver,
332    using_column_tables: t.Dict[str, t.Any],
333    pseudocolumns: t.Set[str],
334) -> None:
335    """Expand stars to lists of column selections"""
336
337    new_selections = []
338    except_columns: t.Dict[int, t.Set[str]] = {}
339    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
340    coalesced_columns = set()
341
342    # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
343    pivot_columns = None
344    pivot_output_columns = None
345    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
346
347    has_pivoted_source = pivot and not pivot.args.get("unpivot")
348    if pivot and has_pivoted_source:
349        pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
350
351        pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
352        if not pivot_output_columns:
353            pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
354
355    for expression in scope.expression.selects:
356        if isinstance(expression, exp.Star):
357            tables = list(scope.selected_sources)
358            _add_except_columns(expression, tables, except_columns)
359            _add_replace_columns(expression, tables, replace_columns)
360        elif expression.is_star:
361            tables = [expression.table]
362            _add_except_columns(expression.this, tables, except_columns)
363            _add_replace_columns(expression.this, tables, replace_columns)
364        else:
365            new_selections.append(expression)
366            continue
367
368        for table in tables:
369            if table not in scope.sources:
370                raise OptimizeError(f"Unknown table: {table}")
371
372            columns = resolver.get_source_columns(table, only_visible=True)
373
374            if pseudocolumns:
375                columns = [name for name in columns if name.upper() not in pseudocolumns]
376
377            if columns and "*" not in columns:
378                if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
379                    implicit_columns = [col for col in columns if col not in pivot_columns]
380                    new_selections.extend(
381                        exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
382                        for name in implicit_columns + pivot_output_columns
383                    )
384                    continue
385
386                table_id = id(table)
387                for name in columns:
388                    if name in using_column_tables and table in using_column_tables[name]:
389                        if name in coalesced_columns:
390                            continue
391
392                        coalesced_columns.add(name)
393                        tables = using_column_tables[name]
394                        coalesce = [exp.column(name, table=table) for table in tables]
395
396                        new_selections.append(
397                            alias(
398                                exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
399                                alias=name,
400                                copy=False,
401                            )
402                        )
403                    elif name not in except_columns.get(table_id, set()):
404                        alias_ = replace_columns.get(table_id, {}).get(name, name)
405                        column = exp.column(name, table=table)
406                        new_selections.append(
407                            alias(column, alias_, copy=False) if alias_ != name else column
408                        )
409            else:
410                return
411
412    # Ensures we don't overwrite the initial selections with an empty list
413    if new_selections:
414        scope.expression.set("expressions", new_selections)
415
416
417def _add_except_columns(
418    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
419) -> None:
420    except_ = expression.args.get("except")
421
422    if not except_:
423        return
424
425    columns = {e.name for e in except_}
426
427    for table in tables:
428        except_columns[id(table)] = columns
429
430
431def _add_replace_columns(
432    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
433) -> None:
434    replace = expression.args.get("replace")
435
436    if not replace:
437        return
438
439    columns = {e.this.name: e.alias for e in replace}
440
441    for table in tables:
442        replace_columns[id(table)] = columns
443
444
445def _qualify_outputs(scope: Scope):
446    """Ensure all output columns are aliased"""
447    new_selections = []
448
449    for i, (selection, aliased_column) in enumerate(
450        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
451    ):
452        if isinstance(selection, exp.Subquery):
453            if not selection.output_name:
454                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
455        elif not isinstance(selection, exp.Alias) and not selection.is_star:
456            selection = alias(
457                selection,
458                alias=selection.output_name or f"_col_{i}",
459            )
460        if aliased_column:
461            selection.set("alias", exp.to_identifier(aliased_column))
462
463        new_selections.append(selection)
464
465    scope.expression.set("expressions", new_selections)
466
467
468def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
469    """Makes sure all identifiers that need to be quoted are quoted."""
470    return expression.transform(
471        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
472    )
473
474
475class Resolver:
476    """
477    Helper for resolving columns.
478
479    This is a class so we can lazily load some things and easily share them across functions.
480    """
481
482    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
483        self.scope = scope
484        self.schema = schema
485        self._source_columns = None
486        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
487        self._all_columns = None
488        self._infer_schema = infer_schema
489
490    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
491        """
492        Get the table for a column name.
493
494        Args:
495            column_name: The column name to find the table for.
496        Returns:
497            The table name if it can be found/inferred.
498        """
499        if self._unambiguous_columns is None:
500            self._unambiguous_columns = self._get_unambiguous_columns(
501                self._get_all_source_columns()
502            )
503
504        table_name = self._unambiguous_columns.get(column_name)
505
506        if not table_name and self._infer_schema:
507            sources_without_schema = tuple(
508                source
509                for source, columns in self._get_all_source_columns().items()
510                if not columns or "*" in columns
511            )
512            if len(sources_without_schema) == 1:
513                table_name = sources_without_schema[0]
514
515        if table_name not in self.scope.selected_sources:
516            return exp.to_identifier(table_name)
517
518        node, _ = self.scope.selected_sources.get(table_name)
519
520        if isinstance(node, exp.Subqueryable):
521            while node and node.alias != table_name:
522                node = node.parent
523
524        node_alias = node.args.get("alias")
525        if node_alias:
526            return exp.to_identifier(node_alias.this)
527
528        return exp.to_identifier(table_name)
529
530    @property
531    def all_columns(self):
532        """All available columns of all sources in this scope"""
533        if self._all_columns is None:
534            self._all_columns = {
535                column for columns in self._get_all_source_columns().values() for column in columns
536            }
537        return self._all_columns
538
539    def get_source_columns(self, name, only_visible=False):
540        """Resolve the source columns for a given source `name`"""
541        if name not in self.scope.sources:
542            raise OptimizeError(f"Unknown table: {name}")
543
544        source = self.scope.sources[name]
545
546        # If referencing a table, return the columns from the schema
547        if isinstance(source, exp.Table):
548            return self.schema.column_names(source, only_visible)
549
550        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
551            return source.expression.alias_column_names
552
553        # Otherwise, if referencing another scope, return that scope's named selects
554        return source.expression.named_selects
555
556    def _get_all_source_columns(self):
557        if self._source_columns is None:
558            self._source_columns = {
559                k: self.get_source_columns(k)
560                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
561            }
562        return self._source_columns
563
564    def _get_unambiguous_columns(self, source_columns):
565        """
566        Find all the unambiguous columns in sources.
567
568        Args:
569            source_columns (dict): Mapping of names to source columns
570        Returns:
571            dict: Mapping of column name to source name
572        """
573        if not source_columns:
574            return {}
575
576        source_columns = list(source_columns.items())
577
578        first_table, first_columns = source_columns[0]
579        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
580        all_columns = set(unambiguous_columns)
581
582        for table, columns in source_columns[1:]:
583            unique = self._find_unique_columns(columns)
584            ambiguous = set(all_columns).intersection(unique)
585            all_columns.update(columns)
586            for column in ambiguous:
587                unambiguous_columns.pop(column, None)
588            for column in unique.difference(ambiguous):
589                unambiguous_columns[column] = table
590
591        return unambiguous_columns
592
593    @staticmethod
594    def _find_unique_columns(columns):
595        """
596        Find the unique columns in a list of columns.
597
598        Example:
599            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
600            ['a', 'c']
601
602        This is necessary because duplicate column names are ambiguous.
603        """
604        counts = {}
605        for column in columns:
606            counts[column] = counts.get(column, 0) + 1
607        return {column for column, count in counts.items() if count == 1}
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
16def qualify_columns(
17    expression: exp.Expression,
18    schema: t.Dict | Schema,
19    expand_alias_refs: bool = True,
20    infer_schema: t.Optional[bool] = None,
21) -> exp.Expression:
22    """
23    Rewrite sqlglot AST to have fully qualified columns.
24
25    Example:
26        >>> import sqlglot
27        >>> schema = {"tbl": {"col": "INT"}}
28        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
29        >>> qualify_columns(expression, schema).sql()
30        'SELECT tbl.col AS col FROM tbl'
31
32    Args:
33        expression: Expression to qualify.
34        schema: Database schema.
35        expand_alias_refs: Whether or not to expand references to aliases.
36        infer_schema: Whether or not to infer the schema if missing.
37
38    Returns:
39        The qualified expression.
40    """
41    schema = ensure_schema(schema)
42    infer_schema = schema.empty if infer_schema is None else infer_schema
43    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
44
45    for scope in traverse_scope(expression):
46        resolver = Resolver(scope, schema, infer_schema=infer_schema)
47        _pop_table_column_aliases(scope.ctes)
48        _pop_table_column_aliases(scope.derived_tables)
49        using_column_tables = _expand_using(scope, resolver)
50
51        if schema.empty and expand_alias_refs:
52            _expand_alias_refs(scope, resolver)
53
54        _qualify_columns(scope, resolver)
55
56        if not schema.empty and expand_alias_refs:
57            _expand_alias_refs(scope, resolver)
58
59        if not isinstance(scope.expression, exp.UDTF):
60            _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
61            _qualify_outputs(scope)
62        _expand_group_by(scope)
63        _expand_order_by(scope, resolver)
64
65    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether or not to expand references to aliases.
  • infer_schema: Whether or not to infer the schema if missing.
Returns:

The qualified expression.

def validate_qualify_columns(expression: ~E) -> ~E:
68def validate_qualify_columns(expression: E) -> E:
69    """Raise an `OptimizeError` if any columns aren't qualified"""
70    unqualified_columns = []
71    for scope in traverse_scope(expression):
72        if isinstance(scope.expression, exp.Select):
73            unqualified_columns.extend(scope.unqualified_columns)
74            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
75                column = scope.external_columns[0]
76                raise OptimizeError(
77                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
78                )
79
80    if unqualified_columns:
81        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
82    return expression

Raise an OptimizeError if any columns aren't qualified

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
469def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
470    """Makes sure all identifiers that need to be quoted are quoted."""
471    return expression.transform(
472        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
473    )

Makes sure all identifiers that need to be quoted are quoted.

class Resolver:
476class Resolver:
477    """
478    Helper for resolving columns.
479
480    This is a class so we can lazily load some things and easily share them across functions.
481    """
482
483    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
484        self.scope = scope
485        self.schema = schema
486        self._source_columns = None
487        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
488        self._all_columns = None
489        self._infer_schema = infer_schema
490
491    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
492        """
493        Get the table for a column name.
494
495        Args:
496            column_name: The column name to find the table for.
497        Returns:
498            The table name if it can be found/inferred.
499        """
500        if self._unambiguous_columns is None:
501            self._unambiguous_columns = self._get_unambiguous_columns(
502                self._get_all_source_columns()
503            )
504
505        table_name = self._unambiguous_columns.get(column_name)
506
507        if not table_name and self._infer_schema:
508            sources_without_schema = tuple(
509                source
510                for source, columns in self._get_all_source_columns().items()
511                if not columns or "*" in columns
512            )
513            if len(sources_without_schema) == 1:
514                table_name = sources_without_schema[0]
515
516        if table_name not in self.scope.selected_sources:
517            return exp.to_identifier(table_name)
518
519        node, _ = self.scope.selected_sources.get(table_name)
520
521        if isinstance(node, exp.Subqueryable):
522            while node and node.alias != table_name:
523                node = node.parent
524
525        node_alias = node.args.get("alias")
526        if node_alias:
527            return exp.to_identifier(node_alias.this)
528
529        return exp.to_identifier(table_name)
530
531    @property
532    def all_columns(self):
533        """All available columns of all sources in this scope"""
534        if self._all_columns is None:
535            self._all_columns = {
536                column for columns in self._get_all_source_columns().values() for column in columns
537            }
538        return self._all_columns
539
540    def get_source_columns(self, name, only_visible=False):
541        """Resolve the source columns for a given source `name`"""
542        if name not in self.scope.sources:
543            raise OptimizeError(f"Unknown table: {name}")
544
545        source = self.scope.sources[name]
546
547        # If referencing a table, return the columns from the schema
548        if isinstance(source, exp.Table):
549            return self.schema.column_names(source, only_visible)
550
551        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
552            return source.expression.alias_column_names
553
554        # Otherwise, if referencing another scope, return that scope's named selects
555        return source.expression.named_selects
556
557    def _get_all_source_columns(self):
558        if self._source_columns is None:
559            self._source_columns = {
560                k: self.get_source_columns(k)
561                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
562            }
563        return self._source_columns
564
565    def _get_unambiguous_columns(self, source_columns):
566        """
567        Find all the unambiguous columns in sources.
568
569        Args:
570            source_columns (dict): Mapping of names to source columns
571        Returns:
572            dict: Mapping of column name to source name
573        """
574        if not source_columns:
575            return {}
576
577        source_columns = list(source_columns.items())
578
579        first_table, first_columns = source_columns[0]
580        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
581        all_columns = set(unambiguous_columns)
582
583        for table, columns in source_columns[1:]:
584            unique = self._find_unique_columns(columns)
585            ambiguous = set(all_columns).intersection(unique)
586            all_columns.update(columns)
587            for column in ambiguous:
588                unambiguous_columns.pop(column, None)
589            for column in unique.difference(ambiguous):
590                unambiguous_columns[column] = table
591
592        return unambiguous_columns
593
594    @staticmethod
595    def _find_unique_columns(columns):
596        """
597        Find the unique columns in a list of columns.
598
599        Example:
600            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
601            ['a', 'c']
602
603        This is necessary because duplicate column names are ambiguous.
604        """
605        counts = {}
606        for column in columns:
607            counts[column] = counts.get(column, 0) + 1
608        return {column for column, count in counts.items() if count == 1}

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
483    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
484        self.scope = scope
485        self.schema = schema
486        self._source_columns = None
487        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
488        self._all_columns = None
489        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
491    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
492        """
493        Get the table for a column name.
494
495        Args:
496            column_name: The column name to find the table for.
497        Returns:
498            The table name if it can be found/inferred.
499        """
500        if self._unambiguous_columns is None:
501            self._unambiguous_columns = self._get_unambiguous_columns(
502                self._get_all_source_columns()
503            )
504
505        table_name = self._unambiguous_columns.get(column_name)
506
507        if not table_name and self._infer_schema:
508            sources_without_schema = tuple(
509                source
510                for source, columns in self._get_all_source_columns().items()
511                if not columns or "*" in columns
512            )
513            if len(sources_without_schema) == 1:
514                table_name = sources_without_schema[0]
515
516        if table_name not in self.scope.selected_sources:
517            return exp.to_identifier(table_name)
518
519        node, _ = self.scope.selected_sources.get(table_name)
520
521        if isinstance(node, exp.Subqueryable):
522            while node and node.alias != table_name:
523                node = node.parent
524
525        node_alias = node.args.get("alias")
526        if node_alias:
527            return exp.to_identifier(node_alias.this)
528
529        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns

All available columns of all sources in this scope

def get_source_columns(self, name, only_visible=False):
540    def get_source_columns(self, name, only_visible=False):
541        """Resolve the source columns for a given source `name`"""
542        if name not in self.scope.sources:
543            raise OptimizeError(f"Unknown table: {name}")
544
545        source = self.scope.sources[name]
546
547        # If referencing a table, return the columns from the schema
548        if isinstance(source, exp.Table):
549            return self.schema.column_names(source, only_visible)
550
551        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
552            return source.expression.alias_column_names
553
554        # Otherwise, if referencing another scope, return that scope's named selects
555        return source.expression.named_selects

Resolve the source columns for a given source name