Edit on GitHub

sqlglot.optimizer.qualify_columns

  1from __future__ import annotations
  2
  3import itertools
  4import typing as t
  5
  6from sqlglot import alias, exp
  7from sqlglot._typing import E
  8from sqlglot.dialects.dialect import Dialect, DialectType
  9from sqlglot.errors import OptimizeError
 10from sqlglot.helper import seq_get
 11from sqlglot.optimizer.scope import Scope, traverse_scope, walk_in_scope
 12from sqlglot.optimizer.simplify import simplify_parens
 13from sqlglot.schema import Schema, ensure_schema
 14
 15
 16def qualify_columns(
 17    expression: exp.Expression,
 18    schema: t.Dict | Schema,
 19    expand_alias_refs: bool = True,
 20    infer_schema: t.Optional[bool] = None,
 21) -> exp.Expression:
 22    """
 23    Rewrite sqlglot AST to have fully qualified columns.
 24
 25    Example:
 26        >>> import sqlglot
 27        >>> schema = {"tbl": {"col": "INT"}}
 28        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 29        >>> qualify_columns(expression, schema).sql()
 30        'SELECT tbl.col AS col FROM tbl'
 31
 32    Args:
 33        expression: Expression to qualify.
 34        schema: Database schema.
 35        expand_alias_refs: Whether or not to expand references to aliases.
 36        infer_schema: Whether or not to infer the schema if missing.
 37
 38    Returns:
 39        The qualified expression.
 40    """
 41    schema = ensure_schema(schema)
 42    infer_schema = schema.empty if infer_schema is None else infer_schema
 43    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
 44
 45    for scope in traverse_scope(expression):
 46        resolver = Resolver(scope, schema, infer_schema=infer_schema)
 47        _pop_table_column_aliases(scope.ctes)
 48        _pop_table_column_aliases(scope.derived_tables)
 49        using_column_tables = _expand_using(scope, resolver)
 50
 51        if schema.empty and expand_alias_refs:
 52            _expand_alias_refs(scope, resolver)
 53
 54        _qualify_columns(scope, resolver)
 55
 56        if not schema.empty and expand_alias_refs:
 57            _expand_alias_refs(scope, resolver)
 58
 59        if not isinstance(scope.expression, exp.UDTF):
 60            _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
 61            _qualify_outputs(scope)
 62
 63        _expand_group_by(scope)
 64        _expand_order_by(scope, resolver)
 65
 66    return expression
 67
 68
 69def validate_qualify_columns(expression: E) -> E:
 70    """Raise an `OptimizeError` if any columns aren't qualified"""
 71    unqualified_columns = []
 72    for scope in traverse_scope(expression):
 73        if isinstance(scope.expression, exp.Select):
 74            unqualified_columns.extend(scope.unqualified_columns)
 75            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
 76                column = scope.external_columns[0]
 77                raise OptimizeError(
 78                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
 79                )
 80
 81    if unqualified_columns:
 82        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
 83    return expression
 84
 85
 86def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None:
 87    """
 88    Remove table column aliases.
 89
 90    For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 91    """
 92    for derived_table in derived_tables:
 93        table_alias = derived_table.args.get("alias")
 94        if table_alias:
 95            table_alias.args.pop("columns", None)
 96
 97
 98def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]:
 99    joins = list(scope.find_all(exp.Join))
100    names = {join.alias_or_name for join in joins}
101    ordered = [key for key in scope.selected_sources if key not in names]
102
103    # Mapping of automatically joined column names to an ordered set of source names (dict).
104    column_tables: t.Dict[str, t.Dict[str, t.Any]] = {}
105
106    for join in joins:
107        using = join.args.get("using")
108
109        if not using:
110            continue
111
112        join_table = join.alias_or_name
113
114        columns = {}
115
116        for source_name in scope.selected_sources:
117            if source_name in ordered:
118                for column_name in resolver.get_source_columns(source_name):
119                    if column_name not in columns:
120                        columns[column_name] = source_name
121
122        source_table = ordered[-1]
123        ordered.append(join_table)
124        join_columns = resolver.get_source_columns(join_table)
125        conditions = []
126
127        for identifier in using:
128            identifier = identifier.name
129            table = columns.get(identifier)
130
131            if not table or identifier not in join_columns:
132                if (columns and "*" not in columns) and join_columns:
133                    raise OptimizeError(f"Cannot automatically join: {identifier}")
134
135            table = table or source_table
136            conditions.append(
137                exp.condition(
138                    exp.EQ(
139                        this=exp.column(identifier, table=table),
140                        expression=exp.column(identifier, table=join_table),
141                    )
142                )
143            )
144
145            # Set all values in the dict to None, because we only care about the key ordering
146            tables = column_tables.setdefault(identifier, {})
147            if table not in tables:
148                tables[table] = None
149            if join_table not in tables:
150                tables[join_table] = None
151
152        join.args.pop("using")
153        join.set("on", exp.and_(*conditions, copy=False))
154
155    if column_tables:
156        for column in scope.columns:
157            if not column.table and column.name in column_tables:
158                tables = column_tables[column.name]
159                coalesce = [exp.column(column.name, table=table) for table in tables]
160                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
161
162                # Ensure selects keep their output name
163                if isinstance(column.parent, exp.Select):
164                    replacement = alias(replacement, alias=column.name, copy=False)
165
166                scope.replace(column, replacement)
167
168    return column_tables
169
170
171def _expand_alias_refs(scope: Scope, resolver: Resolver) -> None:
172    expression = scope.expression
173
174    if not isinstance(expression, exp.Select):
175        return
176
177    alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {}
178
179    def replace_columns(
180        node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False
181    ) -> None:
182        if not node:
183            return
184
185        for column, *_ in walk_in_scope(node):
186            if not isinstance(column, exp.Column):
187                continue
188
189            table = resolver.get_table(column.name) if resolve_table and not column.table else None
190            alias_expr, i = alias_to_expression.get(column.name, (None, 1))
191            double_agg = (
192                (alias_expr.find(exp.AggFunc) and column.find_ancestor(exp.AggFunc))
193                if alias_expr
194                else False
195            )
196
197            if table and (not alias_expr or double_agg):
198                column.set("table", table)
199            elif not column.table and alias_expr and not double_agg:
200                if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table):
201                    if literal_index:
202                        column.replace(exp.Literal.number(i))
203                else:
204                    column = column.replace(exp.paren(alias_expr))
205                    simplified = simplify_parens(column)
206                    if simplified is not column:
207                        column.replace(simplified)
208
209    for i, projection in enumerate(scope.expression.selects):
210        replace_columns(projection)
211
212        if isinstance(projection, exp.Alias):
213            alias_to_expression[projection.alias] = (projection.this, i + 1)
214
215    replace_columns(expression.args.get("where"))
216    replace_columns(expression.args.get("group"), literal_index=True)
217    replace_columns(expression.args.get("having"), resolve_table=True)
218    replace_columns(expression.args.get("qualify"), resolve_table=True)
219    scope.clear_cache()
220
221
222def _expand_group_by(scope: Scope) -> None:
223    expression = scope.expression
224    group = expression.args.get("group")
225    if not group:
226        return
227
228    group.set("expressions", _expand_positional_references(scope, group.expressions))
229    expression.set("group", group)
230
231
232def _expand_order_by(scope: Scope, resolver: Resolver) -> None:
233    order = scope.expression.args.get("order")
234    if not order:
235        return
236
237    ordereds = order.expressions
238    for ordered, new_expression in zip(
239        ordereds,
240        _expand_positional_references(scope, (o.this for o in ordereds)),
241    ):
242        for agg in ordered.find_all(exp.AggFunc):
243            for col in agg.find_all(exp.Column):
244                if not col.table:
245                    col.set("table", resolver.get_table(col.name))
246
247        ordered.set("this", new_expression)
248
249    if scope.expression.args.get("group"):
250        selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects}
251
252        for ordered in ordereds:
253            ordered = ordered.this
254
255            ordered.replace(
256                exp.to_identifier(_select_by_pos(scope, ordered).alias)
257                if ordered.is_int
258                else selects.get(ordered, ordered)
259            )
260
261
262def _expand_positional_references(scope: Scope, expressions: t.Iterable[E]) -> t.List[E]:
263    new_nodes = []
264    for node in expressions:
265        if node.is_int:
266            select = _select_by_pos(scope, t.cast(exp.Literal, node)).this
267
268            if isinstance(select, exp.Literal):
269                new_nodes.append(node)
270            else:
271                new_nodes.append(select.copy())
272                scope.clear_cache()
273        else:
274            new_nodes.append(node)
275
276    return new_nodes
277
278
279def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias:
280    try:
281        return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias)
282    except IndexError:
283        raise OptimizeError(f"Unknown output column: {node.name}")
284
285
286def _qualify_columns(scope: Scope, resolver: Resolver) -> None:
287    """Disambiguate columns, ensuring each column specifies a source"""
288    for column in scope.columns:
289        column_table = column.table
290        column_name = column.name
291
292        if column_table and column_table in scope.sources:
293            source_columns = resolver.get_source_columns(column_table)
294            if source_columns and column_name not in source_columns and "*" not in source_columns:
295                raise OptimizeError(f"Unknown column: {column_name}")
296
297        if not column_table:
298            if scope.pivots and not column.find_ancestor(exp.Pivot):
299                # If the column is under the Pivot expression, we need to qualify it
300                # using the name of the pivoted source instead of the pivot's alias
301                column.set("table", exp.to_identifier(scope.pivots[0].alias))
302                continue
303
304            column_table = resolver.get_table(column_name)
305
306            # column_table can be a '' because bigquery unnest has no table alias
307            if column_table:
308                column.set("table", column_table)
309        elif column_table not in scope.sources and (
310            not scope.parent or column_table not in scope.parent.sources
311        ):
312            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
313            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
314
315            root, *parts = column.parts
316
317            if root.name in scope.sources:
318                # struct is already qualified, but we still need to change the AST representation
319                column_table = root
320                root, *parts = parts
321            else:
322                column_table = resolver.get_table(root.name)
323
324            if column_table:
325                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
326
327    for pivot in scope.pivots:
328        for column in pivot.find_all(exp.Column):
329            if not column.table and column.name in resolver.all_columns:
330                column_table = resolver.get_table(column.name)
331                if column_table:
332                    column.set("table", column_table)
333
334
335def _expand_stars(
336    scope: Scope,
337    resolver: Resolver,
338    using_column_tables: t.Dict[str, t.Any],
339    pseudocolumns: t.Set[str],
340) -> None:
341    """Expand stars to lists of column selections"""
342
343    new_selections = []
344    except_columns: t.Dict[int, t.Set[str]] = {}
345    replace_columns: t.Dict[int, t.Dict[str, str]] = {}
346    coalesced_columns = set()
347
348    # TODO: handle optimization of multiple PIVOTs (and possibly UNPIVOTs) in the future
349    pivot_columns = None
350    pivot_output_columns = None
351    pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0))
352
353    has_pivoted_source = pivot and not pivot.args.get("unpivot")
354    if pivot and has_pivoted_source:
355        pivot_columns = set(col.output_name for col in pivot.find_all(exp.Column))
356
357        pivot_output_columns = [col.output_name for col in pivot.args.get("columns", [])]
358        if not pivot_output_columns:
359            pivot_output_columns = [col.alias_or_name for col in pivot.expressions]
360
361    for expression in scope.expression.selects:
362        if isinstance(expression, exp.Star):
363            tables = list(scope.selected_sources)
364            _add_except_columns(expression, tables, except_columns)
365            _add_replace_columns(expression, tables, replace_columns)
366        elif expression.is_star:
367            tables = [expression.table]
368            _add_except_columns(expression.this, tables, except_columns)
369            _add_replace_columns(expression.this, tables, replace_columns)
370        else:
371            new_selections.append(expression)
372            continue
373
374        for table in tables:
375            if table not in scope.sources:
376                raise OptimizeError(f"Unknown table: {table}")
377
378            columns = resolver.get_source_columns(table, only_visible=True)
379
380            if pseudocolumns:
381                columns = [name for name in columns if name.upper() not in pseudocolumns]
382
383            if columns and "*" not in columns:
384                if pivot and has_pivoted_source and pivot_columns and pivot_output_columns:
385                    implicit_columns = [col for col in columns if col not in pivot_columns]
386                    new_selections.extend(
387                        exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
388                        for name in implicit_columns + pivot_output_columns
389                    )
390                    continue
391
392                table_id = id(table)
393                for name in columns:
394                    if name in using_column_tables and table in using_column_tables[name]:
395                        if name in coalesced_columns:
396                            continue
397
398                        coalesced_columns.add(name)
399                        tables = using_column_tables[name]
400                        coalesce = [exp.column(name, table=table) for table in tables]
401
402                        new_selections.append(
403                            alias(
404                                exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]),
405                                alias=name,
406                                copy=False,
407                            )
408                        )
409                    elif name not in except_columns.get(table_id, set()):
410                        alias_ = replace_columns.get(table_id, {}).get(name, name)
411                        column = exp.column(name, table=table)
412                        new_selections.append(
413                            alias(column, alias_, copy=False) if alias_ != name else column
414                        )
415            else:
416                return
417
418    # Ensures we don't overwrite the initial selections with an empty list
419    if new_selections:
420        scope.expression.set("expressions", new_selections)
421
422
423def _add_except_columns(
424    expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]]
425) -> None:
426    except_ = expression.args.get("except")
427
428    if not except_:
429        return
430
431    columns = {e.name for e in except_}
432
433    for table in tables:
434        except_columns[id(table)] = columns
435
436
437def _add_replace_columns(
438    expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, str]]
439) -> None:
440    replace = expression.args.get("replace")
441
442    if not replace:
443        return
444
445    columns = {e.this.name: e.alias for e in replace}
446
447    for table in tables:
448        replace_columns[id(table)] = columns
449
450
451def _qualify_outputs(scope: Scope) -> None:
452    """Ensure all output columns are aliased"""
453    new_selections = []
454
455    for i, (selection, aliased_column) in enumerate(
456        itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
457    ):
458        if isinstance(selection, exp.Subquery):
459            if not selection.output_name:
460                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
461        elif not isinstance(selection, exp.Alias) and not selection.is_star:
462            selection = alias(
463                selection,
464                alias=selection.output_name or f"_col_{i}",
465            )
466        if aliased_column:
467            selection.set("alias", exp.to_identifier(aliased_column))
468
469        new_selections.append(selection)
470
471    scope.expression.set("expressions", new_selections)
472
473
474def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
475    """Makes sure all identifiers that need to be quoted are quoted."""
476    return expression.transform(
477        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
478    )
479
480
481class Resolver:
482    """
483    Helper for resolving columns.
484
485    This is a class so we can lazily load some things and easily share them across functions.
486    """
487
488    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
489        self.scope = scope
490        self.schema = schema
491        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
492        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
493        self._all_columns: t.Optional[t.Set[str]] = None
494        self._infer_schema = infer_schema
495
496    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
497        """
498        Get the table for a column name.
499
500        Args:
501            column_name: The column name to find the table for.
502        Returns:
503            The table name if it can be found/inferred.
504        """
505        if self._unambiguous_columns is None:
506            self._unambiguous_columns = self._get_unambiguous_columns(
507                self._get_all_source_columns()
508            )
509
510        table_name = self._unambiguous_columns.get(column_name)
511
512        if not table_name and self._infer_schema:
513            sources_without_schema = tuple(
514                source
515                for source, columns in self._get_all_source_columns().items()
516                if not columns or "*" in columns
517            )
518            if len(sources_without_schema) == 1:
519                table_name = sources_without_schema[0]
520
521        if table_name not in self.scope.selected_sources:
522            return exp.to_identifier(table_name)
523
524        node, _ = self.scope.selected_sources.get(table_name)
525
526        if isinstance(node, exp.Subqueryable):
527            while node and node.alias != table_name:
528                node = node.parent
529
530        node_alias = node.args.get("alias")
531        if node_alias:
532            return exp.to_identifier(node_alias.this)
533
534        return exp.to_identifier(table_name)
535
536    @property
537    def all_columns(self) -> t.Set[str]:
538        """All available columns of all sources in this scope"""
539        if self._all_columns is None:
540            self._all_columns = {
541                column for columns in self._get_all_source_columns().values() for column in columns
542            }
543        return self._all_columns
544
545    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
546        """Resolve the source columns for a given source `name`."""
547        if name not in self.scope.sources:
548            raise OptimizeError(f"Unknown table: {name}")
549
550        source = self.scope.sources[name]
551
552        if isinstance(source, exp.Table):
553            columns = self.schema.column_names(source, only_visible)
554        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
555            columns = source.expression.alias_column_names
556        else:
557            columns = source.expression.named_selects
558
559        node, _ = self.scope.selected_sources.get(name) or (None, None)
560        if isinstance(node, Scope):
561            column_aliases = node.expression.alias_column_names
562        elif isinstance(node, exp.Expression):
563            column_aliases = node.alias_column_names
564        else:
565            column_aliases = []
566
567        # If the source's columns are aliased, their aliases shadow the corresponding column names
568        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
569
570    def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
571        if self._source_columns is None:
572            self._source_columns = {
573                source_name: self.get_source_columns(source_name)
574                for source_name, source in itertools.chain(
575                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
576                )
577            }
578        return self._source_columns
579
580    def _get_unambiguous_columns(
581        self, source_columns: t.Dict[str, t.List[str]]
582    ) -> t.Dict[str, str]:
583        """
584        Find all the unambiguous columns in sources.
585
586        Args:
587            source_columns: Mapping of names to source columns.
588
589        Returns:
590            Mapping of column name to source name.
591        """
592        if not source_columns:
593            return {}
594
595        source_columns_pairs = list(source_columns.items())
596
597        first_table, first_columns = source_columns_pairs[0]
598        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
599        all_columns = set(unambiguous_columns)
600
601        for table, columns in source_columns_pairs[1:]:
602            unique = self._find_unique_columns(columns)
603            ambiguous = set(all_columns).intersection(unique)
604            all_columns.update(columns)
605
606            for column in ambiguous:
607                unambiguous_columns.pop(column, None)
608            for column in unique.difference(ambiguous):
609                unambiguous_columns[column] = table
610
611        return unambiguous_columns
612
613    @staticmethod
614    def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
615        """
616        Find the unique columns in a list of columns.
617
618        Example:
619            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
620            ['a', 'c']
621
622        This is necessary because duplicate column names are ambiguous.
623        """
624        counts: t.Dict[str, int] = {}
625        for column in columns:
626            counts[column] = counts.get(column, 0) + 1
627        return {column for column, count in counts.items() if count == 1}
def qualify_columns( expression: sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema], expand_alias_refs: bool = True, infer_schema: Optional[bool] = None) -> sqlglot.expressions.Expression:
17def qualify_columns(
18    expression: exp.Expression,
19    schema: t.Dict | Schema,
20    expand_alias_refs: bool = True,
21    infer_schema: t.Optional[bool] = None,
22) -> exp.Expression:
23    """
24    Rewrite sqlglot AST to have fully qualified columns.
25
26    Example:
27        >>> import sqlglot
28        >>> schema = {"tbl": {"col": "INT"}}
29        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
30        >>> qualify_columns(expression, schema).sql()
31        'SELECT tbl.col AS col FROM tbl'
32
33    Args:
34        expression: Expression to qualify.
35        schema: Database schema.
36        expand_alias_refs: Whether or not to expand references to aliases.
37        infer_schema: Whether or not to infer the schema if missing.
38
39    Returns:
40        The qualified expression.
41    """
42    schema = ensure_schema(schema)
43    infer_schema = schema.empty if infer_schema is None else infer_schema
44    pseudocolumns = Dialect.get_or_raise(schema.dialect).PSEUDOCOLUMNS
45
46    for scope in traverse_scope(expression):
47        resolver = Resolver(scope, schema, infer_schema=infer_schema)
48        _pop_table_column_aliases(scope.ctes)
49        _pop_table_column_aliases(scope.derived_tables)
50        using_column_tables = _expand_using(scope, resolver)
51
52        if schema.empty and expand_alias_refs:
53            _expand_alias_refs(scope, resolver)
54
55        _qualify_columns(scope, resolver)
56
57        if not schema.empty and expand_alias_refs:
58            _expand_alias_refs(scope, resolver)
59
60        if not isinstance(scope.expression, exp.UDTF):
61            _expand_stars(scope, resolver, using_column_tables, pseudocolumns)
62            _qualify_outputs(scope)
63
64        _expand_group_by(scope)
65        _expand_order_by(scope, resolver)
66
67    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression: Expression to qualify.
  • schema: Database schema.
  • expand_alias_refs: Whether or not to expand references to aliases.
  • infer_schema: Whether or not to infer the schema if missing.
Returns:

The qualified expression.

def validate_qualify_columns(expression: ~E) -> ~E:
70def validate_qualify_columns(expression: E) -> E:
71    """Raise an `OptimizeError` if any columns aren't qualified"""
72    unqualified_columns = []
73    for scope in traverse_scope(expression):
74        if isinstance(scope.expression, exp.Select):
75            unqualified_columns.extend(scope.unqualified_columns)
76            if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots:
77                column = scope.external_columns[0]
78                raise OptimizeError(
79                    f"""Column '{column}' could not be resolved{f" for table: '{column.table}'" if column.table else ''}"""
80                )
81
82    if unqualified_columns:
83        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
84    return expression

Raise an OptimizeError if any columns aren't qualified

def quote_identifiers( expression: ~E, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, identify: bool = True) -> ~E:
475def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E:
476    """Makes sure all identifiers that need to be quoted are quoted."""
477    return expression.transform(
478        Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False
479    )

Makes sure all identifiers that need to be quoted are quoted.

class Resolver:
482class Resolver:
483    """
484    Helper for resolving columns.
485
486    This is a class so we can lazily load some things and easily share them across functions.
487    """
488
489    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
490        self.scope = scope
491        self.schema = schema
492        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
493        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
494        self._all_columns: t.Optional[t.Set[str]] = None
495        self._infer_schema = infer_schema
496
497    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
498        """
499        Get the table for a column name.
500
501        Args:
502            column_name: The column name to find the table for.
503        Returns:
504            The table name if it can be found/inferred.
505        """
506        if self._unambiguous_columns is None:
507            self._unambiguous_columns = self._get_unambiguous_columns(
508                self._get_all_source_columns()
509            )
510
511        table_name = self._unambiguous_columns.get(column_name)
512
513        if not table_name and self._infer_schema:
514            sources_without_schema = tuple(
515                source
516                for source, columns in self._get_all_source_columns().items()
517                if not columns or "*" in columns
518            )
519            if len(sources_without_schema) == 1:
520                table_name = sources_without_schema[0]
521
522        if table_name not in self.scope.selected_sources:
523            return exp.to_identifier(table_name)
524
525        node, _ = self.scope.selected_sources.get(table_name)
526
527        if isinstance(node, exp.Subqueryable):
528            while node and node.alias != table_name:
529                node = node.parent
530
531        node_alias = node.args.get("alias")
532        if node_alias:
533            return exp.to_identifier(node_alias.this)
534
535        return exp.to_identifier(table_name)
536
537    @property
538    def all_columns(self) -> t.Set[str]:
539        """All available columns of all sources in this scope"""
540        if self._all_columns is None:
541            self._all_columns = {
542                column for columns in self._get_all_source_columns().values() for column in columns
543            }
544        return self._all_columns
545
546    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
547        """Resolve the source columns for a given source `name`."""
548        if name not in self.scope.sources:
549            raise OptimizeError(f"Unknown table: {name}")
550
551        source = self.scope.sources[name]
552
553        if isinstance(source, exp.Table):
554            columns = self.schema.column_names(source, only_visible)
555        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
556            columns = source.expression.alias_column_names
557        else:
558            columns = source.expression.named_selects
559
560        node, _ = self.scope.selected_sources.get(name) or (None, None)
561        if isinstance(node, Scope):
562            column_aliases = node.expression.alias_column_names
563        elif isinstance(node, exp.Expression):
564            column_aliases = node.alias_column_names
565        else:
566            column_aliases = []
567
568        # If the source's columns are aliased, their aliases shadow the corresponding column names
569        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]
570
571    def _get_all_source_columns(self) -> t.Dict[str, t.List[str]]:
572        if self._source_columns is None:
573            self._source_columns = {
574                source_name: self.get_source_columns(source_name)
575                for source_name, source in itertools.chain(
576                    self.scope.selected_sources.items(), self.scope.lateral_sources.items()
577                )
578            }
579        return self._source_columns
580
581    def _get_unambiguous_columns(
582        self, source_columns: t.Dict[str, t.List[str]]
583    ) -> t.Dict[str, str]:
584        """
585        Find all the unambiguous columns in sources.
586
587        Args:
588            source_columns: Mapping of names to source columns.
589
590        Returns:
591            Mapping of column name to source name.
592        """
593        if not source_columns:
594            return {}
595
596        source_columns_pairs = list(source_columns.items())
597
598        first_table, first_columns = source_columns_pairs[0]
599        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
600        all_columns = set(unambiguous_columns)
601
602        for table, columns in source_columns_pairs[1:]:
603            unique = self._find_unique_columns(columns)
604            ambiguous = set(all_columns).intersection(unique)
605            all_columns.update(columns)
606
607            for column in ambiguous:
608                unambiguous_columns.pop(column, None)
609            for column in unique.difference(ambiguous):
610                unambiguous_columns[column] = table
611
612        return unambiguous_columns
613
614    @staticmethod
615    def _find_unique_columns(columns: t.Collection[str]) -> t.Set[str]:
616        """
617        Find the unique columns in a list of columns.
618
619        Example:
620            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
621            ['a', 'c']
622
623        This is necessary because duplicate column names are ambiguous.
624        """
625        counts: t.Dict[str, int] = {}
626        for column in columns:
627            counts[column] = counts.get(column, 0) + 1
628        return {column for column, count in counts.items() if count == 1}

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver( scope: sqlglot.optimizer.scope.Scope, schema: sqlglot.schema.Schema, infer_schema: bool = True)
489    def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True):
490        self.scope = scope
491        self.schema = schema
492        self._source_columns: t.Optional[t.Dict[str, t.List[str]]] = None
493        self._unambiguous_columns: t.Optional[t.Dict[str, str]] = None
494        self._all_columns: t.Optional[t.Set[str]] = None
495        self._infer_schema = infer_schema
scope
schema
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
497    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
498        """
499        Get the table for a column name.
500
501        Args:
502            column_name: The column name to find the table for.
503        Returns:
504            The table name if it can be found/inferred.
505        """
506        if self._unambiguous_columns is None:
507            self._unambiguous_columns = self._get_unambiguous_columns(
508                self._get_all_source_columns()
509            )
510
511        table_name = self._unambiguous_columns.get(column_name)
512
513        if not table_name and self._infer_schema:
514            sources_without_schema = tuple(
515                source
516                for source, columns in self._get_all_source_columns().items()
517                if not columns or "*" in columns
518            )
519            if len(sources_without_schema) == 1:
520                table_name = sources_without_schema[0]
521
522        if table_name not in self.scope.selected_sources:
523            return exp.to_identifier(table_name)
524
525        node, _ = self.scope.selected_sources.get(table_name)
526
527        if isinstance(node, exp.Subqueryable):
528            while node and node.alias != table_name:
529                node = node.parent
530
531        node_alias = node.args.get("alias")
532        if node_alias:
533            return exp.to_identifier(node_alias.this)
534
535        return exp.to_identifier(table_name)

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns: Set[str]

All available columns of all sources in this scope

def get_source_columns(self, name: str, only_visible: bool = False) -> List[str]:
546    def get_source_columns(self, name: str, only_visible: bool = False) -> t.List[str]:
547        """Resolve the source columns for a given source `name`."""
548        if name not in self.scope.sources:
549            raise OptimizeError(f"Unknown table: {name}")
550
551        source = self.scope.sources[name]
552
553        if isinstance(source, exp.Table):
554            columns = self.schema.column_names(source, only_visible)
555        elif isinstance(source, Scope) and isinstance(source.expression, exp.Values):
556            columns = source.expression.alias_column_names
557        else:
558            columns = source.expression.named_selects
559
560        node, _ = self.scope.selected_sources.get(name) or (None, None)
561        if isinstance(node, Scope):
562            column_aliases = node.expression.alias_column_names
563        elif isinstance(node, exp.Expression):
564            column_aliases = node.alias_column_names
565        else:
566            column_aliases = []
567
568        # If the source's columns are aliased, their aliases shadow the corresponding column names
569        return [alias or name for (name, alias) in itertools.zip_longest(columns, column_aliases)]

Resolve the source columns for a given source name.