sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import ensure_collection, find_new_name, seq_get 12 13logger = logging.getLogger("sqlglot") 14 15TRAVERSABLES = (exp.Query, exp.DDL, exp.DML) 16 17 18class ScopeType(Enum): 19 ROOT = auto() 20 SUBQUERY = auto() 21 DERIVED_TABLE = auto() 22 CTE = auto() 23 UNION = auto() 24 UDTF = auto() 25 26 27class Scope: 28 """ 29 Selection scope. 30 31 Attributes: 32 expression (exp.Select|exp.SetOperation): Root expression of this scope 33 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 34 a Table expression or another Scope instance. For example: 35 SELECT * FROM x {"x": Table(this="x")} 36 SELECT * FROM x AS y {"y": Table(this="x")} 37 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 38 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 39 For example: 40 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 41 The LATERAL VIEW EXPLODE gets x as a source. 42 cte_sources (dict[str, Scope]): Sources from CTES 43 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 44 defines a column list for the alias of this scope, this is that list of columns. 45 For example: 46 SELECT * FROM (SELECT ...) AS y(col1, col2) 47 The inner query would have `["col1", "col2"]` for its `outer_columns` 48 parent (Scope): Parent scope 49 scope_type (ScopeType): Type of this scope, relative to it's parent 50 subquery_scopes (list[Scope]): List of all child scopes for subqueries 51 cte_scopes (list[Scope]): List of all child scopes for CTEs 52 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 53 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 54 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 55 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 56 a list of the left and right child scopes. 57 """ 58 59 def __init__( 60 self, 61 expression, 62 sources=None, 63 outer_columns=None, 64 parent=None, 65 scope_type=ScopeType.ROOT, 66 lateral_sources=None, 67 cte_sources=None, 68 can_be_correlated=None, 69 ): 70 self.expression = expression 71 self.sources = sources or {} 72 self.lateral_sources = lateral_sources or {} 73 self.cte_sources = cte_sources or {} 74 self.sources.update(self.lateral_sources) 75 self.sources.update(self.cte_sources) 76 self.outer_columns = outer_columns or [] 77 self.parent = parent 78 self.scope_type = scope_type 79 self.subquery_scopes = [] 80 self.derived_table_scopes = [] 81 self.table_scopes = [] 82 self.cte_scopes = [] 83 self.union_scopes = [] 84 self.udtf_scopes = [] 85 self.can_be_correlated = can_be_correlated 86 self.clear_cache() 87 88 def clear_cache(self): 89 self._collected = False 90 self._raw_columns = None 91 self._stars = None 92 self._derived_tables = None 93 self._udtfs = None 94 self._tables = None 95 self._ctes = None 96 self._subqueries = None 97 self._selected_sources = None 98 self._columns = None 99 self._external_columns = None 100 self._join_hints = None 101 self._pivots = None 102 self._references = None 103 104 def branch( 105 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 106 ): 107 """Branch from the current scope to a new, inner scope""" 108 return Scope( 109 expression=expression.unnest(), 110 sources=sources.copy() if sources else None, 111 parent=self, 112 scope_type=scope_type, 113 cte_sources={**self.cte_sources, **(cte_sources or {})}, 114 lateral_sources=lateral_sources.copy() if lateral_sources else None, 115 can_be_correlated=self.can_be_correlated 116 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 117 **kwargs, 118 ) 119 120 def _collect(self): 121 self._tables = [] 122 self._ctes = [] 123 self._subqueries = [] 124 self._derived_tables = [] 125 self._udtfs = [] 126 self._raw_columns = [] 127 self._stars = [] 128 self._join_hints = [] 129 130 for node in self.walk(bfs=False): 131 if node is self.expression: 132 continue 133 134 if isinstance(node, exp.Dot) and node.is_star: 135 self._stars.append(node) 136 elif isinstance(node, exp.Column): 137 if isinstance(node.this, exp.Star): 138 self._stars.append(node) 139 else: 140 self._raw_columns.append(node) 141 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 142 self._tables.append(node) 143 elif isinstance(node, exp.JoinHint): 144 self._join_hints.append(node) 145 elif isinstance(node, exp.UDTF): 146 self._udtfs.append(node) 147 elif isinstance(node, exp.CTE): 148 self._ctes.append(node) 149 elif _is_derived_table(node) and isinstance( 150 node.parent, (exp.From, exp.Join, exp.Subquery) 151 ): 152 self._derived_tables.append(node) 153 elif isinstance(node, exp.UNWRAPPED_QUERIES): 154 self._subqueries.append(node) 155 156 self._collected = True 157 158 def _ensure_collected(self): 159 if not self._collected: 160 self._collect() 161 162 def walk(self, bfs=True, prune=None): 163 return walk_in_scope(self.expression, bfs=bfs, prune=None) 164 165 def find(self, *expression_types, bfs=True): 166 return find_in_scope(self.expression, expression_types, bfs=bfs) 167 168 def find_all(self, *expression_types, bfs=True): 169 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 170 171 def replace(self, old, new): 172 """ 173 Replace `old` with `new`. 174 175 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 176 177 Args: 178 old (exp.Expression): old node 179 new (exp.Expression): new node 180 """ 181 old.replace(new) 182 self.clear_cache() 183 184 @property 185 def tables(self): 186 """ 187 List of tables in this scope. 188 189 Returns: 190 list[exp.Table]: tables 191 """ 192 self._ensure_collected() 193 return self._tables 194 195 @property 196 def ctes(self): 197 """ 198 List of CTEs in this scope. 199 200 Returns: 201 list[exp.CTE]: ctes 202 """ 203 self._ensure_collected() 204 return self._ctes 205 206 @property 207 def derived_tables(self): 208 """ 209 List of derived tables in this scope. 210 211 For example: 212 SELECT * FROM (SELECT ...) <- that's a derived table 213 214 Returns: 215 list[exp.Subquery]: derived tables 216 """ 217 self._ensure_collected() 218 return self._derived_tables 219 220 @property 221 def udtfs(self): 222 """ 223 List of "User Defined Tabular Functions" in this scope. 224 225 Returns: 226 list[exp.UDTF]: UDTFs 227 """ 228 self._ensure_collected() 229 return self._udtfs 230 231 @property 232 def subqueries(self): 233 """ 234 List of subqueries in this scope. 235 236 For example: 237 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 238 239 Returns: 240 list[exp.Select | exp.SetOperation]: subqueries 241 """ 242 self._ensure_collected() 243 return self._subqueries 244 245 @property 246 def stars(self) -> t.List[exp.Column | exp.Dot]: 247 """ 248 List of star expressions (columns or dots) in this scope. 249 """ 250 self._ensure_collected() 251 return self._stars 252 253 @property 254 def columns(self): 255 """ 256 List of columns in this scope. 257 258 Returns: 259 list[exp.Column]: Column instances in this scope, plus any 260 Columns that reference this scope from correlated subqueries. 261 """ 262 if self._columns is None: 263 self._ensure_collected() 264 columns = self._raw_columns 265 266 external_columns = [ 267 column 268 for scope in itertools.chain( 269 self.subquery_scopes, 270 self.udtf_scopes, 271 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 272 ) 273 for column in scope.external_columns 274 ] 275 276 named_selects = set(self.expression.named_selects) 277 278 self._columns = [] 279 for column in columns + external_columns: 280 ancestor = column.find_ancestor( 281 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 282 ) 283 if ( 284 not ancestor 285 or column.table 286 or isinstance(ancestor, exp.Select) 287 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 288 or ( 289 isinstance(ancestor, exp.Order) 290 and ( 291 isinstance(ancestor.parent, exp.Window) 292 or column.name not in named_selects 293 ) 294 ) 295 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 296 ): 297 self._columns.append(column) 298 299 return self._columns 300 301 @property 302 def selected_sources(self): 303 """ 304 Mapping of nodes and sources that are actually selected from in this scope. 305 306 That is, all tables in a schema are selectable at any point. But a 307 table only becomes a selected source if it's included in a FROM or JOIN clause. 308 309 Returns: 310 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 311 """ 312 if self._selected_sources is None: 313 result = {} 314 315 for name, node in self.references: 316 if name in result: 317 raise OptimizeError(f"Alias already used: {name}") 318 if name in self.sources: 319 result[name] = (node, self.sources[name]) 320 321 self._selected_sources = result 322 return self._selected_sources 323 324 @property 325 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 326 if self._references is None: 327 self._references = [] 328 329 for table in self.tables: 330 self._references.append((table.alias_or_name, table)) 331 for expression in itertools.chain(self.derived_tables, self.udtfs): 332 self._references.append( 333 ( 334 expression.alias, 335 expression if expression.args.get("pivots") else expression.unnest(), 336 ) 337 ) 338 339 return self._references 340 341 @property 342 def external_columns(self): 343 """ 344 Columns that appear to reference sources in outer scopes. 345 346 Returns: 347 list[exp.Column]: Column instances that don't reference 348 sources in the current scope. 349 """ 350 if self._external_columns is None: 351 if isinstance(self.expression, exp.SetOperation): 352 left, right = self.union_scopes 353 self._external_columns = left.external_columns + right.external_columns 354 else: 355 self._external_columns = [ 356 c for c in self.columns if c.table not in self.selected_sources 357 ] 358 359 return self._external_columns 360 361 @property 362 def unqualified_columns(self): 363 """ 364 Unqualified columns in the current scope. 365 366 Returns: 367 list[exp.Column]: Unqualified columns 368 """ 369 return [c for c in self.columns if not c.table] 370 371 @property 372 def join_hints(self): 373 """ 374 Hints that exist in the scope that reference tables 375 376 Returns: 377 list[exp.JoinHint]: Join hints that are referenced within the scope 378 """ 379 if self._join_hints is None: 380 return [] 381 return self._join_hints 382 383 @property 384 def pivots(self): 385 if not self._pivots: 386 self._pivots = [ 387 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 388 ] 389 390 return self._pivots 391 392 def source_columns(self, source_name): 393 """ 394 Get all columns in the current scope for a particular source. 395 396 Args: 397 source_name (str): Name of the source 398 Returns: 399 list[exp.Column]: Column instances that reference `source_name` 400 """ 401 return [column for column in self.columns if column.table == source_name] 402 403 @property 404 def is_subquery(self): 405 """Determine if this scope is a subquery""" 406 return self.scope_type == ScopeType.SUBQUERY 407 408 @property 409 def is_derived_table(self): 410 """Determine if this scope is a derived table""" 411 return self.scope_type == ScopeType.DERIVED_TABLE 412 413 @property 414 def is_union(self): 415 """Determine if this scope is a union""" 416 return self.scope_type == ScopeType.UNION 417 418 @property 419 def is_cte(self): 420 """Determine if this scope is a common table expression""" 421 return self.scope_type == ScopeType.CTE 422 423 @property 424 def is_root(self): 425 """Determine if this is the root scope""" 426 return self.scope_type == ScopeType.ROOT 427 428 @property 429 def is_udtf(self): 430 """Determine if this scope is a UDTF (User Defined Table Function)""" 431 return self.scope_type == ScopeType.UDTF 432 433 @property 434 def is_correlated_subquery(self): 435 """Determine if this scope is a correlated subquery""" 436 return bool(self.can_be_correlated and self.external_columns) 437 438 def rename_source(self, old_name, new_name): 439 """Rename a source in this scope""" 440 columns = self.sources.pop(old_name or "", []) 441 self.sources[new_name] = columns 442 443 def add_source(self, name, source): 444 """Add a source to this scope""" 445 self.sources[name] = source 446 self.clear_cache() 447 448 def remove_source(self, name): 449 """Remove a source from this scope""" 450 self.sources.pop(name, None) 451 self.clear_cache() 452 453 def __repr__(self): 454 return f"Scope<{self.expression.sql()}>" 455 456 def traverse(self): 457 """ 458 Traverse the scope tree from this node. 459 460 Yields: 461 Scope: scope instances in depth-first-search post-order 462 """ 463 stack = [self] 464 result = [] 465 while stack: 466 scope = stack.pop() 467 result.append(scope) 468 stack.extend( 469 itertools.chain( 470 scope.cte_scopes, 471 scope.union_scopes, 472 scope.table_scopes, 473 scope.subquery_scopes, 474 ) 475 ) 476 477 yield from reversed(result) 478 479 def ref_count(self): 480 """ 481 Count the number of times each scope in this tree is referenced. 482 483 Returns: 484 dict[int, int]: Mapping of Scope instance ID to reference count 485 """ 486 scope_ref_count = defaultdict(lambda: 0) 487 488 for scope in self.traverse(): 489 for _, source in scope.selected_sources.values(): 490 scope_ref_count[id(source)] += 1 491 492 return scope_ref_count 493 494 495def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 496 """ 497 Traverse an expression by its "scopes". 498 499 "Scope" represents the current context of a Select statement. 500 501 This is helpful for optimizing queries, where we need more information than 502 the expression tree itself. For example, we might care about the source 503 names within a subquery. Returns a list because a generator could result in 504 incomplete properties which is confusing. 505 506 Examples: 507 >>> import sqlglot 508 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 509 >>> scopes = traverse_scope(expression) 510 >>> scopes[0].expression.sql(), list(scopes[0].sources) 511 ('SELECT a FROM x', ['x']) 512 >>> scopes[1].expression.sql(), list(scopes[1].sources) 513 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 514 515 Args: 516 expression: Expression to traverse 517 518 Returns: 519 A list of the created scope instances 520 """ 521 if isinstance(expression, TRAVERSABLES): 522 return list(_traverse_scope(Scope(expression))) 523 return [] 524 525 526def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 527 """ 528 Build a scope tree. 529 530 Args: 531 expression: Expression to build the scope tree for. 532 533 Returns: 534 The root scope 535 """ 536 return seq_get(traverse_scope(expression), -1) 537 538 539def _traverse_scope(scope): 540 expression = scope.expression 541 542 if isinstance(expression, exp.Select): 543 yield from _traverse_select(scope) 544 elif isinstance(expression, exp.SetOperation): 545 yield from _traverse_ctes(scope) 546 yield from _traverse_union(scope) 547 return 548 elif isinstance(expression, exp.Subquery): 549 if scope.is_root: 550 yield from _traverse_select(scope) 551 else: 552 yield from _traverse_subqueries(scope) 553 elif isinstance(expression, exp.Table): 554 yield from _traverse_tables(scope) 555 elif isinstance(expression, exp.UDTF): 556 yield from _traverse_udtfs(scope) 557 elif isinstance(expression, exp.DDL): 558 if isinstance(expression.expression, exp.Query): 559 yield from _traverse_ctes(scope) 560 yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources)) 561 return 562 elif isinstance(expression, exp.DML): 563 yield from _traverse_ctes(scope) 564 for query in find_all_in_scope(expression, exp.Query): 565 # This check ensures we don't yield the CTE queries twice 566 if not isinstance(query.parent, exp.CTE): 567 yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources)) 568 return 569 else: 570 logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression)) 571 return 572 573 yield scope 574 575 576def _traverse_select(scope): 577 yield from _traverse_ctes(scope) 578 yield from _traverse_tables(scope) 579 yield from _traverse_subqueries(scope) 580 581 582def _traverse_union(scope): 583 prev_scope = None 584 union_scope_stack = [scope] 585 expression_stack = [scope.expression.right, scope.expression.left] 586 587 while expression_stack: 588 expression = expression_stack.pop() 589 union_scope = union_scope_stack[-1] 590 591 new_scope = union_scope.branch( 592 expression, 593 outer_columns=union_scope.outer_columns, 594 scope_type=ScopeType.UNION, 595 ) 596 597 if isinstance(expression, exp.SetOperation): 598 yield from _traverse_ctes(new_scope) 599 600 union_scope_stack.append(new_scope) 601 expression_stack.extend([expression.right, expression.left]) 602 continue 603 604 for scope in _traverse_scope(new_scope): 605 yield scope 606 607 if prev_scope: 608 union_scope_stack.pop() 609 union_scope.union_scopes = [prev_scope, scope] 610 prev_scope = union_scope 611 612 yield union_scope 613 else: 614 prev_scope = scope 615 616 617def _traverse_ctes(scope): 618 sources = {} 619 620 for cte in scope.ctes: 621 cte_name = cte.alias 622 623 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 624 # thus the recursive scope is the first section of the union. 625 with_ = scope.expression.args.get("with") 626 if with_ and with_.recursive: 627 union = cte.this 628 629 if isinstance(union, exp.SetOperation): 630 sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE) 631 632 child_scope = None 633 634 for child_scope in _traverse_scope( 635 scope.branch( 636 cte.this, 637 cte_sources=sources, 638 outer_columns=cte.alias_column_names, 639 scope_type=ScopeType.CTE, 640 ) 641 ): 642 yield child_scope 643 644 # append the final child_scope yielded 645 if child_scope: 646 sources[cte_name] = child_scope 647 scope.cte_scopes.append(child_scope) 648 649 scope.sources.update(sources) 650 scope.cte_sources.update(sources) 651 652 653def _is_derived_table(expression: exp.Subquery) -> bool: 654 """ 655 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 656 as it doesn't introduce a new scope. If an alias is present, it shadows all names 657 under the Subquery, so that's one exception to this rule. 658 """ 659 return isinstance(expression, exp.Subquery) and bool( 660 expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) 661 ) 662 663 664def _traverse_tables(scope): 665 sources = {} 666 667 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 668 expressions = [] 669 from_ = scope.expression.args.get("from") 670 if from_: 671 expressions.append(from_.this) 672 673 for join in scope.expression.args.get("joins") or []: 674 expressions.append(join.this) 675 676 if isinstance(scope.expression, exp.Table): 677 expressions.append(scope.expression) 678 679 expressions.extend(scope.expression.args.get("laterals") or []) 680 681 for expression in expressions: 682 if isinstance(expression, exp.Table): 683 table_name = expression.name 684 source_name = expression.alias_or_name 685 686 if table_name in scope.sources and not expression.db: 687 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 688 # it is pivoted, because then we get back a new table and hence a new source. 689 pivots = expression.args.get("pivots") 690 if pivots: 691 sources[pivots[0].alias] = expression 692 else: 693 sources[source_name] = scope.sources[table_name] 694 elif source_name in sources: 695 sources[find_new_name(sources, table_name)] = expression 696 else: 697 sources[source_name] = expression 698 699 # Make sure to not include the joins twice 700 if expression is not scope.expression: 701 expressions.extend(join.this for join in expression.args.get("joins") or []) 702 703 continue 704 705 if not isinstance(expression, exp.DerivedTable): 706 continue 707 708 if isinstance(expression, exp.UDTF): 709 lateral_sources = sources 710 scope_type = ScopeType.UDTF 711 scopes = scope.udtf_scopes 712 elif _is_derived_table(expression): 713 lateral_sources = None 714 scope_type = ScopeType.DERIVED_TABLE 715 scopes = scope.derived_table_scopes 716 expressions.extend(join.this for join in expression.args.get("joins") or []) 717 else: 718 # Makes sure we check for possible sources in nested table constructs 719 expressions.append(expression.this) 720 expressions.extend(join.this for join in expression.args.get("joins") or []) 721 continue 722 723 for child_scope in _traverse_scope( 724 scope.branch( 725 expression, 726 lateral_sources=lateral_sources, 727 outer_columns=expression.alias_column_names, 728 scope_type=scope_type, 729 ) 730 ): 731 yield child_scope 732 733 # Tables without aliases will be set as "" 734 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 735 # Until then, this means that only a single, unaliased derived table is allowed (rather, 736 # the latest one wins. 737 sources[expression.alias] = child_scope 738 739 # append the final child_scope yielded 740 scopes.append(child_scope) 741 scope.table_scopes.append(child_scope) 742 743 scope.sources.update(sources) 744 745 746def _traverse_subqueries(scope): 747 for subquery in scope.subqueries: 748 top = None 749 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 750 yield child_scope 751 top = child_scope 752 scope.subquery_scopes.append(top) 753 754 755def _traverse_udtfs(scope): 756 if isinstance(scope.expression, exp.Unnest): 757 expressions = scope.expression.expressions 758 elif isinstance(scope.expression, exp.Lateral): 759 expressions = [scope.expression.this] 760 else: 761 expressions = [] 762 763 sources = {} 764 for expression in expressions: 765 if _is_derived_table(expression): 766 top = None 767 for child_scope in _traverse_scope( 768 scope.branch( 769 expression, 770 scope_type=ScopeType.SUBQUERY, 771 outer_columns=expression.alias_column_names, 772 ) 773 ): 774 yield child_scope 775 top = child_scope 776 sources[expression.alias] = child_scope 777 778 scope.subquery_scopes.append(top) 779 780 scope.sources.update(sources) 781 782 783def walk_in_scope(expression, bfs=True, prune=None): 784 """ 785 Returns a generator object which visits all nodes in the syntrax tree, stopping at 786 nodes that start child scopes. 787 788 Args: 789 expression (exp.Expression): 790 bfs (bool): if set to True the BFS traversal order will be applied, 791 otherwise the DFS traversal will be used instead. 792 prune ((node, parent, arg_key) -> bool): callable that returns True if 793 the generator should stop traversing this branch of the tree. 794 795 Yields: 796 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 797 """ 798 # We'll use this variable to pass state into the dfs generator. 799 # Whenever we set it to True, we exclude a subtree from traversal. 800 crossed_scope_boundary = False 801 802 for node in expression.walk( 803 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 804 ): 805 crossed_scope_boundary = False 806 807 yield node 808 809 if node is expression: 810 continue 811 if ( 812 isinstance(node, exp.CTE) 813 or ( 814 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 815 and (_is_derived_table(node) or isinstance(node, exp.UDTF)) 816 ) 817 or isinstance(node, exp.UNWRAPPED_QUERIES) 818 ): 819 crossed_scope_boundary = True 820 821 if isinstance(node, (exp.Subquery, exp.UDTF)): 822 # The following args are not actually in the inner scope, so we should visit them 823 for key in ("joins", "laterals", "pivots"): 824 for arg in node.args.get(key) or []: 825 yield from walk_in_scope(arg, bfs=bfs) 826 827 828def find_all_in_scope(expression, expression_types, bfs=True): 829 """ 830 Returns a generator object which visits all nodes in this scope and only yields those that 831 match at least one of the specified expression types. 832 833 This does NOT traverse into subscopes. 834 835 Args: 836 expression (exp.Expression): 837 expression_types (tuple[type]|type): the expression type(s) to match. 838 bfs (bool): True to use breadth-first search, False to use depth-first. 839 840 Yields: 841 exp.Expression: nodes 842 """ 843 for expression in walk_in_scope(expression, bfs=bfs): 844 if isinstance(expression, tuple(ensure_collection(expression_types))): 845 yield expression 846 847 848def find_in_scope(expression, expression_types, bfs=True): 849 """ 850 Returns the first node in this scope which matches at least one of the specified types. 851 852 This does NOT traverse into subscopes. 853 854 Args: 855 expression (exp.Expression): 856 expression_types (tuple[type]|type): the expression type(s) to match. 857 bfs (bool): True to use breadth-first search, False to use depth-first. 858 859 Returns: 860 exp.Expression: the node which matches the criteria or None if no node matching 861 the criteria was found. 862 """ 863 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
19class ScopeType(Enum): 20 ROOT = auto() 21 SUBQUERY = auto() 22 DERIVED_TABLE = auto() 23 CTE = auto() 24 UNION = auto() 25 UDTF = auto()
An enumeration.
Inherited Members
- enum.Enum
- name
- value
28class Scope: 29 """ 30 Selection scope. 31 32 Attributes: 33 expression (exp.Select|exp.SetOperation): Root expression of this scope 34 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 35 a Table expression or another Scope instance. For example: 36 SELECT * FROM x {"x": Table(this="x")} 37 SELECT * FROM x AS y {"y": Table(this="x")} 38 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 39 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 40 For example: 41 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 42 The LATERAL VIEW EXPLODE gets x as a source. 43 cte_sources (dict[str, Scope]): Sources from CTES 44 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 45 defines a column list for the alias of this scope, this is that list of columns. 46 For example: 47 SELECT * FROM (SELECT ...) AS y(col1, col2) 48 The inner query would have `["col1", "col2"]` for its `outer_columns` 49 parent (Scope): Parent scope 50 scope_type (ScopeType): Type of this scope, relative to it's parent 51 subquery_scopes (list[Scope]): List of all child scopes for subqueries 52 cte_scopes (list[Scope]): List of all child scopes for CTEs 53 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 54 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 55 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 56 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 57 a list of the left and right child scopes. 58 """ 59 60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache() 88 89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._stars = None 93 self._derived_tables = None 94 self._udtfs = None 95 self._tables = None 96 self._ctes = None 97 self._subqueries = None 98 self._selected_sources = None 99 self._columns = None 100 self._external_columns = None 101 self._join_hints = None 102 self._pivots = None 103 self._references = None 104 105 def branch( 106 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 107 ): 108 """Branch from the current scope to a new, inner scope""" 109 return Scope( 110 expression=expression.unnest(), 111 sources=sources.copy() if sources else None, 112 parent=self, 113 scope_type=scope_type, 114 cte_sources={**self.cte_sources, **(cte_sources or {})}, 115 lateral_sources=lateral_sources.copy() if lateral_sources else None, 116 can_be_correlated=self.can_be_correlated 117 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 118 **kwargs, 119 ) 120 121 def _collect(self): 122 self._tables = [] 123 self._ctes = [] 124 self._subqueries = [] 125 self._derived_tables = [] 126 self._udtfs = [] 127 self._raw_columns = [] 128 self._stars = [] 129 self._join_hints = [] 130 131 for node in self.walk(bfs=False): 132 if node is self.expression: 133 continue 134 135 if isinstance(node, exp.Dot) and node.is_star: 136 self._stars.append(node) 137 elif isinstance(node, exp.Column): 138 if isinstance(node.this, exp.Star): 139 self._stars.append(node) 140 else: 141 self._raw_columns.append(node) 142 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 143 self._tables.append(node) 144 elif isinstance(node, exp.JoinHint): 145 self._join_hints.append(node) 146 elif isinstance(node, exp.UDTF): 147 self._udtfs.append(node) 148 elif isinstance(node, exp.CTE): 149 self._ctes.append(node) 150 elif _is_derived_table(node) and isinstance( 151 node.parent, (exp.From, exp.Join, exp.Subquery) 152 ): 153 self._derived_tables.append(node) 154 elif isinstance(node, exp.UNWRAPPED_QUERIES): 155 self._subqueries.append(node) 156 157 self._collected = True 158 159 def _ensure_collected(self): 160 if not self._collected: 161 self._collect() 162 163 def walk(self, bfs=True, prune=None): 164 return walk_in_scope(self.expression, bfs=bfs, prune=None) 165 166 def find(self, *expression_types, bfs=True): 167 return find_in_scope(self.expression, expression_types, bfs=bfs) 168 169 def find_all(self, *expression_types, bfs=True): 170 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 171 172 def replace(self, old, new): 173 """ 174 Replace `old` with `new`. 175 176 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 177 178 Args: 179 old (exp.Expression): old node 180 new (exp.Expression): new node 181 """ 182 old.replace(new) 183 self.clear_cache() 184 185 @property 186 def tables(self): 187 """ 188 List of tables in this scope. 189 190 Returns: 191 list[exp.Table]: tables 192 """ 193 self._ensure_collected() 194 return self._tables 195 196 @property 197 def ctes(self): 198 """ 199 List of CTEs in this scope. 200 201 Returns: 202 list[exp.CTE]: ctes 203 """ 204 self._ensure_collected() 205 return self._ctes 206 207 @property 208 def derived_tables(self): 209 """ 210 List of derived tables in this scope. 211 212 For example: 213 SELECT * FROM (SELECT ...) <- that's a derived table 214 215 Returns: 216 list[exp.Subquery]: derived tables 217 """ 218 self._ensure_collected() 219 return self._derived_tables 220 221 @property 222 def udtfs(self): 223 """ 224 List of "User Defined Tabular Functions" in this scope. 225 226 Returns: 227 list[exp.UDTF]: UDTFs 228 """ 229 self._ensure_collected() 230 return self._udtfs 231 232 @property 233 def subqueries(self): 234 """ 235 List of subqueries in this scope. 236 237 For example: 238 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 239 240 Returns: 241 list[exp.Select | exp.SetOperation]: subqueries 242 """ 243 self._ensure_collected() 244 return self._subqueries 245 246 @property 247 def stars(self) -> t.List[exp.Column | exp.Dot]: 248 """ 249 List of star expressions (columns or dots) in this scope. 250 """ 251 self._ensure_collected() 252 return self._stars 253 254 @property 255 def columns(self): 256 """ 257 List of columns in this scope. 258 259 Returns: 260 list[exp.Column]: Column instances in this scope, plus any 261 Columns that reference this scope from correlated subqueries. 262 """ 263 if self._columns is None: 264 self._ensure_collected() 265 columns = self._raw_columns 266 267 external_columns = [ 268 column 269 for scope in itertools.chain( 270 self.subquery_scopes, 271 self.udtf_scopes, 272 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 273 ) 274 for column in scope.external_columns 275 ] 276 277 named_selects = set(self.expression.named_selects) 278 279 self._columns = [] 280 for column in columns + external_columns: 281 ancestor = column.find_ancestor( 282 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 283 ) 284 if ( 285 not ancestor 286 or column.table 287 or isinstance(ancestor, exp.Select) 288 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 289 or ( 290 isinstance(ancestor, exp.Order) 291 and ( 292 isinstance(ancestor.parent, exp.Window) 293 or column.name not in named_selects 294 ) 295 ) 296 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 297 ): 298 self._columns.append(column) 299 300 return self._columns 301 302 @property 303 def selected_sources(self): 304 """ 305 Mapping of nodes and sources that are actually selected from in this scope. 306 307 That is, all tables in a schema are selectable at any point. But a 308 table only becomes a selected source if it's included in a FROM or JOIN clause. 309 310 Returns: 311 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 312 """ 313 if self._selected_sources is None: 314 result = {} 315 316 for name, node in self.references: 317 if name in result: 318 raise OptimizeError(f"Alias already used: {name}") 319 if name in self.sources: 320 result[name] = (node, self.sources[name]) 321 322 self._selected_sources = result 323 return self._selected_sources 324 325 @property 326 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 327 if self._references is None: 328 self._references = [] 329 330 for table in self.tables: 331 self._references.append((table.alias_or_name, table)) 332 for expression in itertools.chain(self.derived_tables, self.udtfs): 333 self._references.append( 334 ( 335 expression.alias, 336 expression if expression.args.get("pivots") else expression.unnest(), 337 ) 338 ) 339 340 return self._references 341 342 @property 343 def external_columns(self): 344 """ 345 Columns that appear to reference sources in outer scopes. 346 347 Returns: 348 list[exp.Column]: Column instances that don't reference 349 sources in the current scope. 350 """ 351 if self._external_columns is None: 352 if isinstance(self.expression, exp.SetOperation): 353 left, right = self.union_scopes 354 self._external_columns = left.external_columns + right.external_columns 355 else: 356 self._external_columns = [ 357 c for c in self.columns if c.table not in self.selected_sources 358 ] 359 360 return self._external_columns 361 362 @property 363 def unqualified_columns(self): 364 """ 365 Unqualified columns in the current scope. 366 367 Returns: 368 list[exp.Column]: Unqualified columns 369 """ 370 return [c for c in self.columns if not c.table] 371 372 @property 373 def join_hints(self): 374 """ 375 Hints that exist in the scope that reference tables 376 377 Returns: 378 list[exp.JoinHint]: Join hints that are referenced within the scope 379 """ 380 if self._join_hints is None: 381 return [] 382 return self._join_hints 383 384 @property 385 def pivots(self): 386 if not self._pivots: 387 self._pivots = [ 388 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 389 ] 390 391 return self._pivots 392 393 def source_columns(self, source_name): 394 """ 395 Get all columns in the current scope for a particular source. 396 397 Args: 398 source_name (str): Name of the source 399 Returns: 400 list[exp.Column]: Column instances that reference `source_name` 401 """ 402 return [column for column in self.columns if column.table == source_name] 403 404 @property 405 def is_subquery(self): 406 """Determine if this scope is a subquery""" 407 return self.scope_type == ScopeType.SUBQUERY 408 409 @property 410 def is_derived_table(self): 411 """Determine if this scope is a derived table""" 412 return self.scope_type == ScopeType.DERIVED_TABLE 413 414 @property 415 def is_union(self): 416 """Determine if this scope is a union""" 417 return self.scope_type == ScopeType.UNION 418 419 @property 420 def is_cte(self): 421 """Determine if this scope is a common table expression""" 422 return self.scope_type == ScopeType.CTE 423 424 @property 425 def is_root(self): 426 """Determine if this is the root scope""" 427 return self.scope_type == ScopeType.ROOT 428 429 @property 430 def is_udtf(self): 431 """Determine if this scope is a UDTF (User Defined Table Function)""" 432 return self.scope_type == ScopeType.UDTF 433 434 @property 435 def is_correlated_subquery(self): 436 """Determine if this scope is a correlated subquery""" 437 return bool(self.can_be_correlated and self.external_columns) 438 439 def rename_source(self, old_name, new_name): 440 """Rename a source in this scope""" 441 columns = self.sources.pop(old_name or "", []) 442 self.sources[new_name] = columns 443 444 def add_source(self, name, source): 445 """Add a source to this scope""" 446 self.sources[name] = source 447 self.clear_cache() 448 449 def remove_source(self, name): 450 """Remove a source from this scope""" 451 self.sources.pop(name, None) 452 self.clear_cache() 453 454 def __repr__(self): 455 return f"Scope<{self.expression.sql()}>" 456 457 def traverse(self): 458 """ 459 Traverse the scope tree from this node. 460 461 Yields: 462 Scope: scope instances in depth-first-search post-order 463 """ 464 stack = [self] 465 result = [] 466 while stack: 467 scope = stack.pop() 468 result.append(scope) 469 stack.extend( 470 itertools.chain( 471 scope.cte_scopes, 472 scope.union_scopes, 473 scope.table_scopes, 474 scope.subquery_scopes, 475 ) 476 ) 477 478 yield from reversed(result) 479 480 def ref_count(self): 481 """ 482 Count the number of times each scope in this tree is referenced. 483 484 Returns: 485 dict[int, int]: Mapping of Scope instance ID to reference count 486 """ 487 scope_ref_count = defaultdict(lambda: 0) 488 489 for scope in self.traverse(): 490 for _, source in scope.selected_sources.values(): 491 scope_ref_count[id(source)] += 1 492 493 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.SetOperation): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- cte_sources (dict[str, Scope]): Sources from CTES
- outer_columns (list[str]): If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_columns
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 can_be_correlated=None, 70 ): 71 self.expression = expression 72 self.sources = sources or {} 73 self.lateral_sources = lateral_sources or {} 74 self.cte_sources = cte_sources or {} 75 self.sources.update(self.lateral_sources) 76 self.sources.update(self.cte_sources) 77 self.outer_columns = outer_columns or [] 78 self.parent = parent 79 self.scope_type = scope_type 80 self.subquery_scopes = [] 81 self.derived_table_scopes = [] 82 self.table_scopes = [] 83 self.cte_scopes = [] 84 self.union_scopes = [] 85 self.udtf_scopes = [] 86 self.can_be_correlated = can_be_correlated 87 self.clear_cache()
89 def clear_cache(self): 90 self._collected = False 91 self._raw_columns = None 92 self._stars = None 93 self._derived_tables = None 94 self._udtfs = None 95 self._tables = None 96 self._ctes = None 97 self._subqueries = None 98 self._selected_sources = None 99 self._columns = None 100 self._external_columns = None 101 self._join_hints = None 102 self._pivots = None 103 self._references = None
105 def branch( 106 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 107 ): 108 """Branch from the current scope to a new, inner scope""" 109 return Scope( 110 expression=expression.unnest(), 111 sources=sources.copy() if sources else None, 112 parent=self, 113 scope_type=scope_type, 114 cte_sources={**self.cte_sources, **(cte_sources or {})}, 115 lateral_sources=lateral_sources.copy() if lateral_sources else None, 116 can_be_correlated=self.can_be_correlated 117 or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), 118 **kwargs, 119 )
Branch from the current scope to a new, inner scope
172 def replace(self, old, new): 173 """ 174 Replace `old` with `new`. 175 176 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 177 178 Args: 179 old (exp.Expression): old node 180 new (exp.Expression): new node 181 """ 182 old.replace(new) 183 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
185 @property 186 def tables(self): 187 """ 188 List of tables in this scope. 189 190 Returns: 191 list[exp.Table]: tables 192 """ 193 self._ensure_collected() 194 return self._tables
List of tables in this scope.
Returns:
list[exp.Table]: tables
196 @property 197 def ctes(self): 198 """ 199 List of CTEs in this scope. 200 201 Returns: 202 list[exp.CTE]: ctes 203 """ 204 self._ensure_collected() 205 return self._ctes
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
207 @property 208 def derived_tables(self): 209 """ 210 List of derived tables in this scope. 211 212 For example: 213 SELECT * FROM (SELECT ...) <- that's a derived table 214 215 Returns: 216 list[exp.Subquery]: derived tables 217 """ 218 self._ensure_collected() 219 return self._derived_tables
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
221 @property 222 def udtfs(self): 223 """ 224 List of "User Defined Tabular Functions" in this scope. 225 226 Returns: 227 list[exp.UDTF]: UDTFs 228 """ 229 self._ensure_collected() 230 return self._udtfs
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
232 @property 233 def subqueries(self): 234 """ 235 List of subqueries in this scope. 236 237 For example: 238 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 239 240 Returns: 241 list[exp.Select | exp.SetOperation]: subqueries 242 """ 243 self._ensure_collected() 244 return self._subqueries
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.SetOperation]: subqueries
246 @property 247 def stars(self) -> t.List[exp.Column | exp.Dot]: 248 """ 249 List of star expressions (columns or dots) in this scope. 250 """ 251 self._ensure_collected() 252 return self._stars
List of star expressions (columns or dots) in this scope.
254 @property 255 def columns(self): 256 """ 257 List of columns in this scope. 258 259 Returns: 260 list[exp.Column]: Column instances in this scope, plus any 261 Columns that reference this scope from correlated subqueries. 262 """ 263 if self._columns is None: 264 self._ensure_collected() 265 columns = self._raw_columns 266 267 external_columns = [ 268 column 269 for scope in itertools.chain( 270 self.subquery_scopes, 271 self.udtf_scopes, 272 (dts for dts in self.derived_table_scopes if dts.can_be_correlated), 273 ) 274 for column in scope.external_columns 275 ] 276 277 named_selects = set(self.expression.named_selects) 278 279 self._columns = [] 280 for column in columns + external_columns: 281 ancestor = column.find_ancestor( 282 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 283 ) 284 if ( 285 not ancestor 286 or column.table 287 or isinstance(ancestor, exp.Select) 288 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 289 or ( 290 isinstance(ancestor, exp.Order) 291 and ( 292 isinstance(ancestor.parent, exp.Window) 293 or column.name not in named_selects 294 ) 295 ) 296 or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") 297 ): 298 self._columns.append(column) 299 300 return self._columns
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
302 @property 303 def selected_sources(self): 304 """ 305 Mapping of nodes and sources that are actually selected from in this scope. 306 307 That is, all tables in a schema are selectable at any point. But a 308 table only becomes a selected source if it's included in a FROM or JOIN clause. 309 310 Returns: 311 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 312 """ 313 if self._selected_sources is None: 314 result = {} 315 316 for name, node in self.references: 317 if name in result: 318 raise OptimizeError(f"Alias already used: {name}") 319 if name in self.sources: 320 result[name] = (node, self.sources[name]) 321 322 self._selected_sources = result 323 return self._selected_sources
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
325 @property 326 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 327 if self._references is None: 328 self._references = [] 329 330 for table in self.tables: 331 self._references.append((table.alias_or_name, table)) 332 for expression in itertools.chain(self.derived_tables, self.udtfs): 333 self._references.append( 334 ( 335 expression.alias, 336 expression if expression.args.get("pivots") else expression.unnest(), 337 ) 338 ) 339 340 return self._references
342 @property 343 def external_columns(self): 344 """ 345 Columns that appear to reference sources in outer scopes. 346 347 Returns: 348 list[exp.Column]: Column instances that don't reference 349 sources in the current scope. 350 """ 351 if self._external_columns is None: 352 if isinstance(self.expression, exp.SetOperation): 353 left, right = self.union_scopes 354 self._external_columns = left.external_columns + right.external_columns 355 else: 356 self._external_columns = [ 357 c for c in self.columns if c.table not in self.selected_sources 358 ] 359 360 return self._external_columns
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
362 @property 363 def unqualified_columns(self): 364 """ 365 Unqualified columns in the current scope. 366 367 Returns: 368 list[exp.Column]: Unqualified columns 369 """ 370 return [c for c in self.columns if not c.table]
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
372 @property 373 def join_hints(self): 374 """ 375 Hints that exist in the scope that reference tables 376 377 Returns: 378 list[exp.JoinHint]: Join hints that are referenced within the scope 379 """ 380 if self._join_hints is None: 381 return [] 382 return self._join_hints
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
393 def source_columns(self, source_name): 394 """ 395 Get all columns in the current scope for a particular source. 396 397 Args: 398 source_name (str): Name of the source 399 Returns: 400 list[exp.Column]: Column instances that reference `source_name` 401 """ 402 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
404 @property 405 def is_subquery(self): 406 """Determine if this scope is a subquery""" 407 return self.scope_type == ScopeType.SUBQUERY
Determine if this scope is a subquery
409 @property 410 def is_derived_table(self): 411 """Determine if this scope is a derived table""" 412 return self.scope_type == ScopeType.DERIVED_TABLE
Determine if this scope is a derived table
414 @property 415 def is_union(self): 416 """Determine if this scope is a union""" 417 return self.scope_type == ScopeType.UNION
Determine if this scope is a union
419 @property 420 def is_cte(self): 421 """Determine if this scope is a common table expression""" 422 return self.scope_type == ScopeType.CTE
Determine if this scope is a common table expression
424 @property 425 def is_root(self): 426 """Determine if this is the root scope""" 427 return self.scope_type == ScopeType.ROOT
Determine if this is the root scope
429 @property 430 def is_udtf(self): 431 """Determine if this scope is a UDTF (User Defined Table Function)""" 432 return self.scope_type == ScopeType.UDTF
Determine if this scope is a UDTF (User Defined Table Function)
439 def rename_source(self, old_name, new_name): 440 """Rename a source in this scope""" 441 columns = self.sources.pop(old_name or "", []) 442 self.sources[new_name] = columns
Rename a source in this scope
444 def add_source(self, name, source): 445 """Add a source to this scope""" 446 self.sources[name] = source 447 self.clear_cache()
Add a source to this scope
449 def remove_source(self, name): 450 """Remove a source from this scope""" 451 self.sources.pop(name, None) 452 self.clear_cache()
Remove a source from this scope
457 def traverse(self): 458 """ 459 Traverse the scope tree from this node. 460 461 Yields: 462 Scope: scope instances in depth-first-search post-order 463 """ 464 stack = [self] 465 result = [] 466 while stack: 467 scope = stack.pop() 468 result.append(scope) 469 stack.extend( 470 itertools.chain( 471 scope.cte_scopes, 472 scope.union_scopes, 473 scope.table_scopes, 474 scope.subquery_scopes, 475 ) 476 ) 477 478 yield from reversed(result)
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
480 def ref_count(self): 481 """ 482 Count the number of times each scope in this tree is referenced. 483 484 Returns: 485 dict[int, int]: Mapping of Scope instance ID to reference count 486 """ 487 scope_ref_count = defaultdict(lambda: 0) 488 489 for scope in self.traverse(): 490 for _, source in scope.selected_sources.values(): 491 scope_ref_count[id(source)] += 1 492 493 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
496def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 497 """ 498 Traverse an expression by its "scopes". 499 500 "Scope" represents the current context of a Select statement. 501 502 This is helpful for optimizing queries, where we need more information than 503 the expression tree itself. For example, we might care about the source 504 names within a subquery. Returns a list because a generator could result in 505 incomplete properties which is confusing. 506 507 Examples: 508 >>> import sqlglot 509 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 510 >>> scopes = traverse_scope(expression) 511 >>> scopes[0].expression.sql(), list(scopes[0].sources) 512 ('SELECT a FROM x', ['x']) 513 >>> scopes[1].expression.sql(), list(scopes[1].sources) 514 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 515 516 Args: 517 expression: Expression to traverse 518 519 Returns: 520 A list of the created scope instances 521 """ 522 if isinstance(expression, TRAVERSABLES): 523 return list(_traverse_scope(Scope(expression))) 524 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression: Expression to traverse
Returns:
A list of the created scope instances
527def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 528 """ 529 Build a scope tree. 530 531 Args: 532 expression: Expression to build the scope tree for. 533 534 Returns: 535 The root scope 536 """ 537 return seq_get(traverse_scope(expression), -1)
Build a scope tree.
Arguments:
- expression: Expression to build the scope tree for.
Returns:
The root scope
784def walk_in_scope(expression, bfs=True, prune=None): 785 """ 786 Returns a generator object which visits all nodes in the syntrax tree, stopping at 787 nodes that start child scopes. 788 789 Args: 790 expression (exp.Expression): 791 bfs (bool): if set to True the BFS traversal order will be applied, 792 otherwise the DFS traversal will be used instead. 793 prune ((node, parent, arg_key) -> bool): callable that returns True if 794 the generator should stop traversing this branch of the tree. 795 796 Yields: 797 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 798 """ 799 # We'll use this variable to pass state into the dfs generator. 800 # Whenever we set it to True, we exclude a subtree from traversal. 801 crossed_scope_boundary = False 802 803 for node in expression.walk( 804 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 805 ): 806 crossed_scope_boundary = False 807 808 yield node 809 810 if node is expression: 811 continue 812 if ( 813 isinstance(node, exp.CTE) 814 or ( 815 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 816 and (_is_derived_table(node) or isinstance(node, exp.UDTF)) 817 ) 818 or isinstance(node, exp.UNWRAPPED_QUERIES) 819 ): 820 crossed_scope_boundary = True 821 822 if isinstance(node, (exp.Subquery, exp.UDTF)): 823 # The following args are not actually in the inner scope, so we should visit them 824 for key in ("joins", "laterals", "pivots"): 825 for arg in node.args.get(key) or []: 826 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
829def find_all_in_scope(expression, expression_types, bfs=True): 830 """ 831 Returns a generator object which visits all nodes in this scope and only yields those that 832 match at least one of the specified expression types. 833 834 This does NOT traverse into subscopes. 835 836 Args: 837 expression (exp.Expression): 838 expression_types (tuple[type]|type): the expression type(s) to match. 839 bfs (bool): True to use breadth-first search, False to use depth-first. 840 841 Yields: 842 exp.Expression: nodes 843 """ 844 for expression in walk_in_scope(expression, bfs=bfs): 845 if isinstance(expression, tuple(ensure_collection(expression_types))): 846 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
849def find_in_scope(expression, expression_types, bfs=True): 850 """ 851 Returns the first node in this scope which matches at least one of the specified types. 852 853 This does NOT traverse into subscopes. 854 855 Args: 856 expression (exp.Expression): 857 expression_types (tuple[type]|type): the expression type(s) to match. 858 bfs (bool): True to use breadth-first search, False to use depth-first. 859 860 Returns: 861 exp.Expression: the node which matches the criteria or None if no node matching 862 the criteria was found. 863 """ 864 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.