Edit on GitHub

sqlglot.optimizer.qualify_columns

  1import itertools
  2import typing as t
  3
  4from sqlglot import alias, exp
  5from sqlglot.errors import OptimizeError
  6from sqlglot.optimizer.scope import Scope, traverse_scope
  7from sqlglot.schema import ensure_schema
  8
  9
 10def qualify_columns(expression, schema):
 11    """
 12    Rewrite sqlglot AST to have fully qualified columns.
 13
 14    Example:
 15        >>> import sqlglot
 16        >>> schema = {"tbl": {"col": "INT"}}
 17        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
 18        >>> qualify_columns(expression, schema).sql()
 19        'SELECT tbl.col AS col FROM tbl'
 20
 21    Args:
 22        expression (sqlglot.Expression): expression to qualify
 23        schema (dict|sqlglot.optimizer.Schema): Database schema
 24    Returns:
 25        sqlglot.Expression: qualified expression
 26    """
 27    schema = ensure_schema(schema)
 28
 29    for scope in traverse_scope(expression):
 30        resolver = _Resolver(scope, schema)
 31        _pop_table_column_aliases(scope.ctes)
 32        _pop_table_column_aliases(scope.derived_tables)
 33        _expand_using(scope, resolver)
 34        _qualify_columns(scope, resolver)
 35        if not isinstance(scope.expression, exp.UDTF):
 36            _expand_stars(scope, resolver)
 37            _qualify_outputs(scope)
 38        _expand_group_by(scope, resolver)
 39        _expand_order_by(scope)
 40    return expression
 41
 42
 43def validate_qualify_columns(expression):
 44    """Raise an `OptimizeError` if any columns aren't qualified"""
 45    unqualified_columns = []
 46    for scope in traverse_scope(expression):
 47        if isinstance(scope.expression, exp.Select):
 48            unqualified_columns.extend(scope.unqualified_columns)
 49            if scope.external_columns and not scope.is_correlated_subquery:
 50                raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}")
 51
 52    if unqualified_columns:
 53        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
 54    return expression
 55
 56
 57def _pop_table_column_aliases(derived_tables):
 58    """
 59    Remove table column aliases.
 60
 61    (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
 62    """
 63    for derived_table in derived_tables:
 64        table_alias = derived_table.args.get("alias")
 65        if table_alias:
 66            table_alias.args.pop("columns", None)
 67
 68
 69def _expand_using(scope, resolver):
 70    joins = list(scope.expression.find_all(exp.Join))
 71    names = {join.this.alias for join in joins}
 72    ordered = [key for key in scope.selected_sources if key not in names]
 73
 74    # Mapping of automatically joined column names to source names
 75    column_tables = {}
 76
 77    for join in joins:
 78        using = join.args.get("using")
 79
 80        if not using:
 81            continue
 82
 83        join_table = join.this.alias_or_name
 84
 85        columns = {}
 86
 87        for k in scope.selected_sources:
 88            if k in ordered:
 89                for column in resolver.get_source_columns(k):
 90                    if column not in columns:
 91                        columns[column] = k
 92
 93        ordered.append(join_table)
 94        join_columns = resolver.get_source_columns(join_table)
 95        conditions = []
 96
 97        for identifier in using:
 98            identifier = identifier.name
 99            table = columns.get(identifier)
100
101            if not table or identifier not in join_columns:
102                raise OptimizeError(f"Cannot automatically join: {identifier}")
103
104            conditions.append(
105                exp.condition(
106                    exp.EQ(
107                        this=exp.column(identifier, table=table),
108                        expression=exp.column(identifier, table=join_table),
109                    )
110                )
111            )
112
113            tables = column_tables.setdefault(identifier, [])
114            if table not in tables:
115                tables.append(table)
116            if join_table not in tables:
117                tables.append(join_table)
118
119        join.args.pop("using")
120        join.set("on", exp.and_(*conditions))
121
122    if column_tables:
123        for column in scope.columns:
124            if not column.table and column.name in column_tables:
125                tables = column_tables[column.name]
126                coalesce = [exp.column(column.name, table=table) for table in tables]
127                replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
128
129                # Ensure selects keep their output name
130                if isinstance(column.parent, exp.Select):
131                    replacement = exp.alias_(replacement, alias=column.name)
132
133                scope.replace(column, replacement)
134
135
136def _expand_group_by(scope, resolver):
137    group = scope.expression.args.get("group")
138    if not group:
139        return
140
141    # Replace references to select aliases
142    def transform(node, *_):
143        if isinstance(node, exp.Column) and not node.table:
144            table = resolver.get_table(node.name)
145
146            # Source columns get priority over select aliases
147            if table:
148                node.set("table", exp.to_identifier(table))
149                return node
150
151            selects = {s.alias_or_name: s for s in scope.selects}
152
153            select = selects.get(node.name)
154            if select:
155                scope.clear_cache()
156                if isinstance(select, exp.Alias):
157                    select = select.this
158                return select.copy()
159
160        return node
161
162    group.transform(transform, copy=False)
163    group.set("expressions", _expand_positional_references(scope, group.expressions))
164    scope.expression.set("group", group)
165
166
167def _expand_order_by(scope):
168    order = scope.expression.args.get("order")
169    if not order:
170        return
171
172    ordereds = order.expressions
173    for ordered, new_expression in zip(
174        ordereds,
175        _expand_positional_references(scope, (o.this for o in ordereds)),
176    ):
177        ordered.set("this", new_expression)
178
179
180def _expand_positional_references(scope, expressions):
181    new_nodes = []
182    for node in expressions:
183        if node.is_int:
184            try:
185                select = scope.selects[int(node.name) - 1]
186            except IndexError:
187                raise OptimizeError(f"Unknown output column: {node.name}")
188            if isinstance(select, exp.Alias):
189                select = select.this
190            new_nodes.append(select.copy())
191            scope.clear_cache()
192        else:
193            new_nodes.append(node)
194
195    return new_nodes
196
197
198def _qualify_columns(scope, resolver):
199    """Disambiguate columns, ensuring each column specifies a source"""
200    for column in scope.columns:
201        column_table = column.table
202        column_name = column.name
203
204        if column_table and column_table in scope.sources:
205            source_columns = resolver.get_source_columns(column_table)
206            if source_columns and column_name not in source_columns:
207                raise OptimizeError(f"Unknown column: {column_name}")
208
209        if not column_table:
210            column_table = resolver.get_table(column_name)
211
212            # column_table can be a '' because bigquery unnest has no table alias
213            if column_table:
214                column.set("table", exp.to_identifier(column_table))
215
216    columns_missing_from_scope = []
217    # Determine whether each reference in the order by clause is to a column or an alias.
218    for ordered in scope.find_all(exp.Ordered):
219        for column in ordered.find_all(exp.Column):
220            if (
221                not column.table
222                and column.parent is not ordered
223                and column.name in resolver.all_columns
224            ):
225                columns_missing_from_scope.append(column)
226
227    # Determine whether each reference in the having clause is to a column or an alias.
228    for having in scope.find_all(exp.Having):
229        for column in having.find_all(exp.Column):
230            if (
231                not column.table
232                and column.find_ancestor(exp.AggFunc)
233                and column.name in resolver.all_columns
234            ):
235                columns_missing_from_scope.append(column)
236
237    for column in columns_missing_from_scope:
238        column_table = resolver.get_table(column.name)
239
240        if column_table:
241            column.set("table", exp.to_identifier(column_table))
242
243
244def _expand_stars(scope, resolver):
245    """Expand stars to lists of column selections"""
246
247    new_selections = []
248    except_columns = {}
249    replace_columns = {}
250
251    for expression in scope.selects:
252        if isinstance(expression, exp.Star):
253            tables = list(scope.selected_sources)
254            _add_except_columns(expression, tables, except_columns)
255            _add_replace_columns(expression, tables, replace_columns)
256        elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star):
257            tables = [expression.table]
258            _add_except_columns(expression.this, tables, except_columns)
259            _add_replace_columns(expression.this, tables, replace_columns)
260        else:
261            new_selections.append(expression)
262            continue
263
264        for table in tables:
265            if table not in scope.sources:
266                raise OptimizeError(f"Unknown table: {table}")
267            columns = resolver.get_source_columns(table, only_visible=True)
268            if not columns:
269                raise OptimizeError(
270                    f"Table has no schema/columns. Cannot expand star for table: {table}."
271                )
272            table_id = id(table)
273            for name in columns:
274                if name not in except_columns.get(table_id, set()):
275                    alias_ = replace_columns.get(table_id, {}).get(name, name)
276                    column = exp.column(name, table)
277                    new_selections.append(alias(column, alias_) if alias_ != name else column)
278
279    scope.expression.set("expressions", new_selections)
280
281
282def _add_except_columns(expression, tables, except_columns):
283    except_ = expression.args.get("except")
284
285    if not except_:
286        return
287
288    columns = {e.name for e in except_}
289
290    for table in tables:
291        except_columns[id(table)] = columns
292
293
294def _add_replace_columns(expression, tables, replace_columns):
295    replace = expression.args.get("replace")
296
297    if not replace:
298        return
299
300    columns = {e.this.name: e.alias for e in replace}
301
302    for table in tables:
303        replace_columns[id(table)] = columns
304
305
306def _qualify_outputs(scope):
307    """Ensure all output columns are aliased"""
308    new_selections = []
309
310    for i, (selection, aliased_column) in enumerate(
311        itertools.zip_longest(scope.selects, scope.outer_column_list)
312    ):
313        if isinstance(selection, exp.Subquery):
314            if not selection.output_name:
315                selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}")))
316        elif not isinstance(selection, exp.Alias):
317            alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}")
318            alias_.set("this", selection)
319            selection = alias_
320
321        if aliased_column:
322            selection.set("alias", exp.to_identifier(aliased_column))
323
324        new_selections.append(selection)
325
326    scope.expression.set("expressions", new_selections)
327
328
329class _Resolver:
330    """
331    Helper for resolving columns.
332
333    This is a class so we can lazily load some things and easily share them across functions.
334    """
335
336    def __init__(self, scope, schema):
337        self.scope = scope
338        self.schema = schema
339        self._source_columns = None
340        self._unambiguous_columns = None
341        self._all_columns = None
342
343    def get_table(self, column_name: str) -> t.Optional[str]:
344        """
345        Get the table for a column name.
346
347        Args:
348            column_name: The column name to find the table for.
349        Returns:
350            The table name if it can be found/inferred.
351        """
352        if self._unambiguous_columns is None:
353            self._unambiguous_columns = self._get_unambiguous_columns(
354                self._get_all_source_columns()
355            )
356
357        table = self._unambiguous_columns.get(column_name)
358
359        if not table:
360            sources_without_schema = tuple(
361                source for source, columns in self._get_all_source_columns().items() if not columns
362            )
363            if len(sources_without_schema) == 1:
364                return sources_without_schema[0]
365
366        return table
367
368    @property
369    def all_columns(self):
370        """All available columns of all sources in this scope"""
371        if self._all_columns is None:
372            self._all_columns = {
373                column for columns in self._get_all_source_columns().values() for column in columns
374            }
375        return self._all_columns
376
377    def get_source_columns(self, name, only_visible=False):
378        """Resolve the source columns for a given source `name`"""
379        if name not in self.scope.sources:
380            raise OptimizeError(f"Unknown table: {name}")
381
382        source = self.scope.sources[name]
383
384        # If referencing a table, return the columns from the schema
385        if isinstance(source, exp.Table):
386            return self.schema.column_names(source, only_visible)
387
388        if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
389            return source.expression.alias_column_names
390
391        # Otherwise, if referencing another scope, return that scope's named selects
392        return source.expression.named_selects
393
394    def _get_all_source_columns(self):
395        if self._source_columns is None:
396            self._source_columns = {
397                k: self.get_source_columns(k)
398                for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources)
399            }
400        return self._source_columns
401
402    def _get_unambiguous_columns(self, source_columns):
403        """
404        Find all the unambiguous columns in sources.
405
406        Args:
407            source_columns (dict): Mapping of names to source columns
408        Returns:
409            dict: Mapping of column name to source name
410        """
411        if not source_columns:
412            return {}
413
414        source_columns = list(source_columns.items())
415
416        first_table, first_columns = source_columns[0]
417        unambiguous_columns = {col: first_table for col in self._find_unique_columns(first_columns)}
418        all_columns = set(unambiguous_columns)
419
420        for table, columns in source_columns[1:]:
421            unique = self._find_unique_columns(columns)
422            ambiguous = set(all_columns).intersection(unique)
423            all_columns.update(columns)
424            for column in ambiguous:
425                unambiguous_columns.pop(column, None)
426            for column in unique.difference(ambiguous):
427                unambiguous_columns[column] = table
428
429        return unambiguous_columns
430
431    @staticmethod
432    def _find_unique_columns(columns):
433        """
434        Find the unique columns in a list of columns.
435
436        Example:
437            >>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"]))
438            ['a', 'c']
439
440        This is necessary because duplicate column names are ambiguous.
441        """
442        counts = {}
443        for column in columns:
444            counts[column] = counts.get(column, 0) + 1
445        return {column for column, count in counts.items() if count == 1}
def qualify_columns(expression, schema):
11def qualify_columns(expression, schema):
12    """
13    Rewrite sqlglot AST to have fully qualified columns.
14
15    Example:
16        >>> import sqlglot
17        >>> schema = {"tbl": {"col": "INT"}}
18        >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
19        >>> qualify_columns(expression, schema).sql()
20        'SELECT tbl.col AS col FROM tbl'
21
22    Args:
23        expression (sqlglot.Expression): expression to qualify
24        schema (dict|sqlglot.optimizer.Schema): Database schema
25    Returns:
26        sqlglot.Expression: qualified expression
27    """
28    schema = ensure_schema(schema)
29
30    for scope in traverse_scope(expression):
31        resolver = _Resolver(scope, schema)
32        _pop_table_column_aliases(scope.ctes)
33        _pop_table_column_aliases(scope.derived_tables)
34        _expand_using(scope, resolver)
35        _qualify_columns(scope, resolver)
36        if not isinstance(scope.expression, exp.UDTF):
37            _expand_stars(scope, resolver)
38            _qualify_outputs(scope)
39        _expand_group_by(scope, resolver)
40        _expand_order_by(scope)
41    return expression

Rewrite sqlglot AST to have fully qualified columns.

Example:
>>> import sqlglot
>>> schema = {"tbl": {"col": "INT"}}
>>> expression = sqlglot.parse_one("SELECT col FROM tbl")
>>> qualify_columns(expression, schema).sql()
'SELECT tbl.col AS col FROM tbl'
Arguments:
  • expression (sqlglot.Expression): expression to qualify
  • schema (dict|sqlglot.optimizer.Schema): Database schema
Returns:

sqlglot.Expression: qualified expression

def validate_qualify_columns(expression):
44def validate_qualify_columns(expression):
45    """Raise an `OptimizeError` if any columns aren't qualified"""
46    unqualified_columns = []
47    for scope in traverse_scope(expression):
48        if isinstance(scope.expression, exp.Select):
49            unqualified_columns.extend(scope.unqualified_columns)
50            if scope.external_columns and not scope.is_correlated_subquery:
51                raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}")
52
53    if unqualified_columns:
54        raise OptimizeError(f"Ambiguous columns: {unqualified_columns}")
55    return expression

Raise an OptimizeError if any columns aren't qualified