Edit on GitHub

sqlglot.optimizer.scope

  1import itertools
  2import typing as t
  3from collections import defaultdict
  4from enum import Enum, auto
  5
  6from sqlglot import exp
  7from sqlglot.errors import OptimizeError
  8from sqlglot.helper import find_new_name
  9
 10
 11class ScopeType(Enum):
 12    ROOT = auto()
 13    SUBQUERY = auto()
 14    DERIVED_TABLE = auto()
 15    CTE = auto()
 16    UNION = auto()
 17    UDTF = auto()
 18
 19
 20class Scope:
 21    """
 22    Selection scope.
 23
 24    Attributes:
 25        expression (exp.Select|exp.Union): Root expression of this scope
 26        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 27            a Table expression or another Scope instance. For example:
 28                SELECT * FROM x                     {"x": Table(this="x")}
 29                SELECT * FROM x AS y                {"y": Table(this="x")}
 30                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 31        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 32            For example:
 33                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 34            The LATERAL VIEW EXPLODE gets x as a source.
 35        outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
 36            defines a column list of it's alias of this scope, this is that list of columns.
 37            For example:
 38                SELECT * FROM (SELECT ...) AS y(col1, col2)
 39            The inner query would have `["col1", "col2"]` for its `outer_column_list`
 40        parent (Scope): Parent scope
 41        scope_type (ScopeType): Type of this scope, relative to it's parent
 42        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 43        cte_scopes (list[Scope]): List of all child scopes for CTEs
 44        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 45        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 46        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 47        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 48            a list of the left and right child scopes.
 49    """
 50
 51    def __init__(
 52        self,
 53        expression,
 54        sources=None,
 55        outer_column_list=None,
 56        parent=None,
 57        scope_type=ScopeType.ROOT,
 58        lateral_sources=None,
 59    ):
 60        self.expression = expression
 61        self.sources = sources or {}
 62        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
 63        self.sources.update(self.lateral_sources)
 64        self.outer_column_list = outer_column_list or []
 65        self.parent = parent
 66        self.scope_type = scope_type
 67        self.subquery_scopes = []
 68        self.derived_table_scopes = []
 69        self.table_scopes = []
 70        self.cte_scopes = []
 71        self.union_scopes = []
 72        self.udtf_scopes = []
 73        self.clear_cache()
 74
 75    def clear_cache(self):
 76        self._collected = False
 77        self._raw_columns = None
 78        self._derived_tables = None
 79        self._udtfs = None
 80        self._tables = None
 81        self._ctes = None
 82        self._subqueries = None
 83        self._selected_sources = None
 84        self._columns = None
 85        self._external_columns = None
 86        self._join_hints = None
 87        self._pivots = None
 88
 89    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 90        """Branch from the current scope to a new, inner scope"""
 91        return Scope(
 92            expression=expression.unnest(),
 93            sources={**self.cte_sources, **(chain_sources or {})},
 94            parent=self,
 95            scope_type=scope_type,
 96            **kwargs,
 97        )
 98
 99    def _collect(self):
100        self._tables = []
101        self._ctes = []
102        self._subqueries = []
103        self._derived_tables = []
104        self._udtfs = []
105        self._raw_columns = []
106        self._join_hints = []
107
108        for node, parent, _ in self.walk(bfs=False):
109            if node is self.expression:
110                continue
111            elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
112                self._raw_columns.append(node)
113            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
114                self._tables.append(node)
115            elif isinstance(node, exp.JoinHint):
116                self._join_hints.append(node)
117            elif isinstance(node, exp.UDTF):
118                self._udtfs.append(node)
119            elif isinstance(node, exp.CTE):
120                self._ctes.append(node)
121            elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
122                self._derived_tables.append(node)
123            elif isinstance(node, exp.Subqueryable):
124                self._subqueries.append(node)
125
126        self._collected = True
127
128    def _ensure_collected(self):
129        if not self._collected:
130            self._collect()
131
132    def walk(self, bfs=True):
133        return walk_in_scope(self.expression, bfs=bfs)
134
135    def find(self, *expression_types, bfs=True):
136        """
137        Returns the first node in this scope which matches at least one of the specified types.
138
139        This does NOT traverse into subscopes.
140
141        Args:
142            expression_types (type): the expression type(s) to match.
143            bfs (bool): True to use breadth-first search, False to use depth-first.
144
145        Returns:
146            exp.Expression: the node which matches the criteria or None if no node matching
147            the criteria was found.
148        """
149        return next(self.find_all(*expression_types, bfs=bfs), None)
150
151    def find_all(self, *expression_types, bfs=True):
152        """
153        Returns a generator object which visits all nodes in this scope and only yields those that
154        match at least one of the specified expression types.
155
156        This does NOT traverse into subscopes.
157
158        Args:
159            expression_types (type): the expression type(s) to match.
160            bfs (bool): True to use breadth-first search, False to use depth-first.
161
162        Yields:
163            exp.Expression: nodes
164        """
165        for expression, *_ in self.walk(bfs=bfs):
166            if isinstance(expression, expression_types):
167                yield expression
168
169    def replace(self, old, new):
170        """
171        Replace `old` with `new`.
172
173        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
174
175        Args:
176            old (exp.Expression): old node
177            new (exp.Expression): new node
178        """
179        old.replace(new)
180        self.clear_cache()
181
182    @property
183    def tables(self):
184        """
185        List of tables in this scope.
186
187        Returns:
188            list[exp.Table]: tables
189        """
190        self._ensure_collected()
191        return self._tables
192
193    @property
194    def ctes(self):
195        """
196        List of CTEs in this scope.
197
198        Returns:
199            list[exp.CTE]: ctes
200        """
201        self._ensure_collected()
202        return self._ctes
203
204    @property
205    def derived_tables(self):
206        """
207        List of derived tables in this scope.
208
209        For example:
210            SELECT * FROM (SELECT ...) <- that's a derived table
211
212        Returns:
213            list[exp.Subquery]: derived tables
214        """
215        self._ensure_collected()
216        return self._derived_tables
217
218    @property
219    def udtfs(self):
220        """
221        List of "User Defined Tabular Functions" in this scope.
222
223        Returns:
224            list[exp.UDTF]: UDTFs
225        """
226        self._ensure_collected()
227        return self._udtfs
228
229    @property
230    def subqueries(self):
231        """
232        List of subqueries in this scope.
233
234        For example:
235            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
236
237        Returns:
238            list[exp.Subqueryable]: subqueries
239        """
240        self._ensure_collected()
241        return self._subqueries
242
243    @property
244    def columns(self):
245        """
246        List of columns in this scope.
247
248        Returns:
249            list[exp.Column]: Column instances in this scope, plus any
250                Columns that reference this scope from correlated subqueries.
251        """
252        if self._columns is None:
253            self._ensure_collected()
254            columns = self._raw_columns
255
256            external_columns = [
257                column
258                for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
259                for column in scope.external_columns
260            ]
261
262            named_selects = set(self.expression.named_selects)
263
264            self._columns = []
265            for column in columns + external_columns:
266                ancestor = column.find_ancestor(
267                    exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint
268                )
269                if (
270                    not ancestor
271                    or column.table
272                    or isinstance(ancestor, exp.Select)
273                    or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window))
274                    or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
275                ):
276                    self._columns.append(column)
277
278        return self._columns
279
280    @property
281    def selected_sources(self):
282        """
283        Mapping of nodes and sources that are actually selected from in this scope.
284
285        That is, all tables in a schema are selectable at any point. But a
286        table only becomes a selected source if it's included in a FROM or JOIN clause.
287
288        Returns:
289            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
290        """
291        if self._selected_sources is None:
292            referenced_names = []
293
294            for table in self.tables:
295                referenced_names.append((table.alias_or_name, table))
296            for expression in itertools.chain(self.derived_tables, self.udtfs):
297                referenced_names.append((expression.alias, expression.unnest()))
298            result = {}
299
300            for name, node in referenced_names:
301                if name in result:
302                    raise OptimizeError(f"Alias already used: {name}")
303                if name in self.sources:
304                    result[name] = (node, self.sources[name])
305
306            self._selected_sources = result
307        return self._selected_sources
308
309    @property
310    def cte_sources(self):
311        """
312        Sources that are CTEs.
313
314        Returns:
315            dict[str, Scope]: Mapping of source alias to Scope
316        """
317        return {
318            alias: scope
319            for alias, scope in self.sources.items()
320            if isinstance(scope, Scope) and scope.is_cte
321        }
322
323    @property
324    def selects(self):
325        """
326        Select expressions of this scope.
327
328        For example, for the following expression:
329            SELECT 1 as a, 2 as b FROM x
330
331        The outputs are the "1 as a" and "2 as b" expressions.
332
333        Returns:
334            list[exp.Expression]: expressions
335        """
336        if isinstance(self.expression, exp.Union):
337            return self.expression.unnest().selects
338        return self.expression.selects
339
340    @property
341    def external_columns(self):
342        """
343        Columns that appear to reference sources in outer scopes.
344
345        Returns:
346            list[exp.Column]: Column instances that don't reference
347                sources in the current scope.
348        """
349        if self._external_columns is None:
350            self._external_columns = [
351                c for c in self.columns if c.table not in self.selected_sources
352            ]
353        return self._external_columns
354
355    @property
356    def unqualified_columns(self):
357        """
358        Unqualified columns in the current scope.
359
360        Returns:
361             list[exp.Column]: Unqualified columns
362        """
363        return [c for c in self.columns if not c.table]
364
365    @property
366    def join_hints(self):
367        """
368        Hints that exist in the scope that reference tables
369
370        Returns:
371            list[exp.JoinHint]: Join hints that are referenced within the scope
372        """
373        if self._join_hints is None:
374            return []
375        return self._join_hints
376
377    @property
378    def pivots(self):
379        if not self._pivots:
380            self._pivots = [
381                pivot
382                for node in self.tables + self.derived_tables
383                for pivot in node.args.get("pivots") or []
384            ]
385
386        return self._pivots
387
388    def source_columns(self, source_name):
389        """
390        Get all columns in the current scope for a particular source.
391
392        Args:
393            source_name (str): Name of the source
394        Returns:
395            list[exp.Column]: Column instances that reference `source_name`
396        """
397        return [column for column in self.columns if column.table == source_name]
398
399    @property
400    def is_subquery(self):
401        """Determine if this scope is a subquery"""
402        return self.scope_type == ScopeType.SUBQUERY
403
404    @property
405    def is_derived_table(self):
406        """Determine if this scope is a derived table"""
407        return self.scope_type == ScopeType.DERIVED_TABLE
408
409    @property
410    def is_union(self):
411        """Determine if this scope is a union"""
412        return self.scope_type == ScopeType.UNION
413
414    @property
415    def is_cte(self):
416        """Determine if this scope is a common table expression"""
417        return self.scope_type == ScopeType.CTE
418
419    @property
420    def is_root(self):
421        """Determine if this is the root scope"""
422        return self.scope_type == ScopeType.ROOT
423
424    @property
425    def is_udtf(self):
426        """Determine if this scope is a UDTF (User Defined Table Function)"""
427        return self.scope_type == ScopeType.UDTF
428
429    @property
430    def is_correlated_subquery(self):
431        """Determine if this scope is a correlated subquery"""
432        return bool(self.is_subquery and self.external_columns)
433
434    def rename_source(self, old_name, new_name):
435        """Rename a source in this scope"""
436        columns = self.sources.pop(old_name or "", [])
437        self.sources[new_name] = columns
438
439    def add_source(self, name, source):
440        """Add a source to this scope"""
441        self.sources[name] = source
442        self.clear_cache()
443
444    def remove_source(self, name):
445        """Remove a source from this scope"""
446        self.sources.pop(name, None)
447        self.clear_cache()
448
449    def __repr__(self):
450        return f"Scope<{self.expression.sql()}>"
451
452    def traverse(self):
453        """
454        Traverse the scope tree from this node.
455
456        Yields:
457            Scope: scope instances in depth-first-search post-order
458        """
459        for child_scope in itertools.chain(
460            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
461        ):
462            yield from child_scope.traverse()
463        yield self
464
465    def ref_count(self):
466        """
467        Count the number of times each scope in this tree is referenced.
468
469        Returns:
470            dict[int, int]: Mapping of Scope instance ID to reference count
471        """
472        scope_ref_count = defaultdict(lambda: 0)
473
474        for scope in self.traverse():
475            for _, source in scope.selected_sources.values():
476                scope_ref_count[id(source)] += 1
477
478        return scope_ref_count
479
480
481def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
482    """
483    Traverse an expression by it's "scopes".
484
485    "Scope" represents the current context of a Select statement.
486
487    This is helpful for optimizing queries, where we need more information than
488    the expression tree itself. For example, we might care about the source
489    names within a subquery. Returns a list because a generator could result in
490    incomplete properties which is confusing.
491
492    Examples:
493        >>> import sqlglot
494        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
495        >>> scopes = traverse_scope(expression)
496        >>> scopes[0].expression.sql(), list(scopes[0].sources)
497        ('SELECT a FROM x', ['x'])
498        >>> scopes[1].expression.sql(), list(scopes[1].sources)
499        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
500
501    Args:
502        expression (exp.Expression): expression to traverse
503    Returns:
504        list[Scope]: scope instances
505    """
506    if not isinstance(expression, exp.Unionable):
507        return []
508    return list(_traverse_scope(Scope(expression)))
509
510
511def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
512    """
513    Build a scope tree.
514
515    Args:
516        expression (exp.Expression): expression to build the scope tree for
517    Returns:
518        Scope: root scope
519    """
520    scopes = traverse_scope(expression)
521    if scopes:
522        return scopes[-1]
523    return None
524
525
526def _traverse_scope(scope):
527    if isinstance(scope.expression, exp.Select):
528        yield from _traverse_select(scope)
529    elif isinstance(scope.expression, exp.Union):
530        yield from _traverse_union(scope)
531    elif isinstance(scope.expression, exp.Subquery):
532        yield from _traverse_subqueries(scope)
533    elif isinstance(scope.expression, exp.Table):
534        # This case corresponds to a "join construct", i.e. (tbl1 JOIN tbl2 ON ..)
535        yield from _traverse_tables(scope)
536    elif isinstance(scope.expression, exp.UDTF):
537        pass
538    else:
539        raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
540    yield scope
541
542
543def _traverse_select(scope):
544    yield from _traverse_ctes(scope)
545    yield from _traverse_tables(scope)
546    yield from _traverse_subqueries(scope)
547
548
549def _traverse_union(scope):
550    yield from _traverse_ctes(scope)
551
552    # The last scope to be yield should be the top most scope
553    left = None
554    for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)):
555        yield left
556
557    right = None
558    for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)):
559        yield right
560
561    scope.union_scopes = [left, right]
562
563
564def _traverse_ctes(scope):
565    sources = {}
566
567    for cte in scope.ctes:
568        recursive_scope = None
569
570        # if the scope is a recursive cte, it must be in the form of
571        # base_case UNION recursive. thus the recursive scope is the first
572        # section of the union.
573        if scope.expression.args["with"].recursive:
574            union = cte.this
575
576            if isinstance(union, exp.Union):
577                recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
578
579        for child_scope in _traverse_scope(
580            scope.branch(
581                cte.this,
582                chain_sources=sources,
583                outer_column_list=cte.alias_column_names,
584                scope_type=ScopeType.CTE,
585            )
586        ):
587            yield child_scope
588
589            alias = cte.alias
590            sources[alias] = child_scope
591
592            if recursive_scope:
593                child_scope.add_source(alias, recursive_scope)
594
595        # append the final child_scope yielded
596        scope.cte_scopes.append(child_scope)
597
598    scope.sources.update(sources)
599
600
601def _traverse_tables(scope):
602    sources = {}
603
604    # Traverse FROMs, JOINs, and LATERALs in the order they are defined
605    expressions = []
606    from_ = scope.expression.args.get("from")
607    if from_:
608        expressions.append(from_.this)
609
610    for join in scope.expression.args.get("joins") or []:
611        expressions.append(join.this)
612
613    if isinstance(scope.expression, exp.Table):
614        expressions.append(scope.expression)
615
616    expressions.extend(scope.expression.args.get("laterals") or [])
617
618    for expression in expressions:
619        if isinstance(expression, exp.Table):
620            table_name = expression.name
621            source_name = expression.alias_or_name
622
623            if table_name in scope.sources and not expression.db:
624                # This is a reference to a parent source (e.g. a CTE), not an actual table, unless
625                # it is pivoted, because then we get back a new table and hence a new source.
626                pivots = expression.args.get("pivots")
627                if pivots:
628                    sources[pivots[0].alias] = expression
629                else:
630                    sources[source_name] = scope.sources[table_name]
631            elif source_name in sources:
632                sources[find_new_name(sources, table_name)] = expression
633            else:
634                sources[source_name] = expression
635            continue
636
637        if isinstance(expression, exp.UDTF):
638            lateral_sources = sources
639            scope_type = ScopeType.UDTF
640            scopes = scope.udtf_scopes
641        else:
642            lateral_sources = None
643            scope_type = ScopeType.DERIVED_TABLE
644            scopes = scope.derived_table_scopes
645
646        for child_scope in _traverse_scope(
647            scope.branch(
648                expression,
649                lateral_sources=lateral_sources,
650                outer_column_list=expression.alias_column_names,
651                scope_type=scope_type,
652            )
653        ):
654            yield child_scope
655
656            # Tables without aliases will be set as ""
657            # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
658            # Until then, this means that only a single, unaliased derived table is allowed (rather,
659            # the latest one wins.
660            alias = expression.alias
661            sources[alias] = child_scope
662
663        # append the final child_scope yielded
664        scopes.append(child_scope)
665        scope.table_scopes.append(child_scope)
666
667    scope.sources.update(sources)
668
669
670def _traverse_subqueries(scope):
671    for subquery in scope.subqueries:
672        top = None
673        for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)):
674            yield child_scope
675            top = child_scope
676        scope.subquery_scopes.append(top)
677
678
679def walk_in_scope(expression, bfs=True):
680    """
681    Returns a generator object which visits all nodes in the syntrax tree, stopping at
682    nodes that start child scopes.
683
684    Args:
685        expression (exp.Expression):
686        bfs (bool): if set to True the BFS traversal order will be applied,
687            otherwise the DFS traversal will be used instead.
688
689    Yields:
690        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
691    """
692    # We'll use this variable to pass state into the dfs generator.
693    # Whenever we set it to True, we exclude a subtree from traversal.
694    prune = False
695
696    for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
697        prune = False
698
699        yield node, parent, key
700
701        if node is expression:
702            continue
703        if (
704            isinstance(node, exp.CTE)
705            or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)))
706            or isinstance(node, exp.UDTF)
707            or isinstance(node, exp.Subqueryable)
708        ):
709            prune = True
class ScopeType(enum.Enum):
12class ScopeType(Enum):
13    ROOT = auto()
14    SUBQUERY = auto()
15    DERIVED_TABLE = auto()
16    CTE = auto()
17    UNION = auto()
18    UDTF = auto()

An enumeration.

ROOT = <ScopeType.ROOT: 1>
SUBQUERY = <ScopeType.SUBQUERY: 2>
DERIVED_TABLE = <ScopeType.DERIVED_TABLE: 3>
CTE = <ScopeType.CTE: 4>
UNION = <ScopeType.UNION: 5>
UDTF = <ScopeType.UDTF: 6>
Inherited Members
enum.Enum
name
value
class Scope:
 21class Scope:
 22    """
 23    Selection scope.
 24
 25    Attributes:
 26        expression (exp.Select|exp.Union): Root expression of this scope
 27        sources (dict[str, exp.Table|Scope]): Mapping of source name to either
 28            a Table expression or another Scope instance. For example:
 29                SELECT * FROM x                     {"x": Table(this="x")}
 30                SELECT * FROM x AS y                {"y": Table(this="x")}
 31                SELECT * FROM (SELECT ...) AS y     {"y": Scope(...)}
 32        lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
 33            For example:
 34                SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
 35            The LATERAL VIEW EXPLODE gets x as a source.
 36        outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
 37            defines a column list of it's alias of this scope, this is that list of columns.
 38            For example:
 39                SELECT * FROM (SELECT ...) AS y(col1, col2)
 40            The inner query would have `["col1", "col2"]` for its `outer_column_list`
 41        parent (Scope): Parent scope
 42        scope_type (ScopeType): Type of this scope, relative to it's parent
 43        subquery_scopes (list[Scope]): List of all child scopes for subqueries
 44        cte_scopes (list[Scope]): List of all child scopes for CTEs
 45        derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
 46        udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
 47        table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
 48        union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
 49            a list of the left and right child scopes.
 50    """
 51
 52    def __init__(
 53        self,
 54        expression,
 55        sources=None,
 56        outer_column_list=None,
 57        parent=None,
 58        scope_type=ScopeType.ROOT,
 59        lateral_sources=None,
 60    ):
 61        self.expression = expression
 62        self.sources = sources or {}
 63        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
 64        self.sources.update(self.lateral_sources)
 65        self.outer_column_list = outer_column_list or []
 66        self.parent = parent
 67        self.scope_type = scope_type
 68        self.subquery_scopes = []
 69        self.derived_table_scopes = []
 70        self.table_scopes = []
 71        self.cte_scopes = []
 72        self.union_scopes = []
 73        self.udtf_scopes = []
 74        self.clear_cache()
 75
 76    def clear_cache(self):
 77        self._collected = False
 78        self._raw_columns = None
 79        self._derived_tables = None
 80        self._udtfs = None
 81        self._tables = None
 82        self._ctes = None
 83        self._subqueries = None
 84        self._selected_sources = None
 85        self._columns = None
 86        self._external_columns = None
 87        self._join_hints = None
 88        self._pivots = None
 89
 90    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
 91        """Branch from the current scope to a new, inner scope"""
 92        return Scope(
 93            expression=expression.unnest(),
 94            sources={**self.cte_sources, **(chain_sources or {})},
 95            parent=self,
 96            scope_type=scope_type,
 97            **kwargs,
 98        )
 99
100    def _collect(self):
101        self._tables = []
102        self._ctes = []
103        self._subqueries = []
104        self._derived_tables = []
105        self._udtfs = []
106        self._raw_columns = []
107        self._join_hints = []
108
109        for node, parent, _ in self.walk(bfs=False):
110            if node is self.expression:
111                continue
112            elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
113                self._raw_columns.append(node)
114            elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
115                self._tables.append(node)
116            elif isinstance(node, exp.JoinHint):
117                self._join_hints.append(node)
118            elif isinstance(node, exp.UDTF):
119                self._udtfs.append(node)
120            elif isinstance(node, exp.CTE):
121                self._ctes.append(node)
122            elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
123                self._derived_tables.append(node)
124            elif isinstance(node, exp.Subqueryable):
125                self._subqueries.append(node)
126
127        self._collected = True
128
129    def _ensure_collected(self):
130        if not self._collected:
131            self._collect()
132
133    def walk(self, bfs=True):
134        return walk_in_scope(self.expression, bfs=bfs)
135
136    def find(self, *expression_types, bfs=True):
137        """
138        Returns the first node in this scope which matches at least one of the specified types.
139
140        This does NOT traverse into subscopes.
141
142        Args:
143            expression_types (type): the expression type(s) to match.
144            bfs (bool): True to use breadth-first search, False to use depth-first.
145
146        Returns:
147            exp.Expression: the node which matches the criteria or None if no node matching
148            the criteria was found.
149        """
150        return next(self.find_all(*expression_types, bfs=bfs), None)
151
152    def find_all(self, *expression_types, bfs=True):
153        """
154        Returns a generator object which visits all nodes in this scope and only yields those that
155        match at least one of the specified expression types.
156
157        This does NOT traverse into subscopes.
158
159        Args:
160            expression_types (type): the expression type(s) to match.
161            bfs (bool): True to use breadth-first search, False to use depth-first.
162
163        Yields:
164            exp.Expression: nodes
165        """
166        for expression, *_ in self.walk(bfs=bfs):
167            if isinstance(expression, expression_types):
168                yield expression
169
170    def replace(self, old, new):
171        """
172        Replace `old` with `new`.
173
174        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
175
176        Args:
177            old (exp.Expression): old node
178            new (exp.Expression): new node
179        """
180        old.replace(new)
181        self.clear_cache()
182
183    @property
184    def tables(self):
185        """
186        List of tables in this scope.
187
188        Returns:
189            list[exp.Table]: tables
190        """
191        self._ensure_collected()
192        return self._tables
193
194    @property
195    def ctes(self):
196        """
197        List of CTEs in this scope.
198
199        Returns:
200            list[exp.CTE]: ctes
201        """
202        self._ensure_collected()
203        return self._ctes
204
205    @property
206    def derived_tables(self):
207        """
208        List of derived tables in this scope.
209
210        For example:
211            SELECT * FROM (SELECT ...) <- that's a derived table
212
213        Returns:
214            list[exp.Subquery]: derived tables
215        """
216        self._ensure_collected()
217        return self._derived_tables
218
219    @property
220    def udtfs(self):
221        """
222        List of "User Defined Tabular Functions" in this scope.
223
224        Returns:
225            list[exp.UDTF]: UDTFs
226        """
227        self._ensure_collected()
228        return self._udtfs
229
230    @property
231    def subqueries(self):
232        """
233        List of subqueries in this scope.
234
235        For example:
236            SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
237
238        Returns:
239            list[exp.Subqueryable]: subqueries
240        """
241        self._ensure_collected()
242        return self._subqueries
243
244    @property
245    def columns(self):
246        """
247        List of columns in this scope.
248
249        Returns:
250            list[exp.Column]: Column instances in this scope, plus any
251                Columns that reference this scope from correlated subqueries.
252        """
253        if self._columns is None:
254            self._ensure_collected()
255            columns = self._raw_columns
256
257            external_columns = [
258                column
259                for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
260                for column in scope.external_columns
261            ]
262
263            named_selects = set(self.expression.named_selects)
264
265            self._columns = []
266            for column in columns + external_columns:
267                ancestor = column.find_ancestor(
268                    exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint
269                )
270                if (
271                    not ancestor
272                    or column.table
273                    or isinstance(ancestor, exp.Select)
274                    or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window))
275                    or (column.name not in named_selects and not isinstance(ancestor, exp.Hint))
276                ):
277                    self._columns.append(column)
278
279        return self._columns
280
281    @property
282    def selected_sources(self):
283        """
284        Mapping of nodes and sources that are actually selected from in this scope.
285
286        That is, all tables in a schema are selectable at any point. But a
287        table only becomes a selected source if it's included in a FROM or JOIN clause.
288
289        Returns:
290            dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
291        """
292        if self._selected_sources is None:
293            referenced_names = []
294
295            for table in self.tables:
296                referenced_names.append((table.alias_or_name, table))
297            for expression in itertools.chain(self.derived_tables, self.udtfs):
298                referenced_names.append((expression.alias, expression.unnest()))
299            result = {}
300
301            for name, node in referenced_names:
302                if name in result:
303                    raise OptimizeError(f"Alias already used: {name}")
304                if name in self.sources:
305                    result[name] = (node, self.sources[name])
306
307            self._selected_sources = result
308        return self._selected_sources
309
310    @property
311    def cte_sources(self):
312        """
313        Sources that are CTEs.
314
315        Returns:
316            dict[str, Scope]: Mapping of source alias to Scope
317        """
318        return {
319            alias: scope
320            for alias, scope in self.sources.items()
321            if isinstance(scope, Scope) and scope.is_cte
322        }
323
324    @property
325    def selects(self):
326        """
327        Select expressions of this scope.
328
329        For example, for the following expression:
330            SELECT 1 as a, 2 as b FROM x
331
332        The outputs are the "1 as a" and "2 as b" expressions.
333
334        Returns:
335            list[exp.Expression]: expressions
336        """
337        if isinstance(self.expression, exp.Union):
338            return self.expression.unnest().selects
339        return self.expression.selects
340
341    @property
342    def external_columns(self):
343        """
344        Columns that appear to reference sources in outer scopes.
345
346        Returns:
347            list[exp.Column]: Column instances that don't reference
348                sources in the current scope.
349        """
350        if self._external_columns is None:
351            self._external_columns = [
352                c for c in self.columns if c.table not in self.selected_sources
353            ]
354        return self._external_columns
355
356    @property
357    def unqualified_columns(self):
358        """
359        Unqualified columns in the current scope.
360
361        Returns:
362             list[exp.Column]: Unqualified columns
363        """
364        return [c for c in self.columns if not c.table]
365
366    @property
367    def join_hints(self):
368        """
369        Hints that exist in the scope that reference tables
370
371        Returns:
372            list[exp.JoinHint]: Join hints that are referenced within the scope
373        """
374        if self._join_hints is None:
375            return []
376        return self._join_hints
377
378    @property
379    def pivots(self):
380        if not self._pivots:
381            self._pivots = [
382                pivot
383                for node in self.tables + self.derived_tables
384                for pivot in node.args.get("pivots") or []
385            ]
386
387        return self._pivots
388
389    def source_columns(self, source_name):
390        """
391        Get all columns in the current scope for a particular source.
392
393        Args:
394            source_name (str): Name of the source
395        Returns:
396            list[exp.Column]: Column instances that reference `source_name`
397        """
398        return [column for column in self.columns if column.table == source_name]
399
400    @property
401    def is_subquery(self):
402        """Determine if this scope is a subquery"""
403        return self.scope_type == ScopeType.SUBQUERY
404
405    @property
406    def is_derived_table(self):
407        """Determine if this scope is a derived table"""
408        return self.scope_type == ScopeType.DERIVED_TABLE
409
410    @property
411    def is_union(self):
412        """Determine if this scope is a union"""
413        return self.scope_type == ScopeType.UNION
414
415    @property
416    def is_cte(self):
417        """Determine if this scope is a common table expression"""
418        return self.scope_type == ScopeType.CTE
419
420    @property
421    def is_root(self):
422        """Determine if this is the root scope"""
423        return self.scope_type == ScopeType.ROOT
424
425    @property
426    def is_udtf(self):
427        """Determine if this scope is a UDTF (User Defined Table Function)"""
428        return self.scope_type == ScopeType.UDTF
429
430    @property
431    def is_correlated_subquery(self):
432        """Determine if this scope is a correlated subquery"""
433        return bool(self.is_subquery and self.external_columns)
434
435    def rename_source(self, old_name, new_name):
436        """Rename a source in this scope"""
437        columns = self.sources.pop(old_name or "", [])
438        self.sources[new_name] = columns
439
440    def add_source(self, name, source):
441        """Add a source to this scope"""
442        self.sources[name] = source
443        self.clear_cache()
444
445    def remove_source(self, name):
446        """Remove a source from this scope"""
447        self.sources.pop(name, None)
448        self.clear_cache()
449
450    def __repr__(self):
451        return f"Scope<{self.expression.sql()}>"
452
453    def traverse(self):
454        """
455        Traverse the scope tree from this node.
456
457        Yields:
458            Scope: scope instances in depth-first-search post-order
459        """
460        for child_scope in itertools.chain(
461            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
462        ):
463            yield from child_scope.traverse()
464        yield self
465
466    def ref_count(self):
467        """
468        Count the number of times each scope in this tree is referenced.
469
470        Returns:
471            dict[int, int]: Mapping of Scope instance ID to reference count
472        """
473        scope_ref_count = defaultdict(lambda: 0)
474
475        for scope in self.traverse():
476            for _, source in scope.selected_sources.values():
477                scope_ref_count[id(source)] += 1
478
479        return scope_ref_count

Selection scope.

Attributes:
  • expression (exp.Select|exp.Union): Root expression of this scope
  • sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
  • lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
  • outer_column_list (list[str]): If this is a derived table or CTE, and the outer query defines a column list of it's alias of this scope, this is that list of columns. For example: SELECT * FROM (SELECT ...) AS y(col1, col2) The inner query would have ["col1", "col2"] for its outer_column_list
  • parent (Scope): Parent scope
  • scope_type (ScopeType): Type of this scope, relative to it's parent
  • subquery_scopes (list[Scope]): List of all child scopes for subqueries
  • cte_scopes (list[Scope]): List of all child scopes for CTEs
  • derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
  • udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
  • table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
  • union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
Scope( expression, sources=None, outer_column_list=None, parent=None, scope_type=<ScopeType.ROOT: 1>, lateral_sources=None)
52    def __init__(
53        self,
54        expression,
55        sources=None,
56        outer_column_list=None,
57        parent=None,
58        scope_type=ScopeType.ROOT,
59        lateral_sources=None,
60    ):
61        self.expression = expression
62        self.sources = sources or {}
63        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
64        self.sources.update(self.lateral_sources)
65        self.outer_column_list = outer_column_list or []
66        self.parent = parent
67        self.scope_type = scope_type
68        self.subquery_scopes = []
69        self.derived_table_scopes = []
70        self.table_scopes = []
71        self.cte_scopes = []
72        self.union_scopes = []
73        self.udtf_scopes = []
74        self.clear_cache()
def clear_cache(self):
76    def clear_cache(self):
77        self._collected = False
78        self._raw_columns = None
79        self._derived_tables = None
80        self._udtfs = None
81        self._tables = None
82        self._ctes = None
83        self._subqueries = None
84        self._selected_sources = None
85        self._columns = None
86        self._external_columns = None
87        self._join_hints = None
88        self._pivots = None
def branch(self, expression, scope_type, chain_sources=None, **kwargs):
90    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
91        """Branch from the current scope to a new, inner scope"""
92        return Scope(
93            expression=expression.unnest(),
94            sources={**self.cte_sources, **(chain_sources or {})},
95            parent=self,
96            scope_type=scope_type,
97            **kwargs,
98        )

Branch from the current scope to a new, inner scope

def walk(self, bfs=True):
133    def walk(self, bfs=True):
134        return walk_in_scope(self.expression, bfs=bfs)
def find(self, *expression_types, bfs=True):
136    def find(self, *expression_types, bfs=True):
137        """
138        Returns the first node in this scope which matches at least one of the specified types.
139
140        This does NOT traverse into subscopes.
141
142        Args:
143            expression_types (type): the expression type(s) to match.
144            bfs (bool): True to use breadth-first search, False to use depth-first.
145
146        Returns:
147            exp.Expression: the node which matches the criteria or None if no node matching
148            the criteria was found.
149        """
150        return next(self.find_all(*expression_types, bfs=bfs), None)

Returns the first node in this scope which matches at least one of the specified types.

This does NOT traverse into subscopes.

Arguments:
  • expression_types (type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:

exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.

def find_all(self, *expression_types, bfs=True):
152    def find_all(self, *expression_types, bfs=True):
153        """
154        Returns a generator object which visits all nodes in this scope and only yields those that
155        match at least one of the specified expression types.
156
157        This does NOT traverse into subscopes.
158
159        Args:
160            expression_types (type): the expression type(s) to match.
161            bfs (bool): True to use breadth-first search, False to use depth-first.
162
163        Yields:
164            exp.Expression: nodes
165        """
166        for expression, *_ in self.walk(bfs=bfs):
167            if isinstance(expression, expression_types):
168                yield expression

Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.

This does NOT traverse into subscopes.

Arguments:
  • expression_types (type): the expression type(s) to match.
  • bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:

exp.Expression: nodes

def replace(self, old, new):
170    def replace(self, old, new):
171        """
172        Replace `old` with `new`.
173
174        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
175
176        Args:
177            old (exp.Expression): old node
178            new (exp.Expression): new node
179        """
180        old.replace(new)
181        self.clear_cache()

Replace old with new.

This can be used instead of exp.Expression.replace to ensure the Scope is kept up-to-date.

Arguments:
  • old (exp.Expression): old node
  • new (exp.Expression): new node
tables

List of tables in this scope.

Returns:

list[exp.Table]: tables

ctes

List of CTEs in this scope.

Returns:

list[exp.CTE]: ctes

derived_tables

List of derived tables in this scope.

For example:

SELECT * FROM (SELECT ...) <- that's a derived table

Returns:

list[exp.Subquery]: derived tables

udtfs

List of "User Defined Tabular Functions" in this scope.

Returns:

list[exp.UDTF]: UDTFs

subqueries

List of subqueries in this scope.

For example:

SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery

Returns:

list[exp.Subqueryable]: subqueries

columns

List of columns in this scope.

Returns:

list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.

selected_sources

Mapping of nodes and sources that are actually selected from in this scope.

That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.

Returns:

dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes

cte_sources

Sources that are CTEs.

Returns:

dict[str, Scope]: Mapping of source alias to Scope

selects

Select expressions of this scope.

For example, for the following expression: SELECT 1 as a, 2 as b FROM x

The outputs are the "1 as a" and "2 as b" expressions.

Returns:

list[exp.Expression]: expressions

external_columns

Columns that appear to reference sources in outer scopes.

Returns:

list[exp.Column]: Column instances that don't reference sources in the current scope.

unqualified_columns

Unqualified columns in the current scope.

Returns:

list[exp.Column]: Unqualified columns

join_hints

Hints that exist in the scope that reference tables

Returns:

list[exp.JoinHint]: Join hints that are referenced within the scope

def source_columns(self, source_name):
389    def source_columns(self, source_name):
390        """
391        Get all columns in the current scope for a particular source.
392
393        Args:
394            source_name (str): Name of the source
395        Returns:
396            list[exp.Column]: Column instances that reference `source_name`
397        """
398        return [column for column in self.columns if column.table == source_name]

Get all columns in the current scope for a particular source.

Arguments:
  • source_name (str): Name of the source
Returns:

list[exp.Column]: Column instances that reference source_name

is_subquery

Determine if this scope is a subquery

is_derived_table

Determine if this scope is a derived table

is_union

Determine if this scope is a union

is_cte

Determine if this scope is a common table expression

is_root

Determine if this is the root scope

is_udtf

Determine if this scope is a UDTF (User Defined Table Function)

is_correlated_subquery

Determine if this scope is a correlated subquery

def rename_source(self, old_name, new_name):
435    def rename_source(self, old_name, new_name):
436        """Rename a source in this scope"""
437        columns = self.sources.pop(old_name or "", [])
438        self.sources[new_name] = columns

Rename a source in this scope

def add_source(self, name, source):
440    def add_source(self, name, source):
441        """Add a source to this scope"""
442        self.sources[name] = source
443        self.clear_cache()

Add a source to this scope

def remove_source(self, name):
445    def remove_source(self, name):
446        """Remove a source from this scope"""
447        self.sources.pop(name, None)
448        self.clear_cache()

Remove a source from this scope

def traverse(self):
453    def traverse(self):
454        """
455        Traverse the scope tree from this node.
456
457        Yields:
458            Scope: scope instances in depth-first-search post-order
459        """
460        for child_scope in itertools.chain(
461            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
462        ):
463            yield from child_scope.traverse()
464        yield self

Traverse the scope tree from this node.

Yields:

Scope: scope instances in depth-first-search post-order

def ref_count(self):
466    def ref_count(self):
467        """
468        Count the number of times each scope in this tree is referenced.
469
470        Returns:
471            dict[int, int]: Mapping of Scope instance ID to reference count
472        """
473        scope_ref_count = defaultdict(lambda: 0)
474
475        for scope in self.traverse():
476            for _, source in scope.selected_sources.values():
477                scope_ref_count[id(source)] += 1
478
479        return scope_ref_count

Count the number of times each scope in this tree is referenced.

Returns:

dict[int, int]: Mapping of Scope instance ID to reference count

def traverse_scope( expression: sqlglot.expressions.Expression) -> List[sqlglot.optimizer.scope.Scope]:
482def traverse_scope(expression: exp.Expression) -> t.List[Scope]:
483    """
484    Traverse an expression by it's "scopes".
485
486    "Scope" represents the current context of a Select statement.
487
488    This is helpful for optimizing queries, where we need more information than
489    the expression tree itself. For example, we might care about the source
490    names within a subquery. Returns a list because a generator could result in
491    incomplete properties which is confusing.
492
493    Examples:
494        >>> import sqlglot
495        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
496        >>> scopes = traverse_scope(expression)
497        >>> scopes[0].expression.sql(), list(scopes[0].sources)
498        ('SELECT a FROM x', ['x'])
499        >>> scopes[1].expression.sql(), list(scopes[1].sources)
500        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
501
502    Args:
503        expression (exp.Expression): expression to traverse
504    Returns:
505        list[Scope]: scope instances
506    """
507    if not isinstance(expression, exp.Unionable):
508        return []
509    return list(_traverse_scope(Scope(expression)))

Traverse an expression by it's "scopes".

"Scope" represents the current context of a Select statement.

This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.

Examples:
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
>>> scopes = traverse_scope(expression)
>>> scopes[0].expression.sql(), list(scopes[0].sources)
('SELECT a FROM x', ['x'])
>>> scopes[1].expression.sql(), list(scopes[1].sources)
('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
  • expression (exp.Expression): expression to traverse
Returns:

list[Scope]: scope instances

def build_scope( expression: sqlglot.expressions.Expression) -> Optional[sqlglot.optimizer.scope.Scope]:
512def build_scope(expression: exp.Expression) -> t.Optional[Scope]:
513    """
514    Build a scope tree.
515
516    Args:
517        expression (exp.Expression): expression to build the scope tree for
518    Returns:
519        Scope: root scope
520    """
521    scopes = traverse_scope(expression)
522    if scopes:
523        return scopes[-1]
524    return None

Build a scope tree.

Arguments:
  • expression (exp.Expression): expression to build the scope tree for
Returns:

Scope: root scope

def walk_in_scope(expression, bfs=True):
680def walk_in_scope(expression, bfs=True):
681    """
682    Returns a generator object which visits all nodes in the syntrax tree, stopping at
683    nodes that start child scopes.
684
685    Args:
686        expression (exp.Expression):
687        bfs (bool): if set to True the BFS traversal order will be applied,
688            otherwise the DFS traversal will be used instead.
689
690    Yields:
691        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
692    """
693    # We'll use this variable to pass state into the dfs generator.
694    # Whenever we set it to True, we exclude a subtree from traversal.
695    prune = False
696
697    for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
698        prune = False
699
700        yield node, parent, key
701
702        if node is expression:
703            continue
704        if (
705            isinstance(node, exp.CTE)
706            or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)))
707            or isinstance(node, exp.UDTF)
708            or isinstance(node, exp.Subqueryable)
709        ):
710            prune = True

Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.

Arguments:
  • expression (exp.Expression):
  • bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
Yields:

tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key