Edit on GitHub

sqlglot.optimizer.scope

  1import itertools
  2from collections import defaultdict
  3from enum import Enum, auto
  4
  5from sqlglot import exp
  6from sqlglot.errors import OptimizeError
  7from sqlglot.helper import find_new_name
  8
  9
 10class ScopeType(Enum):
 11    ROOT = auto()
 12    SUBQUERY = auto()
 13    DERIVED_TABLE = auto()
 14    CTE = auto()
 15    UNION = auto()
 16    UDTF = auto()
 17
 18
 19class Scope:
 20    """
 21    Selection scope.
 22
 23    Attributes:
 24        expression (exp.Select|exp.Union): Root expression of this scope
 25        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 26            a Table expression or another Scope instance. For example:
 27                SELECT * FROM x                     {"x": Table(this="x")}
 28                SELECT * FROM x AS y                {"y": Table(this="x")}
 29                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 30        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 31            For example:
 32                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 33            The LATERAL VIEW EXPLODE gets x as a source.
 34        outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
 35            defines a column list of it's alias of this scope, this is that list of columns.
 36            For example:
 37                SELECT * FROM (SELECT ...) AS y(col1, col2)
 38            The inner query would have `["col1", "col2"]` for its `outer_column_list`
 39        parent (Scope): Parent scope
 40        scope_type (ScopeType): Type of this scope, relative to it's parent
 41        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 42        cte_scopes (list[Scope]): List of all child scopes for CTEs
 43        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 44        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 45        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 46        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 47            a list of the left and right child scopes.
 48    """
 49
 50    def __init__(
 51        self,
 52        expression,
 53        sources=None,
 54        outer_column_list=None,
 55        parent=None,
 56        scope_type=ScopeType.ROOT,
 57        lateral_sources=None,
 58    ):
 59        self.expression = expression
 60        self.sources = sources or {}
 61        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
 62        self.sources.update(self.lateral_sources)
 63        self.outer_column_list = outer_column_list or []
 64        self.parent = parent
 65        self.scope_type = scope_type
 66        self.subquery_scopes = []
 67        self.derived_table_scopes = []
 68        self.table_scopes = []
 69        self.cte_scopes = []
 70        self.union_scopes = []
 71        self.udtf_scopes = []
 72        self.clear_cache()
 73
 74    def clear_cache(self):
 75        self._collected = False
 76        self._raw_columns = None
 77        self._derived_tables = None
 78        self._udtfs = None
 79        self._tables = None
 80        self._ctes = None
 81        self._subqueries = None
 82        self._selected_sources = None
 83        self._columns = None
 84        self._external_columns = None
 85        self._join_hints = None
 86
 87    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 88        """Branch from the current scope to a new, inner scope"""
 89        return Scope(
 90            expression=expression.unnest(),
 91            sources={**self.cte_sources, **(chain_sources or {})},
 92            parent=self,
 93            scope_type=scope_type,
 94            **kwargs,
 95        )
 96
 97    def _collect(self):
 98        self._tables = []
 99        self._ctes = []
100        self._subqueries = []
101        self._derived_tables = []
102        self._udtfs = []
103        self._raw_columns = []
104        self._join_hints = []
105
106        for node, parent, _ in self.walk(bfs=False):
107            if node is self.expression:
108                continue
109            elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
110                self._raw_columns.append(node)
111            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
112                self._tables.append(node)
113            elif isinstance(node, exp.JoinHint):
114                self._join_hints.append(node)
115            elif isinstance(node, exp.UDTF):
116                self._udtfs.append(node)
117            elif isinstance(node, exp.CTE):
118                self._ctes.append(node)
119            elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
120                self._derived_tables.append(node)
121            elif isinstance(node, exp.Subqueryable):
122                self._subqueries.append(node)
123
124        self._collected = True
125
126    def _ensure_collected(self):
127        if not self._collected:
128            self._collect()
129
130    def walk(self, bfs=True):
131        return walk_in_scope(self.expression, bfs=bfs)
132
133    def find(self, *expression_types, bfs=True):
134        """
135        Returns the first node in this scope which matches at least one of the specified types.
136
137        This does NOT traverse into subscopes.
138
139        Args:
140            expression_types (type): the expression type(s) to match.
141            bfs (bool): True to use breadth-first search, False to use depth-first.
142
143        Returns:
144            exp.Expression: the node which matches the criteria or None if no node matching
145            the criteria was found.
146        """
147        return next(self.find_all(*expression_types, bfs=bfs), None)
148
149    def find_all(self, *expression_types, bfs=True):
150        """
151        Returns a generator object which visits all nodes in this scope and only yields those that
152        match at least one of the specified expression types.
153
154        This does NOT traverse into subscopes.
155
156        Args:
157            expression_types (type): the expression type(s) to match.
158            bfs (bool): True to use breadth-first search, False to use depth-first.
159
160        Yields:
161            exp.Expression: nodes
162        """
163        for expression, *_ in self.walk(bfs=bfs):
164            if isinstance(expression, expression_types):
165                yield expression
166
167    def replace(self, old, new):
168        """
169        Replace `old` with `new`.
170
171        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
172
173        Args:
174            old (exp.Expression): old node
175            new (exp.Expression): new node
176        """
177        old.replace(new)
178        self.clear_cache()
179
180    @property
181    def tables(self):
182        """
183        List of tables in this scope.
184
185        Returns:
186            list[exp.Table]: tables
187        """
188        self._ensure_collected()
189        return self._tables
190
191    @property
192    def ctes(self):
193        """
194        List of CTEs in this scope.
195
196        Returns:
197            list[exp.CTE]: ctes
198        """
199        self._ensure_collected()
200        return self._ctes
201
202    @property
203    def derived_tables(self):
204        """
205        List of derived tables in this scope.
206
207        For example:
208            SELECT * FROM (SELECT ...) <- that's a derived table
209
210        Returns:
211            list[exp.Subquery]: derived tables
212        """
213        self._ensure_collected()
214        return self._derived_tables
215
216    @property
217    def udtfs(self):
218        """
219        List of "User Defined Tabular Functions" in this scope.
220
221        Returns:
222            list[exp.UDTF]: UDTFs
223        """
224        self._ensure_collected()
225        return self._udtfs
226
227    @property
228    def subqueries(self):
229        """
230        List of subqueries in this scope.
231
232        For example:
233            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
234
235        Returns:
236            list[exp.Subqueryable]: subqueries
237        """
238        self._ensure_collected()
239        return self._subqueries
240
241    @property
242    def columns(self):
243        """
244        List of columns in this scope.
245
246        Returns:
247            list[exp.Column]: Column instances in this scope, plus any
248                Columns that reference this scope from correlated subqueries.
249        """
250        if self._columns is None:
251            self._ensure_collected()
252            columns = self._raw_columns
253
254            external_columns = [
255                column
256                for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
257                for column in scope.external_columns
258            ]
259
260            named_selects = set(self.expression.named_selects)
261
262            self._columns = []
263            for column in columns + external_columns:
264                ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint)
265                if (
266                    not ancestor
267                    # Window functions can have an ORDER BY clause
268                    or not isinstance(ancestor.parent, exp.Select)
269                    or column.table
270                    or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
271                ):
272                    self._columns.append(column)
273
274        return self._columns
275
276    @property
277    def selected_sources(self):
278        """
279        Mapping of nodes and sources that are actually selected from in this scope.
280
281        That is, all tables in a schema are selectable at any point. But a
282        table only becomes a selected source if it's included in a FROM or JOIN clause.
283
284        Returns:
285            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
286        """
287        if self._selected_sources is None:
288            referenced_names = []
289
290            for table in self.tables:
291                referenced_names.append((table.alias_or_name, table))
292            for expression in itertools.chain(self.derived_tables, self.udtfs):
293                referenced_names.append((expression.alias, expression.unnest()))
294            result = {}
295
296            for name, node in referenced_names:
297                if name in result:
298                    raise OptimizeError(f"Alias already used: {name}")
299                if name in self.sources:
300                    result[name] = (node, self.sources[name])
301
302            self._selected_sources = result
303        return self._selected_sources
304
305    @property
306    def cte_sources(self):
307        """
308        Sources that are CTEs.
309
310        Returns:
311            dict[str, Scope]: Mapping of source alias to Scope
312        """
313        return {
314            alias: scope
315            for alias, scope in self.sources.items()
316            if isinstance(scope, Scope) and scope.is_cte
317        }
318
319    @property
320    def selects(self):
321        """
322        Select expressions of this scope.
323
324        For example, for the following expression:
325            SELECT 1 as a, 2 as b FROM x
326
327        The outputs are the "1 as a" and "2 as b" expressions.
328
329        Returns:
330            list[exp.Expression]: expressions
331        """
332        if isinstance(self.expression, exp.Union):
333            return self.expression.unnest().selects
334        return self.expression.selects
335
336    @property
337    def external_columns(self):
338        """
339        Columns that appear to reference sources in outer scopes.
340
341        Returns:
342            list[exp.Column]: Column instances that don't reference
343                sources in the current scope.
344        """
345        if self._external_columns is None:
346            self._external_columns = [
347                c for c in self.columns if c.table not in self.selected_sources
348            ]
349        return self._external_columns
350
351    @property
352    def unqualified_columns(self):
353        """
354        Unqualified columns in the current scope.
355
356        Returns:
357             list[exp.Column]: Unqualified columns
358        """
359        return [c for c in self.columns if not c.table]
360
361    @property
362    def join_hints(self):
363        """
364        Hints that exist in the scope that reference tables
365
366        Returns:
367            list[exp.JoinHint]: Join hints that are referenced within the scope
368        """
369        if self._join_hints is None:
370            return []
371        return self._join_hints
372
373    def source_columns(self, source_name):
374        """
375        Get all columns in the current scope for a particular source.
376
377        Args:
378            source_name (str): Name of the source
379        Returns:
380            list[exp.Column]: Column instances that reference `source_name`
381        """
382        return [column for column in self.columns if column.table == source_name]
383
384    @property
385    def is_subquery(self):
386        """Determine if this scope is a subquery"""
387        return self.scope_type == ScopeType.SUBQUERY
388
389    @property
390    def is_derived_table(self):
391        """Determine if this scope is a derived table"""
392        return self.scope_type == ScopeType.DERIVED_TABLE
393
394    @property
395    def is_union(self):
396        """Determine if this scope is a union"""
397        return self.scope_type == ScopeType.UNION
398
399    @property
400    def is_cte(self):
401        """Determine if this scope is a common table expression"""
402        return self.scope_type == ScopeType.CTE
403
404    @property
405    def is_root(self):
406        """Determine if this is the root scope"""
407        return self.scope_type == ScopeType.ROOT
408
409    @property
410    def is_udtf(self):
411        """Determine if this scope is a UDTF (User Defined Table Function)"""
412        return self.scope_type == ScopeType.UDTF
413
414    @property
415    def is_correlated_subquery(self):
416        """Determine if this scope is a correlated subquery"""
417        return bool(self.is_subquery and self.external_columns)
418
419    def rename_source(self, old_name, new_name):
420        """Rename a source in this scope"""
421        columns = self.sources.pop(old_name or "", [])
422        self.sources[new_name] = columns
423
424    def add_source(self, name, source):
425        """Add a source to this scope"""
426        self.sources[name] = source
427        self.clear_cache()
428
429    def remove_source(self, name):
430        """Remove a source from this scope"""
431        self.sources.pop(name, None)
432        self.clear_cache()
433
434    def __repr__(self):
435        return f"Scope<{self.expression.sql()}>"
436
437    def traverse(self):
438        """
439        Traverse the scope tree from this node.
440
441        Yields:
442            Scope: scope instances in depth-first-search post-order
443        """
444        for child_scope in itertools.chain(
445            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
446        ):
447            yield from child_scope.traverse()
448        yield self
449
450    def ref_count(self):
451        """
452        Count the number of times each scope in this tree is referenced.
453
454        Returns:
455            dict[int, int]: Mapping of Scope instance ID to reference count
456        """
457        scope_ref_count = defaultdict(lambda: 0)
458
459        for scope in self.traverse():
460            for _, source in scope.selected_sources.values():
461                scope_ref_count[id(source)] += 1
462
463        return scope_ref_count
464
465
466def traverse_scope(expression):
467    """
468    Traverse an expression by it's "scopes".
469
470    "Scope" represents the current context of a Select statement.
471
472    This is helpful for optimizing queries, where we need more information than
473    the expression tree itself. For example, we might care about the source
474    names within a subquery. Returns a list because a generator could result in
475    incomplete properties which is confusing.
476
477    Examples:
478        >>> import sqlglot
479        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
480        >>> scopes = traverse_scope(expression)
481        >>> scopes[0].expression.sql(), list(scopes[0].sources)
482        ('SELECT a FROM x', ['x'])
483        >>> scopes[1].expression.sql(), list(scopes[1].sources)
484        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
485
486    Args:
487        expression (exp.Expression): expression to traverse
488    Returns:
489        list[Scope]: scope instances
490    """
491    return list(_traverse_scope(Scope(expression)))
492
493
494def build_scope(expression):
495    """
496    Build a scope tree.
497
498    Args:
499        expression (exp.Expression): expression to build the scope tree for
500    Returns:
501        Scope: root scope
502    """
503    return traverse_scope(expression)[-1]
504
505
506def _traverse_scope(scope):
507    if isinstance(scope.expression, exp.Select):
508        yield from _traverse_select(scope)
509    elif isinstance(scope.expression, exp.Union):
510        yield from _traverse_union(scope)
511    elif isinstance(scope.expression, exp.Subquery):
512        yield from _traverse_subqueries(scope)
513    elif isinstance(scope.expression, exp.UDTF):
514        pass
515    else:
516        raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
517    yield scope
518
519
520def _traverse_select(scope):
521    yield from _traverse_ctes(scope)
522    yield from _traverse_tables(scope)
523    yield from _traverse_subqueries(scope)
524
525
526def _traverse_union(scope):
527    yield from _traverse_ctes(scope)
528
529    # The last scope to be yield should be the top most scope
530    left = None
531    for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)):
532        yield left
533
534    right = None
535    for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
536        yield right
537
538    scope.union_scopes = [left, right]
539
540
541def _traverse_ctes(scope):
542    sources = {}
543
544    for cte in scope.ctes:
545        recursive_scope = None
546
547        # if the scope is a recursive cte, it must be in the form of
548        # base_case UNION recursive. thus the recursive scope is the first
549        # section of the union.
550        if scope.expression.args["with"].recursive:
551            union = cte.this
552
553            if isinstance(union, exp.Union):
554                recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
555
556        for child_scope in _traverse_scope(
557            scope.branch(
558                cte.this,
559                chain_sources=sources,
560                outer_column_list=cte.alias_column_names,
561                scope_type=ScopeType.CTE,
562            )
563        ):
564            yield child_scope
565
566            alias = cte.alias
567            sources[alias] = child_scope
568
569            if recursive_scope:
570                child_scope.add_source(alias, recursive_scope)
571
572        # append the final child_scope yielded
573        scope.cte_scopes.append(child_scope)
574
575    scope.sources.update(sources)
576
577
578def _traverse_tables(scope):
579    sources = {}
580
581    # Traverse FROMs, JOINs, and LATERALs in the order they are defined
582    expressions = []
583    from_ = scope.expression.args.get("from")
584    if from_:
585        expressions.extend(from_.expressions)
586
587    for join in scope.expression.args.get("joins") or []:
588        expressions.append(join.this)
589
590    expressions.extend(scope.expression.args.get("laterals") or [])
591
592    for expression in expressions:
593        if isinstance(expression, exp.Table):
594            table_name = expression.name
595            source_name = expression.alias_or_name
596
597            if table_name in scope.sources:
598                # This is a reference to a parent source (e.g. a CTE), not an actual table.
599                sources[source_name] = scope.sources[table_name]
600            elif source_name in sources:
601                sources[find_new_name(sources, table_name)] = expression
602            else:
603                sources[source_name] = expression
604            continue
605
606        if isinstance(expression, exp.UDTF):
607            lateral_sources = sources
608            scope_type = ScopeType.UDTF
609            scopes = scope.udtf_scopes
610        else:
611            lateral_sources = None
612            scope_type = ScopeType.DERIVED_TABLE
613            scopes = scope.derived_table_scopes
614
615        for child_scope in _traverse_scope(
616            scope.branch(
617                expression,
618                lateral_sources=lateral_sources,
619                outer_column_list=expression.alias_column_names,
620                scope_type=scope_type,
621            )
622        ):
623            yield child_scope
624
625            # Tables without aliases will be set as ""
626            # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
627            # Until then, this means that only a single, unaliased derived table is allowed (rather,
628            # the latest one wins.
629            alias = expression.alias
630            sources[alias] = child_scope
631
632        # append the final child_scope yielded
633        scopes.append(child_scope)
634        scope.table_scopes.append(child_scope)
635
636    scope.sources.update(sources)
637
638
639def _traverse_subqueries(scope):
640    for subquery in scope.subqueries:
641        top = None
642        for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
643            yield child_scope
644            top = child_scope
645        scope.subquery_scopes.append(top)
646
647
648def walk_in_scope(expression, bfs=True):
649    """
650    Returns a generator object which visits all nodes in the syntrax tree, stopping at
651    nodes that start child scopes.
652
653    Args:
654        expression (exp.Expression):
655        bfs (bool): if set to True the BFS traversal order will be applied,
656            otherwise the DFS traversal will be used instead.
657
658    Yields:
659        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
660    """
661    # We'll use this variable to pass state into the dfs generator.
662    # Whenever we set it to True, we exclude a subtree from traversal.
663    prune = False
664
665    for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
666        prune = False
667
668        yield node, parent, key
669
670        if node is expression:
671            continue
672        if (
673            isinstance(node, exp.CTE)
674            or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)))
675            or isinstance(node, exp.UDTF)
676            or isinstance(node, exp.Subqueryable)
677        ):
678            prune = True
class ScopeType(enum.Enum):
11class ScopeType(Enum):
12    ROOT = auto()
13    SUBQUERY = auto()
14    DERIVED_TABLE = auto()
15    CTE = auto()
16    UNION = auto()
17    UDTF = auto()

An enumeration.

ROOT = <ScopeType.ROOT: 1>
SUBQUERY = <ScopeType.SUBQUERY: 2>
DERIVED_TABLE = <ScopeType.DERIVED_TABLE: 3>
CTE = <ScopeType.CTE: 4>
UNION = <ScopeType.UNION: 5>
UDTF = <ScopeType.UDTF: 6>
Inherited Members
enum.Enum
name
value
class Scope:
 20class Scope:
 21    """
 22    Selection scope.
 23
 24    Attributes:
 25        expression (exp.Select|exp.Union): Root expression of this scope
 26        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 27            a Table expression or another Scope instance. For example:
 28                SELECT * FROM x                     {"x": Table(this="x")}
 29                SELECT * FROM x AS y                {"y": Table(this="x")}
 30                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 31        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 32            For example:
 33                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 34            The LATERAL VIEW EXPLODE gets x as a source.
 35        outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
 36            defines a column list of it's alias of this scope, this is that list of columns.
 37            For example:
 38                SELECT * FROM (SELECT ...) AS y(col1, col2)
 39            The inner query would have `["col1", "col2"]` for its `outer_column_list`
 40        parent (Scope): Parent scope
 41        scope_type (ScopeType): Type of this scope, relative to it's parent
 42        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 43        cte_scopes (list[Scope]): List of all child scopes for CTEs
 44        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 45        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 46        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 47        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 48            a list of the left and right child scopes.
 49    """
 50
 51    def __init__(
 52        self,
 53        expression,
 54        sources=None,
 55        outer_column_list=None,
 56        parent=None,
 57        scope_type=ScopeType.ROOT,
 58        lateral_sources=None,
 59    ):
 60        self.expression = expression
 61        self.sources = sources or {}
 62        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
 63        self.sources.update(self.lateral_sources)
 64        self.outer_column_list = outer_column_list or []
 65        self.parent = parent
 66        self.scope_type = scope_type
 67        self.subquery_scopes = []
 68        self.derived_table_scopes = []
 69        self.table_scopes = []
 70        self.cte_scopes = []
 71        self.union_scopes = []
 72        self.udtf_scopes = []
 73        self.clear_cache()
 74
 75    def clear_cache(self):
 76        self._collected = False
 77        self._raw_columns = None
 78        self._derived_tables = None
 79        self._udtfs = None
 80        self._tables = None
 81        self._ctes = None
 82        self._subqueries = None
 83        self._selected_sources = None
 84        self._columns = None
 85        self._external_columns = None
 86        self._join_hints = None
 87
 88    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 89        """Branch from the current scope to a new, inner scope"""
 90        return Scope(
 91            expression=expression.unnest(),
 92            sources={**self.cte_sources, **(chain_sources or {})},
 93            parent=self,
 94            scope_type=scope_type,
 95            **kwargs,
 96        )
 97
 98    def _collect(self):
 99        self._tables = []
100        self._ctes = []
101        self._subqueries = []
102        self._derived_tables = []
103        self._udtfs = []
104        self._raw_columns = []
105        self._join_hints = []
106
107        for node, parent, _ in self.walk(bfs=False):
108            if node is self.expression:
109                continue
110            elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
111                self._raw_columns.append(node)
112            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
113                self._tables.append(node)
114            elif isinstance(node, exp.JoinHint):
115                self._join_hints.append(node)
116            elif isinstance(node, exp.UDTF):
117                self._udtfs.append(node)
118            elif isinstance(node, exp.CTE):
119                self._ctes.append(node)
120            elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
121                self._derived_tables.append(node)
122            elif isinstance(node, exp.Subqueryable):
123                self._subqueries.append(node)
124
125        self._collected = True
126
127    def _ensure_collected(self):
128        if not self._collected:
129            self._collect()
130
131    def walk(self, bfs=True):
132        return walk_in_scope(self.expression, bfs=bfs)
133
134    def find(self, *expression_types, bfs=True):
135        """
136        Returns the first node in this scope which matches at least one of the specified types.
137
138        This does NOT traverse into subscopes.
139
140        Args:
141            expression_types (type): the expression type(s) to match.
142            bfs (bool): True to use breadth-first search, False to use depth-first.
143
144        Returns:
145            exp.Expression: the node which matches the criteria or None if no node matching
146            the criteria was found.
147        """
148        return next(self.find_all(*expression_types, bfs=bfs), None)
149
150    def find_all(self, *expression_types, bfs=True):
151        """
152        Returns a generator object which visits all nodes in this scope and only yields those that
153        match at least one of the specified expression types.
154
155        This does NOT traverse into subscopes.
156
157        Args:
158            expression_types (type): the expression type(s) to match.
159            bfs (bool): True to use breadth-first search, False to use depth-first.
160
161        Yields:
162            exp.Expression: nodes
163        """
164        for expression, *_ in self.walk(bfs=bfs):
165            if isinstance(expression, expression_types):
166                yield expression
167
168    def replace(self, old, new):
169        """
170        Replace `old` with `new`.
171
172        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
173
174        Args:
175            old (exp.Expression): old node
176            new (exp.Expression): new node
177        """
178        old.replace(new)
179        self.clear_cache()
180
181    @property
182    def tables(self):
183        """
184        List of tables in this scope.
185
186        Returns:
187            list[exp.Table]: tables
188        """
189        self._ensure_collected()
190        return self._tables
191
192    @property
193    def ctes(self):
194        """
195        List of CTEs in this scope.
196
197        Returns:
198            list[exp.CTE]: ctes
199        """
200        self._ensure_collected()
201        return self._ctes
202
203    @property
204    def derived_tables(self):
205        """
206        List of derived tables in this scope.
207
208        For example:
209            SELECT * FROM (SELECT ...) <- that's a derived table
210
211        Returns:
212            list[exp.Subquery]: derived tables
213        """
214        self._ensure_collected()
215        return self._derived_tables
216
217    @property
218    def udtfs(self):
219        """
220        List of "User Defined Tabular Functions" in this scope.
221
222        Returns:
223            list[exp.UDTF]: UDTFs
224        """
225        self._ensure_collected()
226        return self._udtfs
227
228    @property
229    def subqueries(self):
230        """
231        List of subqueries in this scope.
232
233        For example:
234            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
235
236        Returns:
237            list[exp.Subqueryable]: subqueries
238        """
239        self._ensure_collected()
240        return self._subqueries
241
242    @property
243    def columns(self):
244        """
245        List of columns in this scope.
246
247        Returns:
248            list[exp.Column]: Column instances in this scope, plus any
249                Columns that reference this scope from correlated subqueries.
250        """
251        if self._columns is None:
252            self._ensure_collected()
253            columns = self._raw_columns
254
255            external_columns = [
256                column
257                for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
258                for column in scope.external_columns
259            ]
260
261            named_selects = set(self.expression.named_selects)
262
263            self._columns = []
264            for column in columns + external_columns:
265                ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint)
266                if (
267                    not ancestor
268                    # Window functions can have an ORDER BY clause
269                    or not isinstance(ancestor.parent, exp.Select)
270                    or column.table
271                    or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
272                ):
273                    self._columns.append(column)
274
275        return self._columns
276
277    @property
278    def selected_sources(self):
279        """
280        Mapping of nodes and sources that are actually selected from in this scope.
281
282        That is, all tables in a schema are selectable at any point. But a
283        table only becomes a selected source if it's included in a FROM or JOIN clause.
284
285        Returns:
286            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
287        """
288        if self._selected_sources is None:
289            referenced_names = []
290
291            for table in self.tables:
292                referenced_names.append((table.alias_or_name, table))
293            for expression in itertools.chain(self.derived_tables, self.udtfs):
294                referenced_names.append((expression.alias, expression.unnest()))
295            result = {}
296
297            for name, node in referenced_names:
298                if name in result:
299                    raise OptimizeError(f"Alias already used: {name}")
300                if name in self.sources:
301                    result[name] = (node, self.sources[name])
302
303            self._selected_sources = result
304        return self._selected_sources
305
306    @property
307    def cte_sources(self):
308        """
309        Sources that are CTEs.
310
311        Returns:
312            dict[str, Scope]: Mapping of source alias to Scope
313        """
314        return {
315            alias: scope
316            for alias, scope in self.sources.items()
317            if isinstance(scope, Scope) and scope.is_cte
318        }
319
320    @property
321    def selects(self):
322        """
323        Select expressions of this scope.
324
325        For example, for the following expression:
326            SELECT 1 as a, 2 as b FROM x
327
328        The outputs are the "1 as a" and "2 as b" expressions.
329
330        Returns:
331            list[exp.Expression]: expressions
332        """
333        if isinstance(self.expression, exp.Union):
334            return self.expression.unnest().selects
335        return self.expression.selects
336
337    @property
338    def external_columns(self):
339        """
340        Columns that appear to reference sources in outer scopes.
341
342        Returns:
343            list[exp.Column]: Column instances that don't reference
344                sources in the current scope.
345        """
346        if self._external_columns is None:
347            self._external_columns = [
348                c for c in self.columns if c.table not in self.selected_sources
349            ]
350        return self._external_columns
351
352    @property
353    def unqualified_columns(self):
354        """
355        Unqualified columns in the current scope.
356
357        Returns:
358             list[exp.Column]: Unqualified columns
359        """
360        return [c for c in self.columns if not c.table]
361
362    @property
363    def join_hints(self):
364        """
365        Hints that exist in the scope that reference tables
366
367        Returns:
368            list[exp.JoinHint]: Join hints that are referenced within the scope
369        """
370        if self._join_hints is None:
371            return []
372        return self._join_hints
373
374    def source_columns(self, source_name):
375        """
376        Get all columns in the current scope for a particular source.
377
378        Args:
379            source_name (str): Name of the source
380        Returns:
381            list[exp.Column]: Column instances that reference `source_name`
382        """
383        return [column for column in self.columns if column.table == source_name]
384
385    @property
386    def is_subquery(self):
387        """Determine if this scope is a subquery"""
388        return self.scope_type == ScopeType.SUBQUERY
389
390    @property
391    def is_derived_table(self):
392        """Determine if this scope is a derived table"""
393        return self.scope_type == ScopeType.DERIVED_TABLE
394
395    @property
396    def is_union(self):
397        """Determine if this scope is a union"""
398        return self.scope_type == ScopeType.UNION
399
400    @property
401    def is_cte(self):
402        """Determine if this scope is a common table expression"""
403        return self.scope_type == ScopeType.CTE
404
405    @property
406    def is_root(self):
407        """Determine if this is the root scope"""
408        return self.scope_type == ScopeType.ROOT
409
410    @property
411    def is_udtf(self):
412        """Determine if this scope is a UDTF (User Defined Table Function)"""
413        return self.scope_type == ScopeType.UDTF
414
415    @property
416    def is_correlated_subquery(self):
417        """Determine if this scope is a correlated subquery"""
418        return bool(self.is_subquery and self.external_columns)
419
420    def rename_source(self, old_name, new_name):
421        """Rename a source in this scope"""
422        columns = self.sources.pop(old_name or "", [])
423        self.sources[new_name] = columns
424
425    def add_source(self, name, source):
426        """Add a source to this scope"""
427        self.sources[name] = source
428        self.clear_cache()
429
430    def remove_source(self, name):
431        """Remove a source from this scope"""
432        self.sources.pop(name, None)
433        self.clear_cache()
434
435    def __repr__(self):
436        return f"Scope<{self.expression.sql()}>"
437
438    def traverse(self):
439        """
440        Traverse the scope tree from this node.
441
442        Yields:
443            Scope: scope instances in depth-first-search post-order
444        """
445        for child_scope in itertools.chain(
446            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
447        ):
448            yield from child_scope.traverse()
449        yield self
450
451    def ref_count(self):
452        """
453        Count the number of times each scope in this tree is referenced.
454
455        Returns:
456            dict[int, int]: Mapping of Scope instance ID to reference count
457        """
458        scope_ref_count = defaultdict(lambda: 0)
459
460        for scope in self.traverse():
461            for _, source in scope.selected_sources.values():
462                scope_ref_count[id(source)] += 1
463
464        return scope_ref_count

Selection scope.

Attributes:
  • expression (exp.Select|exp.Union): Root expression of this scope
  • sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
  • lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
  • outer_column_list (list[str]): If this is a derived table or CTE, and the outer query defines a column list of it's alias of this scope, this is that list of columns. For example: SELECT * FROM (SELECT ...) AS y(col1, col2) The inner query would have ["col1", "col2"] for its outer_column_list
  • parent (Scope): Parent scope
  • scope_type (ScopeType): Type of this scope, relative to it's parent
  • subquery_scopes (list[Scope]): List of all child scopes for subqueries
  • cte_scopes (list[Scope]): List of all child scopes for CTEs
  • derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
  • udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
  • table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
  • union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
Scope( expression, sources=None, outer_column_list=None, parent=None, scope_type=<ScopeType.ROOT: 1>, lateral_sources=None)
51    def __init__(
52        self,
53        expression,
54        sources=None,
55        outer_column_list=None,
56        parent=None,
57        scope_type=ScopeType.ROOT,
58        lateral_sources=None,
59    ):
60        self.expression = expression
61        self.sources = sources or {}
62        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
63        self.sources.update(self.lateral_sources)
64        self.outer_column_list = outer_column_list or []
65        self.parent = parent
66        self.scope_type = scope_type
67        self.subquery_scopes = []
68        self.derived_table_scopes = []
69        self.table_scopes = []
70        self.cte_scopes = []
71        self.union_scopes = []
72        self.udtf_scopes = []
73        self.clear_cache()
def clear_cache(self):
75    def clear_cache(self):
76        self._collected = False
77        self._raw_columns = None
78        self._derived_tables = None
79        self._udtfs = None
80        self._tables = None
81        self._ctes = None
82        self._subqueries = None
83        self._selected_sources = None
84        self._columns = None
85        self._external_columns = None
86        self._join_hints = None
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
88    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
89        """Branch from the current scope to a new, inner scope"""
90        return Scope(
91            expression=expression.unnest(),
92            sources={**self.cte_sources, **(chain_sources or {})},
93            parent=self,
94            scope_type=scope_type,
95            **kwargs,
96        )

Branch from the current scope to a new, inner scope

def walk(self, bfs=True):
131    def walk(self, bfs=True):
132        return walk_in_scope(self.expression, bfs=bfs)
def find(self, *expression_types, bfs=True):
134    def find(self, *expression_types, bfs=True):
135        """
136        Returns the first node in this scope which matches at least one of the specified types.
137
138        This does NOT traverse into subscopes.
139
140        Args:
141            expression_types (type): the expression type(s) to match.
142            bfs (bool): True to use breadth-first search, False to use depth-first.
143
144        Returns:
145            exp.Expression: the node which matches the criteria or None if no node matching
146            the criteria was found.
147        """
148        return next(self.find_all(*expression_types, bfs=bfs), None)

Returns the first node in this scope which matches at least one of the specified types.

This does NOT traverse into subscopes.

Arguments:
  • expression_types (type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:

exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.

def find_all(self, *expression_types, bfs=True):
150    def find_all(self, *expression_types, bfs=True):
151        """
152        Returns a generator object which visits all nodes in this scope and only yields those that
153        match at least one of the specified expression types.
154
155        This does NOT traverse into subscopes.
156
157        Args:
158            expression_types (type): the expression type(s) to match.
159            bfs (bool): True to use breadth-first search, False to use depth-first.
160
161        Yields:
162            exp.Expression: nodes
163        """
164        for expression, *_ in self.walk(bfs=bfs):
165            if isinstance(expression, expression_types):
166                yield expression

Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.

This does NOT traverse into subscopes.

Arguments:
  • expression_types (type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:

exp.Expression: nodes

def replace(self, old, new):
168    def replace(self, old, new):
169        """
170        Replace `old` with `new`.
171
172        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
173
174        Args:
175            old (exp.Expression): old node
176            new (exp.Expression): new node
177        """
178        old.replace(new)
179        self.clear_cache()

Replace old with new.

This can be used instead of exp.Expression.replace to ensure the Scope is kept up-to-date.

Arguments:
  • old (exp.Expression): old node
  • new (exp.Expression): new node
tables

List of tables in this scope.

Returns:

list[exp.Table]: tables

ctes

List of CTEs in this scope.

Returns:

list[exp.CTE]: ctes

derived_tables

List of derived tables in this scope.

For example:

SELECT * FROM (SELECT ...) <- that's a derived table

Returns:

list[exp.Subquery]: derived tables

udtfs

List of "User Defined Tabular Functions" in this scope.

Returns:

list[exp.UDTF]: UDTFs

subqueries

List of subqueries in this scope.

For example:

SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery

Returns:

list[exp.Subqueryable]: subqueries

columns

List of columns in this scope.

Returns:

list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.

selected_sources

Mapping of nodes and sources that are actually selected from in this scope.

That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.

Returns:

dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes

cte_sources

Sources that are CTEs.

Returns:

dict[str, Scope]: Mapping of source alias to Scope

selects

Select expressions of this scope.

For example, for the following expression: SELECT 1 as a, 2 as b FROM x

The outputs are the "1 as a" and "2 as b" expressions.

Returns:

list[exp.Expression]: expressions

external_columns

Columns that appear to reference sources in outer scopes.

Returns:

list[exp.Column]: Column instances that don't reference sources in the current scope.

unqualified_columns

Unqualified columns in the current scope.

Returns:

list[exp.Column]: Unqualified columns

join_hints

Hints that exist in the scope that reference tables

Returns:

list[exp.JoinHint]: Join hints that are referenced within the scope

def source_columns(self, source_name):
374    def source_columns(self, source_name):
375        """
376        Get all columns in the current scope for a particular source.
377
378        Args:
379            source_name (str): Name of the source
380        Returns:
381            list[exp.Column]: Column instances that reference `source_name`
382        """
383        return [column for column in self.columns if column.table == source_name]

Get all columns in the current scope for a particular source.

Arguments:
  • source_name (str): Name of the source
Returns:

list[exp.Column]: Column instances that reference source_name

is_subquery

Determine if this scope is a subquery

is_derived_table

Determine if this scope is a derived table

is_union

Determine if this scope is a union

is_cte

Determine if this scope is a common table expression

is_root

Determine if this is the root scope

is_udtf

Determine if this scope is a UDTF (User Defined Table Function)

is_correlated_subquery

Determine if this scope is a correlated subquery

def rename_source(self, old_name, new_name):
420    def rename_source(self, old_name, new_name):
421        """Rename a source in this scope"""
422        columns = self.sources.pop(old_name or "", [])
423        self.sources[new_name] = columns

Rename a source in this scope

def add_source(self, name, source):
425    def add_source(self, name, source):
426        """Add a source to this scope"""
427        self.sources[name] = source
428        self.clear_cache()

Add a source to this scope

def remove_source(self, name):
430    def remove_source(self, name):
431        """Remove a source from this scope"""
432        self.sources.pop(name, None)
433        self.clear_cache()

Remove a source from this scope

def traverse(self):
438    def traverse(self):
439        """
440        Traverse the scope tree from this node.
441
442        Yields:
443            Scope: scope instances in depth-first-search post-order
444        """
445        for child_scope in itertools.chain(
446            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
447        ):
448            yield from child_scope.traverse()
449        yield self

Traverse the scope tree from this node.

Yields:

Scope: scope instances in depth-first-search post-order

def ref_count(self):
451    def ref_count(self):
452        """
453        Count the number of times each scope in this tree is referenced.
454
455        Returns:
456            dict[int, int]: Mapping of Scope instance ID to reference count
457        """
458        scope_ref_count = defaultdict(lambda: 0)
459
460        for scope in self.traverse():
461            for _, source in scope.selected_sources.values():
462                scope_ref_count[id(source)] += 1
463
464        return scope_ref_count

Count the number of times each scope in this tree is referenced.

Returns:

dict[int, int]: Mapping of Scope instance ID to reference count

def traverse_scope(expression):
467def traverse_scope(expression):
468    """
469    Traverse an expression by it's "scopes".
470
471    "Scope" represents the current context of a Select statement.
472
473    This is helpful for optimizing queries, where we need more information than
474    the expression tree itself. For example, we might care about the source
475    names within a subquery. Returns a list because a generator could result in
476    incomplete properties which is confusing.
477
478    Examples:
479        >>> import sqlglot
480        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
481        >>> scopes = traverse_scope(expression)
482        >>> scopes[0].expression.sql(), list(scopes[0].sources)
483        ('SELECT a FROM x', ['x'])
484        >>> scopes[1].expression.sql(), list(scopes[1].sources)
485        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
486
487    Args:
488        expression (exp.Expression): expression to traverse
489    Returns:
490        list[Scope]: scope instances
491    """
492    return list(_traverse_scope(Scope(expression)))

Traverse an expression by it's "scopes".

"Scope" represents the current context of a Select statement.

This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
>>> scopes = traverse_scope(expression)
>>> scopes[0].expression.sql(), list(scopes[0].sources)
('SELECT a FROM x', ['x'])
>>> scopes[1].expression.sql(), list(scopes[1].sources)
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
  • expression (exp.Expression): expression to traverse
Returns:

list[Scope]: scope instances

def build_scope(expression):
495def build_scope(expression):
496    """
497    Build a scope tree.
498
499    Args:
500        expression (exp.Expression): expression to build the scope tree for
501    Returns:
502        Scope: root scope
503    """
504    return traverse_scope(expression)[-1]

Build a scope tree.

Arguments:
  • expression (exp.Expression): expression to build the scope tree for
Returns:

Scope: root scope

def walk_in_scope(expression, bfs=True):
649def walk_in_scope(expression, bfs=True):
650    """
651    Returns a generator object which visits all nodes in the syntrax tree, stopping at
652    nodes that start child scopes.
653
654    Args:
655        expression (exp.Expression):
656        bfs (bool): if set to True the BFS traversal order will be applied,
657            otherwise the DFS traversal will be used instead.
658
659    Yields:
660        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
661    """
662    # We'll use this variable to pass state into the dfs generator.
663    # Whenever we set it to True, we exclude a subtree from traversal.
664    prune = False
665
666    for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
667        prune = False
668
669        yield node, parent, key
670
671        if node is expression:
672            continue
673        if (
674            isinstance(node, exp.CTE)
675            or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)))
676            or isinstance(node, exp.UDTF)
677            or isinstance(node, exp.Subqueryable)
678        ):
679            prune = True

Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.

Arguments:
  • expression (exp.Expression):
  • bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
Yields:

tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key