Edit on GitHub

sqlglot.optimizer.scope

  1import itertools
  2import logging
  3import typing as t
  4from collections import defaultdict
  5from enum import Enum, auto
  6
  7from sqlglot import exp
  8from sqlglot.errors import OptimizeError
  9from sqlglot.helper import ensure_collection, find_new_name
 10
 11logger = logging.getLogger("sqlglot")
 12
 13
 14class ScopeType(Enum):
 15    ROOT = auto()
 16    SUBQUERY = auto()
 17    DERIVED_TABLE = auto()
 18    CTE = auto()
 19    UNION = auto()
 20    UDTF = auto()
 21
 22
 23class Scope:
 24    """
 25    Selection scope.
 26
 27    Attributes:
 28        expression (exp.Select|exp.Union): Root expression of this scope
 29        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 30            a Table expression or another Scope instance. For example:
 31                SELECT * FROM x                     {"x": Table(this="x")}
 32                SELECT * FROM x AS y                {"y": Table(this="x")}
 33                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 34        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 35            For example:
 36                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 37            The LATERAL VIEW EXPLODE gets x as a source.
 38        outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
 39            defines a column list of it's alias of this scope, this is that list of columns.
 40            For example:
 41                SELECT * FROM (SELECT ...) AS y(col1, col2)
 42            The inner query would have `["col1", "col2"]` for its `outer_column_list`
 43        parent (Scope): Parent scope
 44        scope_type (ScopeType): Type of this scope, relative to it's parent
 45        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 46        cte_scopes (list[Scope]): List of all child scopes for CTEs
 47        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 48        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 49        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 50        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 51            a list of the left and right child scopes.
 52    """
 53
 54    def __init__(
 55        self,
 56        expression,
 57        sources=None,
 58        outer_column_list=None,
 59        parent=None,
 60        scope_type=ScopeType.ROOT,
 61        lateral_sources=None,
 62    ):
 63        self.expression = expression
 64        self.sources = sources or {}
 65        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
 66        self.sources.update(self.lateral_sources)
 67        self.outer_column_list = outer_column_list or []
 68        self.parent = parent
 69        self.scope_type = scope_type
 70        self.subquery_scopes = []
 71        self.derived_table_scopes = []
 72        self.table_scopes = []
 73        self.cte_scopes = []
 74        self.union_scopes = []
 75        self.udtf_scopes = []
 76        self.clear_cache()
 77
 78    def clear_cache(self):
 79        self._collected = False
 80        self._raw_columns = None
 81        self._derived_tables = None
 82        self._udtfs = None
 83        self._tables = None
 84        self._ctes = None
 85        self._subqueries = None
 86        self._selected_sources = None
 87        self._columns = None
 88        self._external_columns = None
 89        self._join_hints = None
 90        self._pivots = None
 91        self._references = None
 92
 93    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 94        """Branch from the current scope to a new, inner scope"""
 95        return Scope(
 96            expression=expression.unnest(),
 97            sources={**self.cte_sources, **(chain_sources or {})},
 98            parent=self,
 99            scope_type=scope_type,
100            **kwargs,
101        )
102
103    def _collect(self):
104        self._tables = []
105        self._ctes = []
106        self._subqueries = []
107        self._derived_tables = []
108        self._udtfs = []
109        self._raw_columns = []
110        self._join_hints = []
111
112        for node, parent, _ in self.walk(bfs=False):
113            if node is self.expression:
114                continue
115            elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
116                self._raw_columns.append(node)
117            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
118                self._tables.append(node)
119            elif isinstance(node, exp.JoinHint):
120                self._join_hints.append(node)
121            elif isinstance(node, exp.UDTF):
122                self._udtfs.append(node)
123            elif isinstance(node, exp.CTE):
124                self._ctes.append(node)
125            elif (
126                isinstance(node, exp.Subquery)
127                and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
128                and _is_derived_table(node)
129            ):
130                self._derived_tables.append(node)
131            elif isinstance(node, exp.Subqueryable):
132                self._subqueries.append(node)
133
134        self._collected = True
135
136    def _ensure_collected(self):
137        if not self._collected:
138            self._collect()
139
140    def walk(self, bfs=True):
141        return walk_in_scope(self.expression, bfs=bfs)
142
143    def find(self, *expression_types, bfs=True):
144        return find_in_scope(self.expression, expression_types, bfs=bfs)
145
146    def find_all(self, *expression_types, bfs=True):
147        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
148
149    def replace(self, old, new):
150        """
151        Replace `old` with `new`.
152
153        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
154
155        Args:
156            old (exp.Expression): old node
157            new (exp.Expression): new node
158        """
159        old.replace(new)
160        self.clear_cache()
161
162    @property
163    def tables(self):
164        """
165        List of tables in this scope.
166
167        Returns:
168            list[exp.Table]: tables
169        """
170        self._ensure_collected()
171        return self._tables
172
173    @property
174    def ctes(self):
175        """
176        List of CTEs in this scope.
177
178        Returns:
179            list[exp.CTE]: ctes
180        """
181        self._ensure_collected()
182        return self._ctes
183
184    @property
185    def derived_tables(self):
186        """
187        List of derived tables in this scope.
188
189        For example:
190            SELECT * FROM (SELECT ...) <- that's a derived table
191
192        Returns:
193            list[exp.Subquery]: derived tables
194        """
195        self._ensure_collected()
196        return self._derived_tables
197
198    @property
199    def udtfs(self):
200        """
201        List of "User Defined Tabular Functions" in this scope.
202
203        Returns:
204            list[exp.UDTF]: UDTFs
205        """
206        self._ensure_collected()
207        return self._udtfs
208
209    @property
210    def subqueries(self):
211        """
212        List of subqueries in this scope.
213
214        For example:
215            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
216
217        Returns:
218            list[exp.Subqueryable]: subqueries
219        """
220        self._ensure_collected()
221        return self._subqueries
222
223    @property
224    def columns(self):
225        """
226        List of columns in this scope.
227
228        Returns:
229            list[exp.Column]: Column instances in this scope, plus any
230                Columns that reference this scope from correlated subqueries.
231        """
232        if self._columns is None:
233            self._ensure_collected()
234            columns = self._raw_columns
235
236            external_columns = [
237                column
238                for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
239                for column in scope.external_columns
240            ]
241
242            named_selects = set(self.expression.named_selects)
243
244            self._columns = []
245            for column in columns + external_columns:
246                ancestor = column.find_ancestor(
247                    exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table
248                )
249                if (
250                    not ancestor
251                    or column.table
252                    or isinstance(ancestor, exp.Select)
253                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
254                    or (
255                        isinstance(ancestor, exp.Order)
256                        and (
257                            isinstance(ancestor.parent, exp.Window)
258                            or column.name not in named_selects
259                        )
260                    )
261                ):
262                    self._columns.append(column)
263
264        return self._columns
265
266    @property
267    def selected_sources(self):
268        """
269        Mapping of nodes and sources that are actually selected from in this scope.
270
271        That is, all tables in a schema are selectable at any point. But a
272        table only becomes a selected source if it's included in a FROM or JOIN clause.
273
274        Returns:
275            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
276        """
277        if self._selected_sources is None:
278            result = {}
279
280            for name, node in self.references:
281                if name in result:
282                    raise OptimizeError(f"Alias already used: {name}")
283                if name in self.sources:
284                    result[name] = (node, self.sources[name])
285
286            self._selected_sources = result
287        return self._selected_sources
288
289    @property
290    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
291        if self._references is None:
292            self._references = []
293
294            for table in self.tables:
295                self._references.append((table.alias_or_name, table))
296            for expression in itertools.chain(self.derived_tables, self.udtfs):
297                self._references.append(
298                    (
299                        expression.alias,
300                        expression if expression.args.get("pivots") else expression.unnest(),
301                    )
302                )
303
304        return self._references
305
306    @property
307    def cte_sources(self):
308        """
309        Sources that are CTEs.
310
311        Returns:
312            dict[str, Scope]: Mapping of source alias to Scope
313        """
314        return {
315            alias: scope
316            for alias, scope in self.sources.items()
317            if isinstance(scope, Scope) and scope.is_cte
318        }
319
320    @property
321    def external_columns(self):
322        """
323        Columns that appear to reference sources in outer scopes.
324
325        Returns:
326            list[exp.Column]: Column instances that don't reference
327                sources in the current scope.
328        """
329        if self._external_columns is None:
330            self._external_columns = [
331                c for c in self.columns if c.table not in self.selected_sources
332            ]
333        return self._external_columns
334
335    @property
336    def unqualified_columns(self):
337        """
338        Unqualified columns in the current scope.
339
340        Returns:
341             list[exp.Column]: Unqualified columns
342        """
343        return [c for c in self.columns if not c.table]
344
345    @property
346    def join_hints(self):
347        """
348        Hints that exist in the scope that reference tables
349
350        Returns:
351            list[exp.JoinHint]: Join hints that are referenced within the scope
352        """
353        if self._join_hints is None:
354            return []
355        return self._join_hints
356
357    @property
358    def pivots(self):
359        if not self._pivots:
360            self._pivots = [
361                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
362            ]
363
364        return self._pivots
365
366    def source_columns(self, source_name):
367        """
368        Get all columns in the current scope for a particular source.
369
370        Args:
371            source_name (str): Name of the source
372        Returns:
373            list[exp.Column]: Column instances that reference `source_name`
374        """
375        return [column for column in self.columns if column.table == source_name]
376
377    @property
378    def is_subquery(self):
379        """Determine if this scope is a subquery"""
380        return self.scope_type == ScopeType.SUBQUERY
381
382    @property
383    def is_derived_table(self):
384        """Determine if this scope is a derived table"""
385        return self.scope_type == ScopeType.DERIVED_TABLE
386
387    @property
388    def is_union(self):
389        """Determine if this scope is a union"""
390        return self.scope_type == ScopeType.UNION
391
392    @property
393    def is_cte(self):
394        """Determine if this scope is a common table expression"""
395        return self.scope_type == ScopeType.CTE
396
397    @property
398    def is_root(self):
399        """Determine if this is the root scope"""
400        return self.scope_type == ScopeType.ROOT
401
402    @property
403    def is_udtf(self):
404        """Determine if this scope is a UDTF (User Defined Table Function)"""
405        return self.scope_type == ScopeType.UDTF
406
407    @property
408    def is_correlated_subquery(self):
409        """Determine if this scope is a correlated subquery"""
410        return bool(
411            (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
412            and self.external_columns
413        )
414
415    def rename_source(self, old_name, new_name):
416        """Rename a source in this scope"""
417        columns = self.sources.pop(old_name or "", [])
418        self.sources[new_name] = columns
419
420    def add_source(self, name, source):
421        """Add a source to this scope"""
422        self.sources[name] = source
423        self.clear_cache()
424
425    def remove_source(self, name):
426        """Remove a source from this scope"""
427        self.sources.pop(name, None)
428        self.clear_cache()
429
430    def __repr__(self):
431        return f"Scope<{self.expression.sql()}>"
432
433    def traverse(self):
434        """
435        Traverse the scope tree from this node.
436
437        Yields:
438            Scope: scope instances in depth-first-search post-order
439        """
440        for child_scope in itertools.chain(
441            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
442        ):
443            yield from child_scope.traverse()
444        yield self
445
446    def ref_count(self):
447        """
448        Count the number of times each scope in this tree is referenced.
449
450        Returns:
451            dict[int, int]: Mapping of Scope instance ID to reference count
452        """
453        scope_ref_count = defaultdict(lambda: 0)
454
455        for scope in self.traverse():
456            for _, source in scope.selected_sources.values():
457                scope_ref_count[id(source)] += 1
458
459        return scope_ref_count
460
461
462def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
463    """
464    Traverse an expression by its "scopes".
465
466    "Scope" represents the current context of a Select statement.
467
468    This is helpful for optimizing queries, where we need more information than
469    the expression tree itself. For example, we might care about the source
470    names within a subquery. Returns a list because a generator could result in
471    incomplete properties which is confusing.
472
473    Examples:
474        >>> import sqlglot
475        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
476        >>> scopes = traverse_scope(expression)
477        >>> scopes[0].expression.sql(), list(scopes[0].sources)
478        ('SELECT a FROM x', ['x'])
479        >>> scopes[1].expression.sql(), list(scopes[1].sources)
480        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
481
482    Args:
483        expression (exp.Expression): expression to traverse
484    Returns:
485        list[Scope]: scope instances
486    """
487    if isinstance(expression, exp.Unionable) or (
488        isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
489    ):
490        return list(_traverse_scope(Scope(expression)))
491
492    return []
493
494
495def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
496    """
497    Build a scope tree.
498
499    Args:
500        expression (exp.Expression): expression to build the scope tree for
501    Returns:
502        Scope: root scope
503    """
504    scopes = traverse_scope(expression)
505    if scopes:
506        return scopes[-1]
507    return None
508
509
510def _traverse_scope(scope):
511    if isinstance(scope.expression, exp.Select):
512        yield from _traverse_select(scope)
513    elif isinstance(scope.expression, exp.Union):
514        yield from _traverse_union(scope)
515    elif isinstance(scope.expression, exp.Subquery):
516        yield from _traverse_subqueries(scope)
517    elif isinstance(scope.expression, exp.Table):
518        yield from _traverse_tables(scope)
519    elif isinstance(scope.expression, exp.UDTF):
520        yield from _traverse_udtfs(scope)
521    elif isinstance(scope.expression, exp.DDL):
522        yield from _traverse_ddl(scope)
523    else:
524        logger.warning(
525            "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression)
526        )
527        return
528
529    yield scope
530
531
532def _traverse_select(scope):
533    yield from _traverse_ctes(scope)
534    yield from _traverse_tables(scope)
535    yield from _traverse_subqueries(scope)
536
537
538def _traverse_union(scope):
539    yield from _traverse_ctes(scope)
540
541    # The last scope to be yield should be the top most scope
542    left = None
543    for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)):
544        yield left
545
546    right = None
547    for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
548        yield right
549
550    scope.union_scopes = [left, right]
551
552
553def _traverse_ctes(scope):
554    sources = {}
555
556    for cte in scope.ctes:
557        recursive_scope = None
558
559        # if the scope is a recursive cte, it must be in the form of base_case UNION recursive.
560        # thus the recursive scope is the first section of the union.
561        with_ = scope.expression.args.get("with")
562        if with_ and with_.recursive:
563            union = cte.this
564
565            if isinstance(union, exp.Union):
566                recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
567
568        child_scope = None
569
570        for child_scope in _traverse_scope(
571            scope.branch(
572                cte.this,
573                chain_sources=sources,
574                outer_column_list=cte.alias_column_names,
575                scope_type=ScopeType.CTE,
576            )
577        ):
578            yield child_scope
579
580            alias = cte.alias
581            sources[alias] = child_scope
582
583            if recursive_scope:
584                child_scope.add_source(alias, recursive_scope)
585
586        # append the final child_scope yielded
587        if child_scope:
588            scope.cte_scopes.append(child_scope)
589
590    scope.sources.update(sources)
591
592
593def _is_derived_table(expression: exp.Subquery) -> bool:
594    """
595    We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table",
596    as it doesn't introduce a new scope. If an alias is present, it shadows all names
597    under the Subquery, so that's one exception to this rule.
598    """
599    return bool(expression.alias or isinstance(expression.this, exp.Subqueryable))
600
601
602def _traverse_tables(scope):
603    sources = {}
604
605    # Traverse FROMs, JOINs, and LATERALs in the order they are defined
606    expressions = []
607    from_ = scope.expression.args.get("from")
608    if from_:
609        expressions.append(from_.this)
610
611    for join in scope.expression.args.get("joins") or []:
612        expressions.append(join.this)
613
614    if isinstance(scope.expression, exp.Table):
615        expressions.append(scope.expression)
616
617    expressions.extend(scope.expression.args.get("laterals") or [])
618
619    for expression in expressions:
620        if isinstance(expression, exp.Table):
621            table_name = expression.name
622            source_name = expression.alias_or_name
623
624            if table_name in scope.sources and not expression.db:
625                # This is a reference to a parent source (e.g. a CTE), not an actual table, unless
626                # it is pivoted, because then we get back a new table and hence a new source.
627                pivots = expression.args.get("pivots")
628                if pivots:
629                    sources[pivots[0].alias] = expression
630                else:
631                    sources[source_name] = scope.sources[table_name]
632            elif source_name in sources:
633                sources[find_new_name(sources, table_name)] = expression
634            else:
635                sources[source_name] = expression
636
637            # Make sure to not include the joins twice
638            if expression is not scope.expression:
639                expressions.extend(join.this for join in expression.args.get("joins") or [])
640
641            continue
642
643        if not isinstance(expression, exp.DerivedTable):
644            continue
645
646        if isinstance(expression, exp.UDTF):
647            lateral_sources = sources
648            scope_type = ScopeType.UDTF
649            scopes = scope.udtf_scopes
650        elif _is_derived_table(expression):
651            lateral_sources = None
652            scope_type = ScopeType.DERIVED_TABLE
653            scopes = scope.derived_table_scopes
654            expressions.extend(join.this for join in expression.args.get("joins") or [])
655        else:
656            # Makes sure we check for possible sources in nested table constructs
657            expressions.append(expression.this)
658            expressions.extend(join.this for join in expression.args.get("joins") or [])
659            continue
660
661        for child_scope in _traverse_scope(
662            scope.branch(
663                expression,
664                lateral_sources=lateral_sources,
665                outer_column_list=expression.alias_column_names,
666                scope_type=scope_type,
667            )
668        ):
669            yield child_scope
670
671            # Tables without aliases will be set as ""
672            # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
673            # Until then, this means that only a single, unaliased derived table is allowed (rather,
674            # the latest one wins.
675            sources[expression.alias] = child_scope
676
677        # append the final child_scope yielded
678        scopes.append(child_scope)
679        scope.table_scopes.append(child_scope)
680
681    scope.sources.update(sources)
682
683
684def _traverse_subqueries(scope):
685    for subquery in scope.subqueries:
686        top = None
687        for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
688            yield child_scope
689            top = child_scope
690        scope.subquery_scopes.append(top)
691
692
693def _traverse_udtfs(scope):
694    if isinstance(scope.expression, exp.Unnest):
695        expressions = scope.expression.expressions
696    elif isinstance(scope.expression, exp.Lateral):
697        expressions = [scope.expression.this]
698    else:
699        expressions = []
700
701    sources = {}
702    for expression in expressions:
703        if isinstance(expression, exp.Subquery) and _is_derived_table(expression):
704            top = None
705            for child_scope in _traverse_scope(
706                scope.branch(
707                    expression,
708                    scope_type=ScopeType.DERIVED_TABLE,
709                    outer_column_list=expression.alias_column_names,
710                )
711            ):
712                yield child_scope
713                top = child_scope
714                sources[expression.alias] = child_scope
715
716            scope.derived_table_scopes.append(top)
717            scope.table_scopes.append(top)
718
719    scope.sources.update(sources)
720
721
722def _traverse_ddl(scope):
723    yield from _traverse_ctes(scope)
724
725    query_scope = scope.branch(
726        scope.expression.expression, scope_type=ScopeType.DERIVED_TABLE, chain_sources=scope.sources
727    )
728    query_scope._collect()
729    query_scope._ctes = scope.ctes + query_scope._ctes
730
731    yield from _traverse_scope(query_scope)
732
733
734def walk_in_scope(expression, bfs=True):
735    """
736    Returns a generator object which visits all nodes in the syntrax tree, stopping at
737    nodes that start child scopes.
738
739    Args:
740        expression (exp.Expression):
741        bfs (bool): if set to True the BFS traversal order will be applied,
742            otherwise the DFS traversal will be used instead.
743
744    Yields:
745        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
746    """
747    # We'll use this variable to pass state into the dfs generator.
748    # Whenever we set it to True, we exclude a subtree from traversal.
749    prune = False
750
751    for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
752        prune = False
753
754        yield node, parent, key
755
756        if node is expression:
757            continue
758        if (
759            isinstance(node, exp.CTE)
760            or (
761                isinstance(node, exp.Subquery)
762                and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
763                and _is_derived_table(node)
764            )
765            or isinstance(node, exp.UDTF)
766            or isinstance(node, exp.Subqueryable)
767        ):
768            prune = True
769
770            if isinstance(node, (exp.Subquery, exp.UDTF)):
771                # The following args are not actually in the inner scope, so we should visit them
772                for key in ("joins", "laterals", "pivots"):
773                    for arg in node.args.get(key) or []:
774                        yield from walk_in_scope(arg, bfs=bfs)
775
776
777def find_all_in_scope(expression, expression_types, bfs=True):
778    """
779    Returns a generator object which visits all nodes in this scope and only yields those that
780    match at least one of the specified expression types.
781
782    This does NOT traverse into subscopes.
783
784    Args:
785        expression (exp.Expression):
786        expression_types (tuple[type]|type): the expression type(s) to match.
787        bfs (bool): True to use breadth-first search, False to use depth-first.
788
789    Yields:
790        exp.Expression: nodes
791    """
792    for expression, *_ in walk_in_scope(expression, bfs=bfs):
793        if isinstance(expression, tuple(ensure_collection(expression_types))):
794            yield expression
795
796
797def find_in_scope(expression, expression_types, bfs=True):
798    """
799    Returns the first node in this scope which matches at least one of the specified types.
800
801    This does NOT traverse into subscopes.
802
803    Args:
804        expression (exp.Expression):
805        expression_types (tuple[type]|type): the expression type(s) to match.
806        bfs (bool): True to use breadth-first search, False to use depth-first.
807
808    Returns:
809        exp.Expression: the node which matches the criteria or None if no node matching
810        the criteria was found.
811    """
812    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
logger = <Logger sqlglot (WARNING)>
class ScopeType(enum.Enum):
15class ScopeType(Enum):
16    ROOT = auto()
17    SUBQUERY = auto()
18    DERIVED_TABLE = auto()
19    CTE = auto()
20    UNION = auto()
21    UDTF = auto()

An enumeration.

ROOT = <ScopeType.ROOT: 1>
SUBQUERY = <ScopeType.SUBQUERY: 2>
DERIVED_TABLE = <ScopeType.DERIVED_TABLE: 3>
CTE = <ScopeType.CTE: 4>
UNION = <ScopeType.UNION: 5>
UDTF = <ScopeType.UDTF: 6>
Inherited Members
enum.Enum
name
value
class Scope:
 24class Scope:
 25    """
 26    Selection scope.
 27
 28    Attributes:
 29        expression (exp.Select|exp.Union): Root expression of this scope
 30        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 31            a Table expression or another Scope instance. For example:
 32                SELECT * FROM x                     {"x": Table(this="x")}
 33                SELECT * FROM x AS y                {"y": Table(this="x")}
 34                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 35        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 36            For example:
 37                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 38            The LATERAL VIEW EXPLODE gets x as a source.
 39        outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
 40            defines a column list of it's alias of this scope, this is that list of columns.
 41            For example:
 42                SELECT * FROM (SELECT ...) AS y(col1, col2)
 43            The inner query would have `["col1", "col2"]` for its `outer_column_list`
 44        parent (Scope): Parent scope
 45        scope_type (ScopeType): Type of this scope, relative to it's parent
 46        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 47        cte_scopes (list[Scope]): List of all child scopes for CTEs
 48        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 49        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 50        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 51        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 52            a list of the left and right child scopes.
 53    """
 54
 55    def __init__(
 56        self,
 57        expression,
 58        sources=None,
 59        outer_column_list=None,
 60        parent=None,
 61        scope_type=ScopeType.ROOT,
 62        lateral_sources=None,
 63    ):
 64        self.expression = expression
 65        self.sources = sources or {}
 66        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
 67        self.sources.update(self.lateral_sources)
 68        self.outer_column_list = outer_column_list or []
 69        self.parent = parent
 70        self.scope_type = scope_type
 71        self.subquery_scopes = []
 72        self.derived_table_scopes = []
 73        self.table_scopes = []
 74        self.cte_scopes = []
 75        self.union_scopes = []
 76        self.udtf_scopes = []
 77        self.clear_cache()
 78
 79    def clear_cache(self):
 80        self._collected = False
 81        self._raw_columns = None
 82        self._derived_tables = None
 83        self._udtfs = None
 84        self._tables = None
 85        self._ctes = None
 86        self._subqueries = None
 87        self._selected_sources = None
 88        self._columns = None
 89        self._external_columns = None
 90        self._join_hints = None
 91        self._pivots = None
 92        self._references = None
 93
 94    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 95        """Branch from the current scope to a new, inner scope"""
 96        return Scope(
 97            expression=expression.unnest(),
 98            sources={**self.cte_sources, **(chain_sources or {})},
 99            parent=self,
100            scope_type=scope_type,
101            **kwargs,
102        )
103
104    def _collect(self):
105        self._tables = []
106        self._ctes = []
107        self._subqueries = []
108        self._derived_tables = []
109        self._udtfs = []
110        self._raw_columns = []
111        self._join_hints = []
112
113        for node, parent, _ in self.walk(bfs=False):
114            if node is self.expression:
115                continue
116            elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
117                self._raw_columns.append(node)
118            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
119                self._tables.append(node)
120            elif isinstance(node, exp.JoinHint):
121                self._join_hints.append(node)
122            elif isinstance(node, exp.UDTF):
123                self._udtfs.append(node)
124            elif isinstance(node, exp.CTE):
125                self._ctes.append(node)
126            elif (
127                isinstance(node, exp.Subquery)
128                and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
129                and _is_derived_table(node)
130            ):
131                self._derived_tables.append(node)
132            elif isinstance(node, exp.Subqueryable):
133                self._subqueries.append(node)
134
135        self._collected = True
136
137    def _ensure_collected(self):
138        if not self._collected:
139            self._collect()
140
141    def walk(self, bfs=True):
142        return walk_in_scope(self.expression, bfs=bfs)
143
144    def find(self, *expression_types, bfs=True):
145        return find_in_scope(self.expression, expression_types, bfs=bfs)
146
147    def find_all(self, *expression_types, bfs=True):
148        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
149
150    def replace(self, old, new):
151        """
152        Replace `old` with `new`.
153
154        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
155
156        Args:
157            old (exp.Expression): old node
158            new (exp.Expression): new node
159        """
160        old.replace(new)
161        self.clear_cache()
162
163    @property
164    def tables(self):
165        """
166        List of tables in this scope.
167
168        Returns:
169            list[exp.Table]: tables
170        """
171        self._ensure_collected()
172        return self._tables
173
174    @property
175    def ctes(self):
176        """
177        List of CTEs in this scope.
178
179        Returns:
180            list[exp.CTE]: ctes
181        """
182        self._ensure_collected()
183        return self._ctes
184
185    @property
186    def derived_tables(self):
187        """
188        List of derived tables in this scope.
189
190        For example:
191            SELECT * FROM (SELECT ...) <- that's a derived table
192
193        Returns:
194            list[exp.Subquery]: derived tables
195        """
196        self._ensure_collected()
197        return self._derived_tables
198
199    @property
200    def udtfs(self):
201        """
202        List of "User Defined Tabular Functions" in this scope.
203
204        Returns:
205            list[exp.UDTF]: UDTFs
206        """
207        self._ensure_collected()
208        return self._udtfs
209
210    @property
211    def subqueries(self):
212        """
213        List of subqueries in this scope.
214
215        For example:
216            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
217
218        Returns:
219            list[exp.Subqueryable]: subqueries
220        """
221        self._ensure_collected()
222        return self._subqueries
223
224    @property
225    def columns(self):
226        """
227        List of columns in this scope.
228
229        Returns:
230            list[exp.Column]: Column instances in this scope, plus any
231                Columns that reference this scope from correlated subqueries.
232        """
233        if self._columns is None:
234            self._ensure_collected()
235            columns = self._raw_columns
236
237            external_columns = [
238                column
239                for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
240                for column in scope.external_columns
241            ]
242
243            named_selects = set(self.expression.named_selects)
244
245            self._columns = []
246            for column in columns + external_columns:
247                ancestor = column.find_ancestor(
248                    exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table
249                )
250                if (
251                    not ancestor
252                    or column.table
253                    or isinstance(ancestor, exp.Select)
254                    or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func))
255                    or (
256                        isinstance(ancestor, exp.Order)
257                        and (
258                            isinstance(ancestor.parent, exp.Window)
259                            or column.name not in named_selects
260                        )
261                    )
262                ):
263                    self._columns.append(column)
264
265        return self._columns
266
267    @property
268    def selected_sources(self):
269        """
270        Mapping of nodes and sources that are actually selected from in this scope.
271
272        That is, all tables in a schema are selectable at any point. But a
273        table only becomes a selected source if it's included in a FROM or JOIN clause.
274
275        Returns:
276            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
277        """
278        if self._selected_sources is None:
279            result = {}
280
281            for name, node in self.references:
282                if name in result:
283                    raise OptimizeError(f"Alias already used: {name}")
284                if name in self.sources:
285                    result[name] = (node, self.sources[name])
286
287            self._selected_sources = result
288        return self._selected_sources
289
290    @property
291    def references(self) -> t.List[t.Tuple[str, exp.Expression]]:
292        if self._references is None:
293            self._references = []
294
295            for table in self.tables:
296                self._references.append((table.alias_or_name, table))
297            for expression in itertools.chain(self.derived_tables, self.udtfs):
298                self._references.append(
299                    (
300                        expression.alias,
301                        expression if expression.args.get("pivots") else expression.unnest(),
302                    )
303                )
304
305        return self._references
306
307    @property
308    def cte_sources(self):
309        """
310        Sources that are CTEs.
311
312        Returns:
313            dict[str, Scope]: Mapping of source alias to Scope
314        """
315        return {
316            alias: scope
317            for alias, scope in self.sources.items()
318            if isinstance(scope, Scope) and scope.is_cte
319        }
320
321    @property
322    def external_columns(self):
323        """
324        Columns that appear to reference sources in outer scopes.
325
326        Returns:
327            list[exp.Column]: Column instances that don't reference
328                sources in the current scope.
329        """
330        if self._external_columns is None:
331            self._external_columns = [
332                c for c in self.columns if c.table not in self.selected_sources
333            ]
334        return self._external_columns
335
336    @property
337    def unqualified_columns(self):
338        """
339        Unqualified columns in the current scope.
340
341        Returns:
342             list[exp.Column]: Unqualified columns
343        """
344        return [c for c in self.columns if not c.table]
345
346    @property
347    def join_hints(self):
348        """
349        Hints that exist in the scope that reference tables
350
351        Returns:
352            list[exp.JoinHint]: Join hints that are referenced within the scope
353        """
354        if self._join_hints is None:
355            return []
356        return self._join_hints
357
358    @property
359    def pivots(self):
360        if not self._pivots:
361            self._pivots = [
362                pivot for _, node in self.references for pivot in node.args.get("pivots") or []
363            ]
364
365        return self._pivots
366
367    def source_columns(self, source_name):
368        """
369        Get all columns in the current scope for a particular source.
370
371        Args:
372            source_name (str): Name of the source
373        Returns:
374            list[exp.Column]: Column instances that reference `source_name`
375        """
376        return [column for column in self.columns if column.table == source_name]
377
378    @property
379    def is_subquery(self):
380        """Determine if this scope is a subquery"""
381        return self.scope_type == ScopeType.SUBQUERY
382
383    @property
384    def is_derived_table(self):
385        """Determine if this scope is a derived table"""
386        return self.scope_type == ScopeType.DERIVED_TABLE
387
388    @property
389    def is_union(self):
390        """Determine if this scope is a union"""
391        return self.scope_type == ScopeType.UNION
392
393    @property
394    def is_cte(self):
395        """Determine if this scope is a common table expression"""
396        return self.scope_type == ScopeType.CTE
397
398    @property
399    def is_root(self):
400        """Determine if this is the root scope"""
401        return self.scope_type == ScopeType.ROOT
402
403    @property
404    def is_udtf(self):
405        """Determine if this scope is a UDTF (User Defined Table Function)"""
406        return self.scope_type == ScopeType.UDTF
407
408    @property
409    def is_correlated_subquery(self):
410        """Determine if this scope is a correlated subquery"""
411        return bool(
412            (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral)))
413            and self.external_columns
414        )
415
416    def rename_source(self, old_name, new_name):
417        """Rename a source in this scope"""
418        columns = self.sources.pop(old_name or "", [])
419        self.sources[new_name] = columns
420
421    def add_source(self, name, source):
422        """Add a source to this scope"""
423        self.sources[name] = source
424        self.clear_cache()
425
426    def remove_source(self, name):
427        """Remove a source from this scope"""
428        self.sources.pop(name, None)
429        self.clear_cache()
430
431    def __repr__(self):
432        return f"Scope<{self.expression.sql()}>"
433
434    def traverse(self):
435        """
436        Traverse the scope tree from this node.
437
438        Yields:
439            Scope: scope instances in depth-first-search post-order
440        """
441        for child_scope in itertools.chain(
442            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
443        ):
444            yield from child_scope.traverse()
445        yield self
446
447    def ref_count(self):
448        """
449        Count the number of times each scope in this tree is referenced.
450
451        Returns:
452            dict[int, int]: Mapping of Scope instance ID to reference count
453        """
454        scope_ref_count = defaultdict(lambda: 0)
455
456        for scope in self.traverse():
457            for _, source in scope.selected_sources.values():
458                scope_ref_count[id(source)] += 1
459
460        return scope_ref_count

Selection scope.

Attributes:
  • expression (exp.Select|exp.Union): Root expression of this scope
  • sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
  • lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
  • outer_column_list (list[str]): If this is a derived table or CTE, and the outer query defines a column list of it's alias of this scope, this is that list of columns. For example: SELECT * FROM (SELECT ...) AS y(col1, col2) The inner query would have ["col1", "col2"] for its outer_column_list
  • parent (Scope): Parent scope
  • scope_type (ScopeType): Type of this scope, relative to it's parent
  • subquery_scopes (list[Scope]): List of all child scopes for subqueries
  • cte_scopes (list[Scope]): List of all child scopes for CTEs
  • derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
  • udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
  • table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
  • union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
Scope( expression, sources=None, outer_column_list=None, parent=None, scope_type=<ScopeType.ROOT: 1>, lateral_sources=None)
55    def __init__(
56        self,
57        expression,
58        sources=None,
59        outer_column_list=None,
60        parent=None,
61        scope_type=ScopeType.ROOT,
62        lateral_sources=None,
63    ):
64        self.expression = expression
65        self.sources = sources or {}
66        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
67        self.sources.update(self.lateral_sources)
68        self.outer_column_list = outer_column_list or []
69        self.parent = parent
70        self.scope_type = scope_type
71        self.subquery_scopes = []
72        self.derived_table_scopes = []
73        self.table_scopes = []
74        self.cte_scopes = []
75        self.union_scopes = []
76        self.udtf_scopes = []
77        self.clear_cache()
expression
sources
lateral_sources
outer_column_list
parent
scope_type
subquery_scopes
derived_table_scopes
table_scopes
cte_scopes
union_scopes
udtf_scopes
def clear_cache(self):
79    def clear_cache(self):
80        self._collected = False
81        self._raw_columns = None
82        self._derived_tables = None
83        self._udtfs = None
84        self._tables = None
85        self._ctes = None
86        self._subqueries = None
87        self._selected_sources = None
88        self._columns = None
89        self._external_columns = None
90        self._join_hints = None
91        self._pivots = None
92        self._references = None
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 94    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 95        """Branch from the current scope to a new, inner scope"""
 96        return Scope(
 97            expression=expression.unnest(),
 98            sources={**self.cte_sources, **(chain_sources or {})},
 99            parent=self,
100            scope_type=scope_type,
101            **kwargs,
102        )

Branch from the current scope to a new, inner scope

def walk(self, bfs=True):
141    def walk(self, bfs=True):
142        return walk_in_scope(self.expression, bfs=bfs)
def find(self, *expression_types, bfs=True):
144    def find(self, *expression_types, bfs=True):
145        return find_in_scope(self.expression, expression_types, bfs=bfs)
def find_all(self, *expression_types, bfs=True):
147    def find_all(self, *expression_types, bfs=True):
148        return find_all_in_scope(self.expression, expression_types, bfs=bfs)
def replace(self, old, new):
150    def replace(self, old, new):
151        """
152        Replace `old` with `new`.
153
154        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
155
156        Args:
157            old (exp.Expression): old node
158            new (exp.Expression): new node
159        """
160        old.replace(new)
161        self.clear_cache()

Replace old with new.

This can be used instead of exp.Expression.replace to ensure the Scope is kept up-to-date.

Arguments:
  • old (exp.Expression): old node
  • new (exp.Expression): new node
tables

List of tables in this scope.

Returns:

list[exp.Table]: tables

ctes

List of CTEs in this scope.

Returns:

list[exp.CTE]: ctes

derived_tables

List of derived tables in this scope.

For example:

SELECT * FROM (SELECT ...) <- that's a derived table

Returns:

list[exp.Subquery]: derived tables

udtfs

List of "User Defined Tabular Functions" in this scope.

Returns:

list[exp.UDTF]: UDTFs

subqueries

List of subqueries in this scope.

For example:

SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery

Returns:

list[exp.Subqueryable]: subqueries

columns

List of columns in this scope.

Returns:

list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.

selected_sources

Mapping of nodes and sources that are actually selected from in this scope.

That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.

Returns:

dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes

references: List[Tuple[str, sqlglot.expressions.Expression]]
cte_sources

Sources that are CTEs.

Returns:

dict[str, Scope]: Mapping of source alias to Scope

external_columns

Columns that appear to reference sources in outer scopes.

Returns:

list[exp.Column]: Column instances that don't reference sources in the current scope.

unqualified_columns

Unqualified columns in the current scope.

Returns:

list[exp.Column]: Unqualified columns

join_hints

Hints that exist in the scope that reference tables

Returns:

list[exp.JoinHint]: Join hints that are referenced within the scope

pivots
def source_columns(self, source_name):
367    def source_columns(self, source_name):
368        """
369        Get all columns in the current scope for a particular source.
370
371        Args:
372            source_name (str): Name of the source
373        Returns:
374            list[exp.Column]: Column instances that reference `source_name`
375        """
376        return [column for column in self.columns if column.table == source_name]

Get all columns in the current scope for a particular source.

Arguments:
  • source_name (str): Name of the source
Returns:

list[exp.Column]: Column instances that reference source_name

is_subquery

Determine if this scope is a subquery

is_derived_table

Determine if this scope is a derived table

is_union

Determine if this scope is a union

is_cte

Determine if this scope is a common table expression

is_root

Determine if this is the root scope

is_udtf

Determine if this scope is a UDTF (User Defined Table Function)

is_correlated_subquery

Determine if this scope is a correlated subquery

def rename_source(self, old_name, new_name):
416    def rename_source(self, old_name, new_name):
417        """Rename a source in this scope"""
418        columns = self.sources.pop(old_name or "", [])
419        self.sources[new_name] = columns

Rename a source in this scope

def add_source(self, name, source):
421    def add_source(self, name, source):
422        """Add a source to this scope"""
423        self.sources[name] = source
424        self.clear_cache()

Add a source to this scope

def remove_source(self, name):
426    def remove_source(self, name):
427        """Remove a source from this scope"""
428        self.sources.pop(name, None)
429        self.clear_cache()

Remove a source from this scope

def traverse(self):
434    def traverse(self):
435        """
436        Traverse the scope tree from this node.
437
438        Yields:
439            Scope: scope instances in depth-first-search post-order
440        """
441        for child_scope in itertools.chain(
442            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
443        ):
444            yield from child_scope.traverse()
445        yield self

Traverse the scope tree from this node.

Yields:

Scope: scope instances in depth-first-search post-order

def ref_count(self):
447    def ref_count(self):
448        """
449        Count the number of times each scope in this tree is referenced.
450
451        Returns:
452            dict[int, int]: Mapping of Scope instance ID to reference count
453        """
454        scope_ref_count = defaultdict(lambda: 0)
455
456        for scope in self.traverse():
457            for _, source in scope.selected_sources.values():
458                scope_ref_count[id(source)] += 1
459
460        return scope_ref_count

Count the number of times each scope in this tree is referenced.

Returns:

dict[int, int]: Mapping of Scope instance ID to reference count

def traverse_scope( expression: sqlglot.expressions.Expression) -> List[sqlglot.optimizer.scope.Scope]:
463def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
464    """
465    Traverse an expression by its "scopes".
466
467    "Scope" represents the current context of a Select statement.
468
469    This is helpful for optimizing queries, where we need more information than
470    the expression tree itself. For example, we might care about the source
471    names within a subquery. Returns a list because a generator could result in
472    incomplete properties which is confusing.
473
474    Examples:
475        >>> import sqlglot
476        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
477        >>> scopes = traverse_scope(expression)
478        >>> scopes[0].expression.sql(), list(scopes[0].sources)
479        ('SELECT a FROM x', ['x'])
480        >>> scopes[1].expression.sql(), list(scopes[1].sources)
481        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
482
483    Args:
484        expression (exp.Expression): expression to traverse
485    Returns:
486        list[Scope]: scope instances
487    """
488    if isinstance(expression, exp.Unionable) or (
489        isinstance(expression, exp.DDL) and isinstance(expression.expression, exp.Subqueryable)
490    ):
491        return list(_traverse_scope(Scope(expression)))
492
493    return []

Traverse an expression by its "scopes".

"Scope" represents the current context of a Select statement.

This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
>>> scopes = traverse_scope(expression)
>>> scopes[0].expression.sql(), list(scopes[0].sources)
('SELECT a FROM x', ['x'])
>>> scopes[1].expression.sql(), list(scopes[1].sources)
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
  • expression (exp.Expression): expression to traverse
Returns:

list[Scope]: scope instances

def build_scope( expression: sqlglot.expressions.Expression) -> Optional[sqlglot.optimizer.scope.Scope]:
496def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
497    """
498    Build a scope tree.
499
500    Args:
501        expression (exp.Expression): expression to build the scope tree for
502    Returns:
503        Scope: root scope
504    """
505    scopes = traverse_scope(expression)
506    if scopes:
507        return scopes[-1]
508    return None

Build a scope tree.

Arguments:
  • expression (exp.Expression): expression to build the scope tree for
Returns:

Scope: root scope

def walk_in_scope(expression, bfs=True):
735def walk_in_scope(expression, bfs=True):
736    """
737    Returns a generator object which visits all nodes in the syntrax tree, stopping at
738    nodes that start child scopes.
739
740    Args:
741        expression (exp.Expression):
742        bfs (bool): if set to True the BFS traversal order will be applied,
743            otherwise the DFS traversal will be used instead.
744
745    Yields:
746        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
747    """
748    # We'll use this variable to pass state into the dfs generator.
749    # Whenever we set it to True, we exclude a subtree from traversal.
750    prune = False
751
752    for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
753        prune = False
754
755        yield node, parent, key
756
757        if node is expression:
758            continue
759        if (
760            isinstance(node, exp.CTE)
761            or (
762                isinstance(node, exp.Subquery)
763                and isinstance(parent, (exp.From, exp.Join, exp.Subquery))
764                and _is_derived_table(node)
765            )
766            or isinstance(node, exp.UDTF)
767            or isinstance(node, exp.Subqueryable)
768        ):
769            prune = True
770
771            if isinstance(node, (exp.Subquery, exp.UDTF)):
772                # The following args are not actually in the inner scope, so we should visit them
773                for key in ("joins", "laterals", "pivots"):
774                    for arg in node.args.get(key) or []:
775                        yield from walk_in_scope(arg, bfs=bfs)

Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.

Arguments:
  • expression (exp.Expression):
  • bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
Yields:

tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key

def find_all_in_scope(expression, expression_types, bfs=True):
778def find_all_in_scope(expression, expression_types, bfs=True):
779    """
780    Returns a generator object which visits all nodes in this scope and only yields those that
781    match at least one of the specified expression types.
782
783    This does NOT traverse into subscopes.
784
785    Args:
786        expression (exp.Expression):
787        expression_types (tuple[type]|type): the expression type(s) to match.
788        bfs (bool): True to use breadth-first search, False to use depth-first.
789
790    Yields:
791        exp.Expression: nodes
792    """
793    for expression, *_ in walk_in_scope(expression, bfs=bfs):
794        if isinstance(expression, tuple(ensure_collection(expression_types))):
795            yield expression

Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:

exp.Expression: nodes

def find_in_scope(expression, expression_types, bfs=True):
798def find_in_scope(expression, expression_types, bfs=True):
799    """
800    Returns the first node in this scope which matches at least one of the specified types.
801
802    This does NOT traverse into subscopes.
803
804    Args:
805        expression (exp.Expression):
806        expression_types (tuple[type]|type): the expression type(s) to match.
807        bfs (bool): True to use breadth-first search, False to use depth-first.
808
809    Returns:
810        exp.Expression: the node which matches the criteria or None if no node matching
811        the criteria was found.
812    """
813    return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)

Returns the first node in this scope which matches at least one of the specified types.

This does NOT traverse into subscopes.

Arguments:
  • expression (exp.Expression):
  • expression_types (tuple[type]|type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:

exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.