Edit on GitHub

sqlglot.optimizer.qualify_columns

  1import itertools
  2import typing as t
  3
  4from sqlglot import alias, exp
  5from sqlglot.errors import OptimizeError
  6from sqlglot.optimizer.scope import Scope, traverse_scope
  7from sqlglot.schema import ensure_schema
  8
  9
 10def qualify_columns(expression, schema):
 11    """
 12    Rewrite sqlglot AST to have fully qualified columns.
 13
 14    Example:
 15        >>> import sqlglot
 16        >>> schema = {"tbl": {"col": "INT"}}
 17        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 18        >>> qualify_columns(expression, schema).sql()
 19        'SELECT tbl.col AS col FROM tbl'
 20
 21    Args:
 22        expression (sqlglot.Expression): expression to qualify
 23        schema (dict|sqlglot.optimizer.Schema): Database schema
 24    Returns:
 25        sqlglot.Expression: qualified expression
 26    """
 27    schema = ensure_schema(schema)
 28
 29    for scope in traverse_scope(expression):
 30        resolver = Resolver(scope, schema)
 31        _pop_table_column_aliases(scope.ctes)
 32        _pop_table_column_aliases(scope.derived_tables)
 33        using_column_tables = _expand_using(scope, resolver)
 34        _qualify_columns(scope, resolver)
 35        if not isinstance(scope.expression, exp.UDTF):
 36            _expand_stars(scope, resolver, using_column_tables)
 37            _qualify_outputs(scope)
 38        _expand_alias_refs(scope, resolver)
 39        _expand_group_by(scope, resolver)
 40        _expand_order_by(scope)
 41
 42    return expression
 43
 44
 45def validate_qualify_columns(expression):
 46    """Raise an `OptimizeError` if any columns aren't qualified"""
 47    unqualified_columns = []
 48    for scope in traverse_scope(expression):
 49        if isinstance(scope.expression, exp.Select):
 50            unqualified_columns.extend(scope.unqualified_columns)
 51            if scope.external_columns and not scope.is_correlated_subquery:
 52                column = scope.external_columns[0]
 53                raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'")
 54
 55    if unqualified_columns:
 56        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
 57    return expression
 58
 59
 60def _pop_table_column_aliases(derived_tables):
 61    """
 62    Remove table column aliases.
 63
 64    (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 65    """
 66    for derived_table in derived_tables:
 67        table_alias = derived_table.args.get("alias")
 68        if table_alias:
 69            table_alias.args.pop("columns", None)
 70
 71
 72def _expand_using(scope, resolver):
 73    joins = list(scope.find_all(exp.Join))
 74    names = {join.this.alias for join in joins}
 75    ordered = [key for key in scope.selected_sources if key not in names]
 76
 77    # Mapping of automatically joined column names to an ordered set of source names (dict).
 78    column_tables = {}
 79
 80    for join in joins:
 81        using = join.args.get("using")
 82
 83        if not using:
 84            continue
 85
 86        join_table = join.this.alias_or_name
 87
 88        columns = {}
 89
 90        for k in scope.selected_sources:
 91            if k in ordered:
 92                for column in resolver.get_source_columns(k):
 93                    if column not in columns:
 94                        columns[column] = k
 95
 96        ordered.append(join_table)
 97        join_columns = resolver.get_source_columns(join_table)
 98        conditions = []
 99
100        for identifier in using:
101            identifier = identifier.name
102            table = columns.get(identifier)
103
104            if not table or identifier not in join_columns:
105                raise OptimizeError(f"Cannot automatically join: {identifier}")
106
107            conditions.append(
108                exp.condition(
109                    exp.EQ(
110                        this=exp.column(identifier, table=table),
111                        expression=exp.column(identifier, table=join_table),
112                    )
113                )
114            )
115
116            # Set all values in the dict to None, because we only care about the key ordering
117            tables = column_tables.setdefault(identifier, {})
118            if table not in tables:
119                tables[table] = None
120            if join_table not in tables:
121                tables[join_table] = None
122
123        join.args.pop("using")
124        join.set("on", exp.and_(*conditions))
125
126    if column_tables:
127        for column in scope.columns:
128            if not column.table and column.name in column_tables:
129                tables = column_tables[column.name]
130                coalesce = [exp.column(column.name, table=table) for table in tables]
131                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
132
133                # Ensure selects keep their output name
134                if isinstance(column.parent, exp.Select):
135                    replacement = exp.alias_(replacement, alias=column.name)
136
137                scope.replace(column, replacement)
138
139    return column_tables
140
141
142def _expand_alias_refs(scope, resolver):
143    selects = {}
144
145    # Replace references to select aliases
146    def transform(node, *_):
147        if isinstance(node, exp.Column) and not node.table:
148            table = resolver.get_table(node.name)
149
150            # Source columns get priority over select aliases
151            if table:
152                node.set("table", table)
153                return node
154
155            if not selects:
156                for s in scope.selects:
157                    selects[s.alias_or_name] = s
158            select = selects.get(node.name)
159
160            if select:
161                scope.clear_cache()
162                if isinstance(select, exp.Alias):
163                    select = select.this
164                return select.copy()
165
166        return node
167
168    where = scope.expression.args.get("where")
169    if where:
170        where.transform(transform, copy=False)
171
172    group = scope.expression.args.get("group")
173    if group:
174        group.transform(transform, copy=False)
175
176
177def _expand_group_by(scope, resolver):
178    group = scope.expression.args.get("group")
179    if not group:
180        return
181
182    group.set("expressions", _expand_positional_references(scope, group.expressions))
183    scope.expression.set("group", group)
184
185
186def _expand_order_by(scope):
187    order = scope.expression.args.get("order")
188    if not order:
189        return
190
191    ordereds = order.expressions
192    for ordered, new_expression in zip(
193        ordereds,
194        _expand_positional_references(scope, (o.this for o in ordereds)),
195    ):
196        ordered.set("this", new_expression)
197
198
199def _expand_positional_references(scope, expressions):
200    new_nodes = []
201    for node in expressions:
202        if node.is_int:
203            try:
204                select = scope.selects[int(node.name) - 1]
205            except IndexError:
206                raise OptimizeError(f"Unknown output column: {node.name}")
207            if isinstance(select, exp.Alias):
208                select = select.this
209            new_nodes.append(select.copy())
210            scope.clear_cache()
211        else:
212            new_nodes.append(node)
213
214    return new_nodes
215
216
217def _qualify_columns(scope, resolver):
218    """Disambiguate columns, ensuring each column specifies a source"""
219    for column in scope.columns:
220        column_table = column.table
221        column_name = column.name
222
223        if column_table and column_table in scope.sources:
224            source_columns = resolver.get_source_columns(column_table)
225            if source_columns and column_name not in source_columns and "*" not in source_columns:
226                raise OptimizeError(f"Unknown column: {column_name}")
227
228        if not column_table:
229            column_table = resolver.get_table(column_name)
230
231            # column_table can be a '' because bigquery unnest has no table alias
232            if column_table:
233                column.set("table", column_table)
234        elif column_table not in scope.sources:
235            # structs are used like tables (e.g. "struct"."field"), so they need to be qualified
236            # separately and represented as dot(dot(...(<table>.<column>, field1), field2, ...))
237
238            root, *parts = column.parts
239
240            if root.name in scope.sources:
241                # struct is already qualified, but we still need to change the AST representation
242                column_table = root
243                root, *parts = parts
244            else:
245                column_table = resolver.get_table(root.name)
246
247            if column_table:
248                column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts]))
249
250    columns_missing_from_scope = []
251
252    # Determine whether each reference in the order by clause is to a column or an alias.
253    order = scope.expression.args.get("order")
254
255    if order:
256        for ordered in order.expressions:
257            for column in ordered.find_all(exp.Column):
258                if (
259                    not column.table
260                    and column.parent is not ordered
261                    and column.name in resolver.all_columns
262                ):
263                    columns_missing_from_scope.append(column)
264
265    # Determine whether each reference in the having clause is to a column or an alias.
266    having = scope.expression.args.get("having")
267
268    if having:
269        for column in having.find_all(exp.Column):
270            if (
271                not column.table
272                and column.find_ancestor(exp.AggFunc)
273                and column.name in resolver.all_columns
274            ):
275                columns_missing_from_scope.append(column)
276
277    for column in columns_missing_from_scope:
278        column_table = resolver.get_table(column.name)
279
280        if column_table:
281            column.set("table", column_table)
282
283
284def _expand_stars(scope, resolver, using_column_tables):
285    """Expand stars to lists of column selections"""
286
287    new_selections = []
288    except_columns = {}
289    replace_columns = {}
290    coalesced_columns = set()
291
292    for expression in scope.selects:
293        if isinstance(expression, exp.Star):
294            tables = list(scope.selected_sources)
295            _add_except_columns(expression, tables, except_columns)
296            _add_replace_columns(expression, tables, replace_columns)
297        elif expression.is_star:
298            tables = [expression.table]
299            _add_except_columns(expression.this, tables, except_columns)
300            _add_replace_columns(expression.this, tables, replace_columns)
301        else:
302            new_selections.append(expression)
303            continue
304
305        for table in tables:
306            if table not in scope.sources:
307                raise OptimizeError(f"Unknown table: {table}")
308            columns = resolver.get_source_columns(table, only_visible=True)
309
310            if columns and "*" not in columns:
311                table_id = id(table)
312                for name in columns:
313                    if name in using_column_tables and table in using_column_tables[name]:
314                        if name in coalesced_columns:
315                            continue
316
317                        coalesced_columns.add(name)
318                        tables = using_column_tables[name]
319                        coalesce = [exp.column(name, table=table) for table in tables]
320
321                        new_selections.append(
322                            exp.alias_(
323                                exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]), alias=name
324                            )
325                        )
326                    elif name not in except_columns.get(table_id, set()):
327                        alias_ = replace_columns.get(table_id, {}).get(name, name)
328                        column = exp.column(name, table)
329                        new_selections.append(alias(column, alias_) if alias_ != name else column)
330            else:
331                return
332    scope.expression.set("expressions", new_selections)
333
334
335def _add_except_columns(expression, tables, except_columns):
336    except_ = expression.args.get("except")
337
338    if not except_:
339        return
340
341    columns = {e.name for e in except_}
342
343    for table in tables:
344        except_columns[id(table)] = columns
345
346
347def _add_replace_columns(expression, tables, replace_columns):
348    replace = expression.args.get("replace")
349
350    if not replace:
351        return
352
353    columns = {e.this.name: e.alias for e in replace}
354
355    for table in tables:
356        replace_columns[id(table)] = columns
357
358
359def _qualify_outputs(scope):
360    """Ensure all output columns are aliased"""
361    new_selections = []
362
363    for i, (selection, aliased_column) in enumerate(
364        itertools.zip_longest(scope.selects, scope.outer_column_list)
365    ):
366        if isinstance(selection, exp.Subquery):
367            if not selection.output_name:
368                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
369        elif not isinstance(selection, exp.Alias) and not selection.is_star:
370            alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
371            alias_.set("this", selection)
372            selection = alias_
373
374        if aliased_column:
375            selection.set("alias", exp.to_identifier(aliased_column))
376
377        new_selections.append(selection)
378
379    scope.expression.set("expressions", new_selections)
380
381
382class Resolver:
383    """
384    Helper for resolving columns.
385
386    This is a class so we can lazily load some things and easily share them across functions.
387    """
388
389    def __init__(self, scope, schema):
390        self.scope = scope
391        self.schema = schema
392        self._source_columns = None
393        self._unambiguous_columns = None
394        self._all_columns = None
395
396    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
397        """
398        Get the table for a column name.
399
400        Args:
401            column_name: The column name to find the table for.
402        Returns:
403            The table name if it can be found/inferred.
404        """
405        if self._unambiguous_columns is None:
406            self._unambiguous_columns = self._get_unambiguous_columns(
407                self._get_all_source_columns()
408            )
409
410        table_name = self._unambiguous_columns.get(column_name)
411
412        if not table_name:
413            sources_without_schema = tuple(
414                source
415                for source, columns in self._get_all_source_columns().items()
416                if not columns or "*" in columns
417            )
418            if len(sources_without_schema) == 1:
419                table_name = sources_without_schema[0]
420
421        if table_name not in self.scope.selected_sources:
422            return exp.to_identifier(table_name)
423
424        node, _ = self.scope.selected_sources.get(table_name)
425
426        if isinstance(node, exp.Subqueryable):
427            while node and node.alias != table_name:
428                node = node.parent
429
430        node_alias = node.args.get("alias")
431        if node_alias:
432            return node_alias.this
433
434        return exp.to_identifier(
435            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
436        )
437
438    @property
439    def all_columns(self):
440        """All available columns of all sources in this scope"""
441        if self._all_columns is None:
442            self._all_columns = {
443                column for columns in self._get_all_source_columns().values() for column in columns
444            }
445        return self._all_columns
446
447    def get_source_columns(self, name, only_visible=False):
448        """Resolve the source columns for a given source `name`"""
449        if name not in self.scope.sources:
450            raise OptimizeError(f"Unknown table: {name}")
451
452        source = self.scope.sources[name]
453
454        # If referencing a table, return the columns from the schema
455        if isinstance(source, exp.Table):
456            return self.schema.column_names(source, only_visible)
457
458        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
459            return source.expression.alias_column_names
460
461        # Otherwise, if referencing another scope, return that scope's named selects
462        return source.expression.named_selects
463
464    def _get_all_source_columns(self):
465        if self._source_columns is None:
466            self._source_columns = {
467                k: self.get_source_columns(k)
468                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
469            }
470        return self._source_columns
471
472    def _get_unambiguous_columns(self, source_columns):
473        """
474        Find all the unambiguous columns in sources.
475
476        Args:
477            source_columns (dict): Mapping of names to source columns
478        Returns:
479            dict: Mapping of column name to source name
480        """
481        if not source_columns:
482            return {}
483
484        source_columns = list(source_columns.items())
485
486        first_table, first_columns = source_columns[0]
487        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
488        all_columns = set(unambiguous_columns)
489
490        for table, columns in source_columns[1:]:
491            unique = self._find_unique_columns(columns)
492            ambiguous = set(all_columns).intersection(unique)
493            all_columns.update(columns)
494            for column in ambiguous:
495                unambiguous_columns.pop(column, None)
496            for column in unique.difference(ambiguous):
497                unambiguous_columns[column] = table
498
499        return unambiguous_columns
500
501    @staticmethod
502    def _find_unique_columns(columns):
503        """
504        Find the unique columns in a list of columns.
505
506        Example:
507            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
508            ['a', 'c']
509
510        This is necessary because duplicate column names are ambiguous.
511        """
512        counts = {}
513        for column in columns:
514            counts[column] = counts.get(column, 0) + 1
515        return {column for column, count in counts.items() if count == 1}
def qualify_columns(expression, schema):
11def qualify_columns(expression, schema):
12    """
13    Rewrite sqlglot AST to have fully qualified columns.
14
15    Example:
16        >>> import sqlglot
17        >>> schema = {"tbl": {"col": "INT"}}
18        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
19        >>> qualify_columns(expression, schema).sql()
20        'SELECT tbl.col AS col FROM tbl'
21
22    Args:
23        expression (sqlglot.Expression): expression to qualify
24        schema (dict|sqlglot.optimizer.Schema): Database schema
25    Returns:
26        sqlglot.Expression: qualified expression
27    """
28    schema = ensure_schema(schema)
29
30    for scope in traverse_scope(expression):
31        resolver = Resolver(scope, schema)
32        _pop_table_column_aliases(scope.ctes)
33        _pop_table_column_aliases(scope.derived_tables)
34        using_column_tables = _expand_using(scope, resolver)
35        _qualify_columns(scope, resolver)
36        if not isinstance(scope.expression, exp.UDTF):
37            _expand_stars(scope, resolver, using_column_tables)
38            _qualify_outputs(scope)
39        _expand_alias_refs(scope, resolver)
40        _expand_group_by(scope, resolver)
41        _expand_order_by(scope)
42
43    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression (sqlglot.Expression): expression to qualify
  • schema (dict|sqlglot.optimizer.Schema): Database schema
Returns:

sqlglot.Expression: qualified expression

def validate_qualify_columns(expression):
46def validate_qualify_columns(expression):
47    """Raise an `OptimizeError` if any columns aren't qualified"""
48    unqualified_columns = []
49    for scope in traverse_scope(expression):
50        if isinstance(scope.expression, exp.Select):
51            unqualified_columns.extend(scope.unqualified_columns)
52            if scope.external_columns and not scope.is_correlated_subquery:
53                column = scope.external_columns[0]
54                raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'")
55
56    if unqualified_columns:
57        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
58    return expression

Raise an OptimizeError if any columns aren't qualified

class Resolver:
383class Resolver:
384    """
385    Helper for resolving columns.
386
387    This is a class so we can lazily load some things and easily share them across functions.
388    """
389
390    def __init__(self, scope, schema):
391        self.scope = scope
392        self.schema = schema
393        self._source_columns = None
394        self._unambiguous_columns = None
395        self._all_columns = None
396
397    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
398        """
399        Get the table for a column name.
400
401        Args:
402            column_name: The column name to find the table for.
403        Returns:
404            The table name if it can be found/inferred.
405        """
406        if self._unambiguous_columns is None:
407            self._unambiguous_columns = self._get_unambiguous_columns(
408                self._get_all_source_columns()
409            )
410
411        table_name = self._unambiguous_columns.get(column_name)
412
413        if not table_name:
414            sources_without_schema = tuple(
415                source
416                for source, columns in self._get_all_source_columns().items()
417                if not columns or "*" in columns
418            )
419            if len(sources_without_schema) == 1:
420                table_name = sources_without_schema[0]
421
422        if table_name not in self.scope.selected_sources:
423            return exp.to_identifier(table_name)
424
425        node, _ = self.scope.selected_sources.get(table_name)
426
427        if isinstance(node, exp.Subqueryable):
428            while node and node.alias != table_name:
429                node = node.parent
430
431        node_alias = node.args.get("alias")
432        if node_alias:
433            return node_alias.this
434
435        return exp.to_identifier(
436            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
437        )
438
439    @property
440    def all_columns(self):
441        """All available columns of all sources in this scope"""
442        if self._all_columns is None:
443            self._all_columns = {
444                column for columns in self._get_all_source_columns().values() for column in columns
445            }
446        return self._all_columns
447
448    def get_source_columns(self, name, only_visible=False):
449        """Resolve the source columns for a given source `name`"""
450        if name not in self.scope.sources:
451            raise OptimizeError(f"Unknown table: {name}")
452
453        source = self.scope.sources[name]
454
455        # If referencing a table, return the columns from the schema
456        if isinstance(source, exp.Table):
457            return self.schema.column_names(source, only_visible)
458
459        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
460            return source.expression.alias_column_names
461
462        # Otherwise, if referencing another scope, return that scope's named selects
463        return source.expression.named_selects
464
465    def _get_all_source_columns(self):
466        if self._source_columns is None:
467            self._source_columns = {
468                k: self.get_source_columns(k)
469                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
470            }
471        return self._source_columns
472
473    def _get_unambiguous_columns(self, source_columns):
474        """
475        Find all the unambiguous columns in sources.
476
477        Args:
478            source_columns (dict): Mapping of names to source columns
479        Returns:
480            dict: Mapping of column name to source name
481        """
482        if not source_columns:
483            return {}
484
485        source_columns = list(source_columns.items())
486
487        first_table, first_columns = source_columns[0]
488        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
489        all_columns = set(unambiguous_columns)
490
491        for table, columns in source_columns[1:]:
492            unique = self._find_unique_columns(columns)
493            ambiguous = set(all_columns).intersection(unique)
494            all_columns.update(columns)
495            for column in ambiguous:
496                unambiguous_columns.pop(column, None)
497            for column in unique.difference(ambiguous):
498                unambiguous_columns[column] = table
499
500        return unambiguous_columns
501
502    @staticmethod
503    def _find_unique_columns(columns):
504        """
505        Find the unique columns in a list of columns.
506
507        Example:
508            >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"]))
509            ['a', 'c']
510
511        This is necessary because duplicate column names are ambiguous.
512        """
513        counts = {}
514        for column in columns:
515            counts[column] = counts.get(column, 0) + 1
516        return {column for column, count in counts.items() if count == 1}

Helper for resolving columns.

This is a class so we can lazily load some things and easily share them across functions.

Resolver(scope, schema)
390    def __init__(self, scope, schema):
391        self.scope = scope
392        self.schema = schema
393        self._source_columns = None
394        self._unambiguous_columns = None
395        self._all_columns = None
def get_table(self, column_name: str) -> Optional[sqlglot.expressions.Identifier]:
397    def get_table(self, column_name: str) -> t.Optional[exp.Identifier]:
398        """
399        Get the table for a column name.
400
401        Args:
402            column_name: The column name to find the table for.
403        Returns:
404            The table name if it can be found/inferred.
405        """
406        if self._unambiguous_columns is None:
407            self._unambiguous_columns = self._get_unambiguous_columns(
408                self._get_all_source_columns()
409            )
410
411        table_name = self._unambiguous_columns.get(column_name)
412
413        if not table_name:
414            sources_without_schema = tuple(
415                source
416                for source, columns in self._get_all_source_columns().items()
417                if not columns or "*" in columns
418            )
419            if len(sources_without_schema) == 1:
420                table_name = sources_without_schema[0]
421
422        if table_name not in self.scope.selected_sources:
423            return exp.to_identifier(table_name)
424
425        node, _ = self.scope.selected_sources.get(table_name)
426
427        if isinstance(node, exp.Subqueryable):
428            while node and node.alias != table_name:
429                node = node.parent
430
431        node_alias = node.args.get("alias")
432        if node_alias:
433            return node_alias.this
434
435        return exp.to_identifier(
436            table_name, quoted=node.this.quoted if isinstance(node, exp.Table) else None
437        )

Get the table for a column name.

Arguments:
  • column_name: The column name to find the table for.
Returns:

The table name if it can be found/inferred.

all_columns

All available columns of all sources in this scope

def get_source_columns(self, name, only_visible=False):
448    def get_source_columns(self, name, only_visible=False):
449        """Resolve the source columns for a given source `name`"""
450        if name not in self.scope.sources:
451            raise OptimizeError(f"Unknown table: {name}")
452
453        source = self.scope.sources[name]
454
455        # If referencing a table, return the columns from the schema
456        if isinstance(source, exp.Table):
457            return self.schema.column_names(source, only_visible)
458
459        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
460            return source.expression.alias_column_names
461
462        # Otherwise, if referencing another scope, return that scope's named selects
463        return source.expression.named_selects

Resolve the source columns for a given source name