Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot.errors import OptimizeError
  8from sqlglot.helper import seq_get
  9from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
 10from sqlglot.schema import Schema, ensure_schema
 11
 12
 13def qualify_columns(
 14    expression: exp.Expression,
 15    schema: dict | Schema,
 16    expand_alias_refs: bool = True,
 17    infer_schema: t.Optional[bool] = None,
 18) -> exp.Expression:
 19    """
 20    Rewrite sqlglot AST to have fully qualified columns.
 21
 22    Example:
 23        >>> import sqlglot
 24        >>> schema = {"tbl": {"col": "INT"}}
 25        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 26        >>> qualify_columns(expression, schema).sql()
 27        'SELECT tbl.col AS col FROM tbl'
 28
 29    Args:
 30        expression: expression to qualify
 31        schema: Database schema
 32        expand_alias_refs: whether or not to expand references to aliases
 33        infer_schema: whether or not to infer the schema if missing
 34    Returns:
 35        sqlglot.Expression: qualified expression
 36    """
 37    schema = ensure_schema(schema)
 38    infer_schema = schema.empty if infer_schema is None else infer_schema
 39
 40    for scope in traverse_scope(expression):
 41        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 42        _pop_table_column_aliases(scope.ctes)
 43        _pop_table_column_aliases(scope.derived_tables)
 44        using_column_tables = _expand_using(scope, resolver)
 45
 46        if schema.empty and expand_alias_refs:
 47            _expand_alias_refs(scope, resolver)
 48
 49        _qualify_columns(scope, resolver)
 50
 51        if not schema.empty and expand_alias_refs:
 52            _expand_alias_refs(scope, resolver)
 53
 54        if not isinstance(scope.expression, exp.UDTF):
 55            _expand_stars(scope, resolver, using_column_tables)
 56            _qualify_outputs(scope)
 57        _expand_group_by(scope, resolver)
 58        _expand_order_by(scope)
 59
 60    return expression
 61
 62
 63def validate_qualify_columns(expression):
 64    """Raise an `OptimizeError` if any columns aren't qualified"""
 65    unqualified_columns = []
 66    for scope in traverse_scope(expression):
 67        if isinstance(scope.expression, exp.Select):
 68            unqualified_columns.extend(scope.unqualified_columns)
 69            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 70                column = scope.external_columns[0]
 71                raise OptimizeError(
 72                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
 73                )
 74
 75    if unqualified_columns:
 76        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
 77    return expression
 78
 79
 80def _pop_table_column_aliases(derived_tables):
 81    """
 82    Remove table column aliases.
 83
 84    (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 85    """
 86    for derived_table in derived_tables:
 87        table_alias = derived_table.args.get("alias")
 88        if table_alias:
 89            table_alias.args.pop("columns", None)
 90
 91
 92def _expand_using(scope, resolver):
 93    joins = list(scope.find_all(exp.Join))
 94    names = {join.this.alias for join in joins}
 95    ordered = [key for key in scope.selected_sources if key not in names]
 96
 97    # Mapping of automatically joined column names to an ordered set of source names (dict).
 98    column_tables = {}
 99
100    for join in joins:
101        using = join.args.get("using")
102
103        if not using:
104            continue
105
106        join_table = join.this.alias_or_name
107
108        columns = {}
109
110        for k in scope.selected_sources:
111            if k in ordered:
112                for column in resolver.get_source_columns(k):
113                    if column not in columns:
114                        columns[column] = k
115
116        source_table = ordered[-1]
117        ordered.append(join_table)
118        join_columns = resolver.get_source_columns(join_table)
119        conditions = []
120
121        for identifier in using:
122            identifier = identifier.name
123            table = columns.get(identifier)
124
125            if not table or identifier not in join_columns:
126                if columns and join_columns:
127                    raise OptimizeError(f"Cannot automatically join: {identifier}")
128
129            table = table or source_table
130            conditions.append(
131                exp.condition(
132                    exp.EQ(
133                        this=exp.column(identifier, table=table),
134                        expression=exp.column(identifier, table=join_table),
135                    )
136                )
137            )
138
139            # Set all values in the dict to None, because we only care about the key ordering
140            tables = column_tables.setdefault(identifier, {})
141            if table not in tables:
142                tables[table] = None
143            if join_table not in tables:
144                tables[join_table] = None
145
146        join.args.pop("using")
147        join.set("on", exp.and_(*conditions, copy=False))
148
149    if column_tables:
150        for column in scope.columns:
151            if not column.table and column.name in column_tables:
152                tables = column_tables[column.name]
153                coalesce = [exp.column(column.name, table=table) for table in tables]
154                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
155
156                # Ensure selects keep their output name
157                if isinstance(column.parent, exp.Select):
158                    replacement = alias(replacement, alias=column.name, copy=False)
159
160                scope.replace(column, replacement)
161
162    return column_tables
163
164
165def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
166    expression = scope.expression
167
168    if not isinstance(expression, exp.Select):
169        return
170
171    alias_to_expression: t.Dict[str, exp.Expression] = {}
172
173    def replace_columns(
174        node: t.Optional[exp.Expression], expand: bool = True, resolve_agg: bool = False
175    ):
176        if not node:
177            return
178
179        for column, *_ in walk_in_scope(node):
180            if not isinstance(column, exp.Column):
181                continue
182            table = resolver.get_table(column.name) if resolve_agg and not column.table else None
183            if table and column.find_ancestor(exp.AggFunc):
184                column.set("table", table)
185            elif expand and not column.table and column.name in alias_to_expression:
186                column.replace(alias_to_expression[column.name].copy())
187
188    for projection in scope.selects:
189        replace_columns(projection)
190
191        if isinstance(projection, exp.Alias):
192            alias_to_expression[projection.alias] = projection.this
193
194    replace_columns(expression.args.get("where"))
195    replace_columns(expression.args.get("group"))
196    replace_columns(expression.args.get("having"), resolve_agg=True)
197    replace_columns(expression.args.get("qualify"), resolve_agg=True)
198    replace_columns(expression.args.get("order"), expand=False, resolve_agg=True)
199    scope.clear_cache()
200
201
202def _expand_group_by(scope, resolver):
203    group = scope.expression.args.get("group")
204    if not group:
205        return
206
207    group.set("expressions", _expand_positional_references(scope, group.expressions))
208    scope.expression.set("group", group)
209
210
211def _expand_order_by(scope):
212    order = scope.expression.args.get("order")
213    if not order:
214        return
215
216    ordereds = order.expressions
217    for ordered, new_expression in zip(
218        ordereds,
219        _expand_positional_references(scope, (o.this for o in ordereds)),
220    ):
221        ordered.set("this", new_expression)
222
223
224def _expand_positional_references(scope, expressions):
225    new_nodes = []
226    for node in expressions:
227        if node.is_int:
228            try:
229                select = scope.selects[int(node.name) - 1]
230            except IndexError:
231                raise OptimizeError(f"Unknown output column: {node.name}")
232            if isinstance(select, exp.Alias):
233                select = select.this
234            new_nodes.append(select.copy())
235            scope.clear_cache()
236        else:
237            new_nodes.append(node)
238
239    return new_nodes
240
241
242def _qualify_columns(scope, resolver):
243    """Disambiguate columns, ensuring each column specifies a source"""
244    for column in scope.columns:
245        column_table = column.table
246        column_name = column.name
247
248        if column_table and column_table in scope.sources:
249            source_columns = resolver.get_source_columns(column_table)
250            if source_columns and column_name not in source_columns and "*" not in source_columns:
251                raise OptimizeError(f"Unknown column: {column_name}")
252
253        if not column_table:
254            if scope.pivots and not column.find_ancestor(exp.Pivot):
255                # If the column is under the Pivot expression, we need to qualify it
256                # using the name of the pivoted source instead of the pivot's alias
257                column.set("table", exp.to_identifier(scope.pivots[0].alias))
258                continue
259
260            column_table = resolver.get_table(column_name)
261
262            # column_table can be a '' because bigquery unnest has no table alias
263            if column_table:
264                column.set("table", column_table)
265        elif column_table not in scope.sources and (
266            not scope.parent or column_table not in scope.parent.sources
267        ):
268            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
269            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
270
271            root, *parts = column.parts
272
273            if root.name in scope.sources:
274                # struct is already qualified, but we still need to change the AST representation
275                column_table = root
276                root, *parts = parts
277            else:
278                column_table = resolver.get_table(root.name)
279
280            if column_table:
281                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
282
283    for pivot in scope.pivots:
284        for column in pivot.find_all(exp.Column):
285            if not column.table and column.name in resolver.all_columns:
286                column_table = resolver.get_table(column.name)
287                if column_table:
288                    column.set("table", column_table)
289
290
291def _expand_stars(scope, resolver, using_column_tables):
292    """Expand stars to lists of column selections"""
293
294    new_selections = []
295    except_columns = {}
296    replace_columns = {}
297    coalesced_columns = set()
298
299    # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
300    pivot_columns = None
301    pivot_output_columns = None
302    pivot = seq_get(scope.pivots, 0)
303
304    has_pivoted_source = pivot and not pivot.args.get("unpivot")
305    if has_pivoted_source:
306        pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
307
308        pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
309        if not pivot_output_columns:
310            pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
311
312    for expression in scope.selects:
313        if isinstance(expression, exp.Star):
314            tables = list(scope.selected_sources)
315            _add_except_columns(expression, tables, except_columns)
316            _add_replace_columns(expression, tables, replace_columns)
317        elif expression.is_star:
318            tables = [expression.table]
319            _add_except_columns(expression.this, tables, except_columns)
320            _add_replace_columns(expression.this, tables, replace_columns)
321        else:
322            new_selections.append(expression)
323            continue
324
325        for table in tables:
326            if table not in scope.sources:
327                raise OptimizeError(f"Unknown table: {table}")
328
329            columns = resolver.get_source_columns(table, only_visible=True)
330
331            if columns and "*" not in columns:
332                if has_pivoted_source:
333                    implicit_columns = [col for col in columns if col not in pivot_columns]
334                    new_selections.extend(
335                        exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
336                        for name in implicit_columns + pivot_output_columns
337                    )
338                    continue
339
340                table_id = id(table)
341                for name in columns:
342                    if name in using_column_tables and table in using_column_tables[name]:
343                        if name in coalesced_columns:
344                            continue
345
346                        coalesced_columns.add(name)
347                        tables = using_column_tables[name]
348                        coalesce = [exp.column(name, table=table) for table in tables]
349
350                        new_selections.append(
351                            alias(
352                                exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
353                                alias=name,
354                                copy=False,
355                            )
356                        )
357                    elif name not in except_columns.get(table_id, set()):
358                        alias_ = replace_columns.get(table_id, {}).get(name, name)
359                        column = exp.column(name, table=table)
360                        new_selections.append(
361                            alias(column, alias_, copy=False) if alias_ != name else column
362                        )
363            else:
364                return
365
366    scope.expression.set("expressions", new_selections)
367
368
369def _add_except_columns(expression, tables, except_columns):
370    except_ = expression.args.get("except")
371
372    if not except_:
373        return
374
375    columns = {e.name for e in except_}
376
377    for table in tables:
378        except_columns[id(table)] = columns
379
380
381def _add_replace_columns(expression, tables, replace_columns):
382    replace = expression.args.get("replace")
383
384    if not replace:
385        return
386
387    columns = {e.this.name: e.alias for e in replace}
388
389    for table in tables:
390        replace_columns[id(table)] = columns
391
392
393def _qualify_outputs(scope):
394    """Ensure all output columns are aliased"""
395    new_selections = []
396
397    for i, (selection, aliased_column) in enumerate(
398        itertools.zip_longest(scope.selects, scope.outer_column_list)
399    ):
400        if isinstance(selection, exp.Subquery):
401            if not selection.output_name:
402                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
403        elif not isinstance(selection, exp.Alias) and not selection.is_star:
404            selection = alias(
405                selection,
406                alias=selection.output_name or f"_col_{i}",
407                quoted=True
408                if isinstance(selection, exp.Column) and selection.this.quoted
409                else None,
410            )
411        if aliased_column:
412            selection.set("alias", exp.to_identifier(aliased_column))
413
414        new_selections.append(selection)
415
416    scope.expression.set("expressions", new_selections)
417
418
419class Resolver:
420    """
421    Helper for resolving columns.
422
423    This is a class so we can lazily load some things and easily share them across functions.
424    """
425
426    def __init__(self, scope, schema, infer_schema: bool = True):
427        self.scope = scope
428        self.schema = schema
429        self._source_columns = None
430        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
431        self._all_columns = None
432        self._infer_schema = infer_schema
433
434    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
435        """
436        Get the table for a column name.
437
438        Args:
439            column_name: The column name to find the table for.
440        Returns:
441            The table name if it can be found/inferred.
442        """
443        if self._unambiguous_columns is None:
444            self._unambiguous_columns = self._get_unambiguous_columns(
445                self._get_all_source_columns()
446            )
447
448        table_name = self._unambiguous_columns.get(column_name)
449
450        if not table_name and self._infer_schema:
451            sources_without_schema = tuple(
452                source
453                for source, columns in self._get_all_source_columns().items()
454                if not columns or "*" in columns
455            )
456            if len(sources_without_schema) == 1:
457                table_name = sources_without_schema[0]
458
459        if table_name not in self.scope.selected_sources:
460            return exp.to_identifier(table_name)
461
462        node, _ = self.scope.selected_sources.get(table_name)
463
464        if isinstance(node, exp.Subqueryable):
465            while node and node.alias != table_name:
466                node = node.parent
467
468        node_alias = node.args.get("alias")
469        if node_alias:
470            return exp.to_identifier(node_alias.this)
471
472        return exp.to_identifier(
473            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
474        )
475
476    @property
477    def all_columns(self):
478        """All available columns of all sources in this scope"""
479        if self._all_columns is None:
480            self._all_columns = {
481                column for columns in self._get_all_source_columns().values() for column in columns
482            }
483        return self._all_columns
484
485    def get_source_columns(self, name, only_visible=False):
486        """Resolve the source columns for a given source `name`"""
487        if name not in self.scope.sources:
488            raise OptimizeError(f"Unknown table: {name}")
489
490        source = self.scope.sources[name]
491
492        # If referencing a table, return the columns from the schema
493        if isinstance(source, exp.Table):
494            return self.schema.column_names(source, only_visible)
495
496        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
497            return source.expression.alias_column_names
498
499        # Otherwise, if referencing another scope, return that scope's named selects
500        return source.expression.named_selects
501
502    def _get_all_source_columns(self):
503        if self._source_columns is None:
504            self._source_columns = {
505                k: self.get_source_columns(k)
506                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
507            }
508        return self._source_columns
509
510    def _get_unambiguous_columns(self, source_columns):
511        """
512        Find all the unambiguous columns in sources.
513
514        Args:
515            source_columns (dict): Mapping of names to source columns
516        Returns:
517            dict: Mapping of column name to source name
518        """
519        if not source_columns:
520            return {}
521
522        source_columns = list(source_columns.items())
523
524        first_table, first_columns = source_columns[0]
525        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
526        all_columns = set(unambiguous_columns)
527
528        for table, columns in source_columns[1:]:
529            unique = self._find_unique_columns(columns)
530            ambiguous = set(all_columns).intersection(unique)
531            all_columns.update(columns)
532            for column in ambiguous:
533                unambiguous_columns.pop(column, None)
534            for column in unique.difference(ambiguous):
535                unambiguous_columns[column] = table
536
537        return unambiguous_columns
538
539    @staticmethod
540    def _find_unique_columns(columns):
541        """
542        Find the unique columns in a list of columns.
543
544        Example:
545            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
546            ['a', 'c']
547
548        This is necessary because duplicate column names are ambiguous.
549        """
550        counts = {}
551        for column in columns:
552            counts[column] = counts.get(column, 0) + 1
553        return {column for column, count in counts.items() if count == 1}
def qualify_columns( expression: sqlglot.expressions.Expression, schema: dict | sqlglot.schema.Schema, expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
14def qualify_columns(
15    expression: exp.Expression,
16    schema: dict | Schema,
17    expand_alias_refs: bool = True,
18    infer_schema: t.Optional[bool] = None,
19) -> exp.Expression:
20    """
21    Rewrite sqlglot AST to have fully qualified columns.
22
23    Example:
24        >>> import sqlglot
25        >>> schema = {"tbl": {"col": "INT"}}
26        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
27        >>> qualify_columns(expression, schema).sql()
28        'SELECT tbl.col AS col FROM tbl'
29
30    Args:
31        expression: expression to qualify
32        schema: Database schema
33        expand_alias_refs: whether or not to expand references to aliases
34        infer_schema: whether or not to infer the schema if missing
35    Returns:
36        sqlglot.Expression: qualified expression
37    """
38    schema = ensure_schema(schema)
39    infer_schema = schema.empty if infer_schema is None else infer_schema
40
41    for scope in traverse_scope(expression):
42        resolver = Resolver(scope, schema, infer_schema=infer_schema)
43        _pop_table_column_aliases(scope.ctes)
44        _pop_table_column_aliases(scope.derived_tables)
45        using_column_tables = _expand_using(scope, resolver)
46
47        if schema.empty and expand_alias_refs:
48            _expand_alias_refs(scope, resolver)
49
50        _qualify_columns(scope, resolver)
51
52        if not schema.empty and expand_alias_refs:
53            _expand_alias_refs(scope, resolver)
54
55        if not isinstance(scope.expression, exp.UDTF):
56            _expand_stars(scope, resolver, using_column_tables)
57            _qualify_outputs(scope)
58        _expand_group_by(scope, resolver)
59        _expand_order_by(scope)
60
61    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: expression to qualify
  • schema: Database schema
  • expand_alias_refs: whether or not to expand references to aliases
  • infer_schema: whether or not to infer the schema if missing
Returns:

sqlglot.Expression: qualified expression

def validate_qualify_columns(expression):
64def validate_qualify_columns(expression):
65    """Raise an `OptimizeError` if any columns aren't qualified"""
66    unqualified_columns = []
67    for scope in traverse_scope(expression):
68        if isinstance(scope.expression, exp.Select):
69            unqualified_columns.extend(scope.unqualified_columns)
70            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
71                column = scope.external_columns[0]
72                raise OptimizeError(
73                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
74                )
75
76    if unqualified_columns:
77        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
78    return expression

Raise an OptimizeError if any columns aren't qualified

class Resolver:
420class Resolver:
421    """
422    Helper for resolving columns.
423
424    This is a class so we can lazily load some things and easily share them across functions.
425    """
426
427    def __init__(self, scope, schema, infer_schema: bool = True):
428        self.scope = scope
429        self.schema = schema
430        self._source_columns = None
431        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
432        self._all_columns = None
433        self._infer_schema = infer_schema
434
435    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
436        """
437        Get the table for a column name.
438
439        Args:
440            column_name: The column name to find the table for.
441        Returns:
442            The table name if it can be found/inferred.
443        """
444        if self._unambiguous_columns is None:
445            self._unambiguous_columns = self._get_unambiguous_columns(
446                self._get_all_source_columns()
447            )
448
449        table_name = self._unambiguous_columns.get(column_name)
450
451        if not table_name and self._infer_schema:
452            sources_without_schema = tuple(
453                source
454                for source, columns in self._get_all_source_columns().items()
455                if not columns or "*" in columns
456            )
457            if len(sources_without_schema) == 1:
458                table_name = sources_without_schema[0]
459
460        if table_name not in self.scope.selected_sources:
461            return exp.to_identifier(table_name)
462
463        node, _ = self.scope.selected_sources.get(table_name)
464
465        if isinstance(node, exp.Subqueryable):
466            while node and node.alias != table_name:
467                node = node.parent
468
469        node_alias = node.args.get("alias")
470        if node_alias:
471            return exp.to_identifier(node_alias.this)
472
473        return exp.to_identifier(
474            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
475        )
476
477    @property
478    def all_columns(self):
479        """All available columns of all sources in this scope"""
480        if self._all_columns is None:
481            self._all_columns = {
482                column for columns in self._get_all_source_columns().values() for column in columns
483            }
484        return self._all_columns
485
486    def get_source_columns(self, name, only_visible=False):
487        """Resolve the source columns for a given source `name`"""
488        if name not in self.scope.sources:
489            raise OptimizeError(f"Unknown table: {name}")
490
491        source = self.scope.sources[name]
492
493        # If referencing a table, return the columns from the schema
494        if isinstance(source, exp.Table):
495            return self.schema.column_names(source, only_visible)
496
497        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
498            return source.expression.alias_column_names
499
500        # Otherwise, if referencing another scope, return that scope's named selects
501        return source.expression.named_selects
502
503    def _get_all_source_columns(self):
504        if self._source_columns is None:
505            self._source_columns = {
506                k: self.get_source_columns(k)
507                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
508            }
509        return self._source_columns
510
511    def _get_unambiguous_columns(self, source_columns):
512        """
513        Find all the unambiguous columns in sources.
514
515        Args:
516            source_columns (dict): Mapping of names to source columns
517        Returns:
518            dict: Mapping of column name to source name
519        """
520        if not source_columns:
521            return {}
522
523        source_columns = list(source_columns.items())
524
525        first_table, first_columns = source_columns[0]
526        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
527        all_columns = set(unambiguous_columns)
528
529        for table, columns in source_columns[1:]:
530            unique = self._find_unique_columns(columns)
531            ambiguous = set(all_columns).intersection(unique)
532            all_columns.update(columns)
533            for column in ambiguous:
534                unambiguous_columns.pop(column, None)
535            for column in unique.difference(ambiguous):
536                unambiguous_columns[column] = table
537
538        return unambiguous_columns
539
540    @staticmethod
541    def _find_unique_columns(columns):
542        """
543        Find the unique columns in a list of columns.
544
545        Example:
546            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
547            ['a', 'c']
548
549        This is necessary because duplicate column names are ambiguous.
550        """
551        counts = {}
552        for column in columns:
553            counts[column] = counts.get(column, 0) + 1
554        return {column for column, count in counts.items() if count == 1}

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver(scope, schema, infer_schema: bool = True)
427    def __init__(self, scope, schema, infer_schema: bool = True):
428        self.scope = scope
429        self.schema = schema
430        self._source_columns = None
431        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
432        self._all_columns = None
433        self._infer_schema = infer_schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
435    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
436        """
437        Get the table for a column name.
438
439        Args:
440            column_name: The column name to find the table for.
441        Returns:
442            The table name if it can be found/inferred.
443        """
444        if self._unambiguous_columns is None:
445            self._unambiguous_columns = self._get_unambiguous_columns(
446                self._get_all_source_columns()
447            )
448
449        table_name = self._unambiguous_columns.get(column_name)
450
451        if not table_name and self._infer_schema:
452            sources_without_schema = tuple(
453                source
454                for source, columns in self._get_all_source_columns().items()
455                if not columns or "*" in columns
456            )
457            if len(sources_without_schema) == 1:
458                table_name = sources_without_schema[0]
459
460        if table_name not in self.scope.selected_sources:
461            return exp.to_identifier(table_name)
462
463        node, _ = self.scope.selected_sources.get(table_name)
464
465        if isinstance(node, exp.Subqueryable):
466            while node and node.alias != table_name:
467                node = node.parent
468
469        node_alias = node.args.get("alias")
470        if node_alias:
471            return exp.to_identifier(node_alias.this)
472
473        return exp.to_identifier(
474            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
475        )

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns

All available columns of all sources in this scope

def get_source_columns(self, name, only_visible=False):
486    def get_source_columns(self, name, only_visible=False):
487        """Resolve the source columns for a given source `name`"""
488        if name not in self.scope.sources:
489            raise OptimizeError(f"Unknown table: {name}")
490
491        source = self.scope.sources[name]
492
493        # If referencing a table, return the columns from the schema
494        if isinstance(source, exp.Table):
495            return self.schema.column_names(source, only_visible)
496
497        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
498            return source.expression.alias_column_names
499
500        # Otherwise, if referencing another scope, return that scope's named selects
501        return source.expression.named_selects

Resolve the source columns for a given source name