Edit on GitHub

sqlglot.optimizer.qualify_columns

  1import itertools
  2import typing as t
  3
  4from sqlglot import alias, exp
  5from sqlglot.errors import OptimizeError
  6from sqlglot.optimizer.scope import Scope, traverse_scope
  7from sqlglot.schema import ensure_schema
  8
  9
 10def qualify_columns(expression, schema):
 11    """
 12    Rewrite sqlglot AST to have fully qualified columns.
 13
 14    Example:
 15        >>> import sqlglot
 16        >>> schema = {"tbl": {"col": "INT"}}
 17        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 18        >>> qualify_columns(expression, schema).sql()
 19        'SELECT tbl.col AS col FROM tbl'
 20
 21    Args:
 22        expression (sqlglot.Expression): expression to qualify
 23        schema (dict|sqlglot.optimizer.Schema): Database schema
 24    Returns:
 25        sqlglot.Expression: qualified expression
 26    """
 27    schema = ensure_schema(schema)
 28
 29    for scope in traverse_scope(expression):
 30        resolver = _Resolver(scope, schema)
 31        _pop_table_column_aliases(scope.ctes)
 32        _pop_table_column_aliases(scope.derived_tables)
 33        _expand_using(scope, resolver)
 34        _expand_group_by(scope, resolver)
 35        _qualify_columns(scope, resolver)
 36        _expand_order_by(scope)
 37        if not isinstance(scope.expression, exp.UDTF):
 38            _expand_stars(scope, resolver)
 39            _qualify_outputs(scope)
 40
 41    return expression
 42
 43
 44def validate_qualify_columns(expression):
 45    """Raise an `OptimizeError` if any columns aren't qualified"""
 46    unqualified_columns = []
 47    for scope in traverse_scope(expression):
 48        if isinstance(scope.expression, exp.Select):
 49            unqualified_columns.extend(scope.unqualified_columns)
 50            if scope.external_columns and not scope.is_correlated_subquery:
 51                raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}")
 52
 53    if unqualified_columns:
 54        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
 55    return expression
 56
 57
 58def _pop_table_column_aliases(derived_tables):
 59    """
 60    Remove table column aliases.
 61
 62    (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 63    """
 64    for derived_table in derived_tables:
 65        if isinstance(derived_table.unnest(), exp.UDTF):
 66            continue
 67        table_alias = derived_table.args.get("alias")
 68        if table_alias:
 69            table_alias.args.pop("columns", None)
 70
 71
 72def _expand_using(scope, resolver):
 73    joins = list(scope.expression.find_all(exp.Join))
 74    names = {join.this.alias for join in joins}
 75    ordered = [key for key in scope.selected_sources if key not in names]
 76
 77    # Mapping of automatically joined column names to source names
 78    column_tables = {}
 79
 80    for join in joins:
 81        using = join.args.get("using")
 82
 83        if not using:
 84            continue
 85
 86        join_table = join.this.alias_or_name
 87
 88        columns = {}
 89
 90        for k in scope.selected_sources:
 91            if k in ordered:
 92                for column in resolver.get_source_columns(k):
 93                    if column not in columns:
 94                        columns[column] = k
 95
 96        ordered.append(join_table)
 97        join_columns = resolver.get_source_columns(join_table)
 98        conditions = []
 99
100        for identifier in using:
101            identifier = identifier.name
102            table = columns.get(identifier)
103
104            if not table or identifier not in join_columns:
105                raise OptimizeError(f"Cannot automatically join: {identifier}")
106
107            conditions.append(
108                exp.condition(
109                    exp.EQ(
110                        this=exp.column(identifier, table=table),
111                        expression=exp.column(identifier, table=join_table),
112                    )
113                )
114            )
115
116            tables = column_tables.setdefault(identifier, [])
117            if table not in tables:
118                tables.append(table)
119            if join_table not in tables:
120                tables.append(join_table)
121
122        join.args.pop("using")
123        join.set("on", exp.and_(*conditions))
124
125    if column_tables:
126        for column in scope.columns:
127            if not column.table and column.name in column_tables:
128                tables = column_tables[column.name]
129                coalesce = [exp.column(column.name, table=table) for table in tables]
130                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
131
132                # Ensure selects keep their output name
133                if isinstance(column.parent, exp.Select):
134                    replacement = exp.alias_(replacement, alias=column.name)
135
136                scope.replace(column, replacement)
137
138
139def _expand_group_by(scope, resolver):
140    group = scope.expression.args.get("group")
141    if not group:
142        return
143
144    # Replace references to select aliases
145    def transform(node, *_):
146        if isinstance(node, exp.Column) and not node.table:
147            table = resolver.get_table(node.name)
148
149            # Source columns get priority over select aliases
150            if table:
151                node.set("table", exp.to_identifier(table))
152                return node
153
154            selects = {s.alias_or_name: s for s in scope.selects}
155
156            select = selects.get(node.name)
157            if select:
158                scope.clear_cache()
159                if isinstance(select, exp.Alias):
160                    select = select.this
161                return select.copy()
162
163        return node
164
165    group.transform(transform, copy=False)
166    group.set("expressions", _expand_positional_references(scope, group.expressions))
167    scope.expression.set("group", group)
168
169
170def _expand_order_by(scope):
171    order = scope.expression.args.get("order")
172    if not order:
173        return
174
175    ordereds = order.expressions
176    for ordered, new_expression in zip(
177        ordereds,
178        _expand_positional_references(scope, (o.this for o in ordereds)),
179    ):
180        ordered.set("this", new_expression)
181
182
183def _expand_positional_references(scope, expressions):
184    new_nodes = []
185    for node in expressions:
186        if node.is_int:
187            try:
188                select = scope.selects[int(node.name) - 1]
189            except IndexError:
190                raise OptimizeError(f"Unknown output column: {node.name}")
191            if isinstance(select, exp.Alias):
192                select = select.this
193            new_nodes.append(select.copy())
194            scope.clear_cache()
195        else:
196            new_nodes.append(node)
197
198    return new_nodes
199
200
201def _qualify_columns(scope, resolver):
202    """Disambiguate columns, ensuring each column specifies a source"""
203    for column in scope.columns:
204        column_table = column.table
205        column_name = column.name
206
207        if column_table and column_table in scope.sources:
208            source_columns = resolver.get_source_columns(column_table)
209            if source_columns and column_name not in source_columns:
210                raise OptimizeError(f"Unknown column: {column_name}")
211
212        if not column_table:
213            column_table = resolver.get_table(column_name)
214
215            # column_table can be a '' because bigquery unnest has no table alias
216            if column_table:
217                column.set("table", exp.to_identifier(column_table))
218
219    columns_missing_from_scope = []
220    # Determine whether each reference in the order by clause is to a column or an alias.
221    for ordered in scope.find_all(exp.Ordered):
222        for column in ordered.find_all(exp.Column):
223            if (
224                not column.table
225                and column.parent is not ordered
226                and column.name in resolver.all_columns
227            ):
228                columns_missing_from_scope.append(column)
229
230    # Determine whether each reference in the having clause is to a column or an alias.
231    for having in scope.find_all(exp.Having):
232        for column in having.find_all(exp.Column):
233            if (
234                not column.table
235                and column.find_ancestor(exp.AggFunc)
236                and column.name in resolver.all_columns
237            ):
238                columns_missing_from_scope.append(column)
239
240    for column in columns_missing_from_scope:
241        column_table = resolver.get_table(column.name)
242
243        if column_table:
244            column.set("table", exp.to_identifier(column_table))
245
246
247def _expand_stars(scope, resolver):
248    """Expand stars to lists of column selections"""
249
250    new_selections = []
251    except_columns = {}
252    replace_columns = {}
253
254    for expression in scope.selects:
255        if isinstance(expression, exp.Star):
256            tables = list(scope.selected_sources)
257            _add_except_columns(expression, tables, except_columns)
258            _add_replace_columns(expression, tables, replace_columns)
259        elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star):
260            tables = [expression.table]
261            _add_except_columns(expression.this, tables, except_columns)
262            _add_replace_columns(expression.this, tables, replace_columns)
263        else:
264            new_selections.append(expression)
265            continue
266
267        for table in tables:
268            if table not in scope.sources:
269                raise OptimizeError(f"Unknown table: {table}")
270            columns = resolver.get_source_columns(table, only_visible=True)
271            if not columns:
272                raise OptimizeError(
273                    f"Table has no schema/columns. Cannot expand star for table: {table}."
274                )
275            table_id = id(table)
276            for name in columns:
277                if name not in except_columns.get(table_id, set()):
278                    alias_ = replace_columns.get(table_id, {}).get(name, name)
279                    column = exp.column(name, table)
280                    new_selections.append(alias(column, alias_) if alias_ != name else column)
281
282    scope.expression.set("expressions", new_selections)
283
284
285def _add_except_columns(expression, tables, except_columns):
286    except_ = expression.args.get("except")
287
288    if not except_:
289        return
290
291    columns = {e.name for e in except_}
292
293    for table in tables:
294        except_columns[id(table)] = columns
295
296
297def _add_replace_columns(expression, tables, replace_columns):
298    replace = expression.args.get("replace")
299
300    if not replace:
301        return
302
303    columns = {e.this.name: e.alias for e in replace}
304
305    for table in tables:
306        replace_columns[id(table)] = columns
307
308
309def _qualify_outputs(scope):
310    """Ensure all output columns are aliased"""
311    new_selections = []
312
313    for i, (selection, aliased_column) in enumerate(
314        itertools.zip_longest(scope.selects, scope.outer_column_list)
315    ):
316        if isinstance(selection, exp.Subquery):
317            if not selection.output_name:
318                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
319        elif not isinstance(selection, exp.Alias):
320            alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
321            alias_.set("this", selection)
322            selection = alias_
323
324        if aliased_column:
325            selection.set("alias", exp.to_identifier(aliased_column))
326
327        new_selections.append(selection)
328
329    scope.expression.set("expressions", new_selections)
330
331
332class _Resolver:
333    """
334    Helper for resolving columns.
335
336    This is a class so we can lazily load some things and easily share them across functions.
337    """
338
339    def __init__(self, scope, schema):
340        self.scope = scope
341        self.schema = schema
342        self._source_columns = None
343        self._unambiguous_columns = None
344        self._all_columns = None
345
346    def get_table(self, column_name: str) -> t.Optional[str]:
347        """
348        Get the table for a column name.
349
350        Args:
351            column_name: The column name to find the table for.
352        Returns:
353            The table name if it can be found/inferred.
354        """
355        if self._unambiguous_columns is None:
356            self._unambiguous_columns = self._get_unambiguous_columns(
357                self._get_all_source_columns()
358            )
359
360        table = self._unambiguous_columns.get(column_name)
361
362        if not table:
363            sources_without_schema = tuple(
364                source for source, columns in self._get_all_source_columns().items() if not columns
365            )
366            if len(sources_without_schema) == 1:
367                return sources_without_schema[0]
368
369        return table
370
371    @property
372    def all_columns(self):
373        """All available columns of all sources in this scope"""
374        if self._all_columns is None:
375            self._all_columns = {
376                column for columns in self._get_all_source_columns().values() for column in columns
377            }
378        return self._all_columns
379
380    def get_source_columns(self, name, only_visible=False):
381        """Resolve the source columns for a given source `name`"""
382        if name not in self.scope.sources:
383            raise OptimizeError(f"Unknown table: {name}")
384
385        source = self.scope.sources[name]
386
387        # If referencing a table, return the columns from the schema
388        if isinstance(source, exp.Table):
389            return self.schema.column_names(source, only_visible)
390
391        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
392            return source.expression.alias_column_names
393
394        # Otherwise, if referencing another scope, return that scope's named selects
395        return source.expression.named_selects
396
397    def _get_all_source_columns(self):
398        if self._source_columns is None:
399            self._source_columns = {
400                k: self.get_source_columns(k) for k in self.scope.selected_sources
401            }
402        return self._source_columns
403
404    def _get_unambiguous_columns(self, source_columns):
405        """
406        Find all the unambiguous columns in sources.
407
408        Args:
409            source_columns (dict): Mapping of names to source columns
410        Returns:
411            dict: Mapping of column name to source name
412        """
413        if not source_columns:
414            return {}
415
416        source_columns = list(source_columns.items())
417
418        first_table, first_columns = source_columns[0]
419        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
420        all_columns = set(unambiguous_columns)
421
422        for table, columns in source_columns[1:]:
423            unique = self._find_unique_columns(columns)
424            ambiguous = set(all_columns).intersection(unique)
425            all_columns.update(columns)
426            for column in ambiguous:
427                unambiguous_columns.pop(column, None)
428            for column in unique.difference(ambiguous):
429                unambiguous_columns[column] = table
430
431        return unambiguous_columns
432
433    @staticmethod
434    def _find_unique_columns(columns):
435        """
436        Find the unique columns in a list of columns.
437
438        Example:
439            >>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"]))
440            ['a', 'c']
441
442        This is necessary because duplicate column names are ambiguous.
443        """
444        counts = {}
445        for column in columns:
446            counts[column] = counts.get(column, 0) + 1
447        return {column for column, count in counts.items() if count == 1}
def qualify_columns(expression, schema):
11def qualify_columns(expression, schema):
12    """
13    Rewrite sqlglot AST to have fully qualified columns.
14
15    Example:
16        >>> import sqlglot
17        >>> schema = {"tbl": {"col": "INT"}}
18        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
19        >>> qualify_columns(expression, schema).sql()
20        'SELECT tbl.col AS col FROM tbl'
21
22    Args:
23        expression (sqlglot.Expression): expression to qualify
24        schema (dict|sqlglot.optimizer.Schema): Database schema
25    Returns:
26        sqlglot.Expression: qualified expression
27    """
28    schema = ensure_schema(schema)
29
30    for scope in traverse_scope(expression):
31        resolver = _Resolver(scope, schema)
32        _pop_table_column_aliases(scope.ctes)
33        _pop_table_column_aliases(scope.derived_tables)
34        _expand_using(scope, resolver)
35        _expand_group_by(scope, resolver)
36        _qualify_columns(scope, resolver)
37        _expand_order_by(scope)
38        if not isinstance(scope.expression, exp.UDTF):
39            _expand_stars(scope, resolver)
40            _qualify_outputs(scope)
41
42    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression (sqlglot.Expression): expression to qualify
  • schema (dict|sqlglot.optimizer.Schema): Database schema
Returns:

sqlglot.Expression: qualified expression

def validate_qualify_columns(expression):
45def validate_qualify_columns(expression):
46    """Raise an `OptimizeError` if any columns aren't qualified"""
47    unqualified_columns = []
48    for scope in traverse_scope(expression):
49        if isinstance(scope.expression, exp.Select):
50            unqualified_columns.extend(scope.unqualified_columns)
51            if scope.external_columns and not scope.is_correlated_subquery:
52                raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}")
53
54    if unqualified_columns:
55        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
56    return expression

Raise an OptimizeError if any columns aren't qualified