sqlglot.optimizer.scope
1from __future__ import annotations 2 3import itertools 4import logging 5import typing as t 6from collections import defaultdict 7from enum import Enum, auto 8 9from sqlglot import exp 10from sqlglot.errors import OptimizeError 11from sqlglot.helper import ensure_collection, find_new_name, seq_get 12 13logger = logging.getLogger("sqlglot") 14 15TRAVERSABLES = (exp.Query, exp.DDL, exp.DML) 16 17 18class ScopeType(Enum): 19 ROOT = auto() 20 SUBQUERY = auto() 21 DERIVED_TABLE = auto() 22 CTE = auto() 23 UNION = auto() 24 UDTF = auto() 25 26 27class Scope: 28 """ 29 Selection scope. 30 31 Attributes: 32 expression (exp.Select|exp.Union): Root expression of this scope 33 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 34 a Table expression or another Scope instance. For example: 35 SELECT * FROM x {"x": Table(this="x")} 36 SELECT * FROM x AS y {"y": Table(this="x")} 37 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 38 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 39 For example: 40 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 41 The LATERAL VIEW EXPLODE gets x as a source. 42 cte_sources (dict[str, Scope]): Sources from CTES 43 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 44 defines a column list for the alias of this scope, this is that list of columns. 45 For example: 46 SELECT * FROM (SELECT ...) AS y(col1, col2) 47 The inner query would have `["col1", "col2"]` for its `outer_columns` 48 parent (Scope): Parent scope 49 scope_type (ScopeType): Type of this scope, relative to it's parent 50 subquery_scopes (list[Scope]): List of all child scopes for subqueries 51 cte_scopes (list[Scope]): List of all child scopes for CTEs 52 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 53 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 54 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 55 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 56 a list of the left and right child scopes. 57 """ 58 59 def __init__( 60 self, 61 expression, 62 sources=None, 63 outer_columns=None, 64 parent=None, 65 scope_type=ScopeType.ROOT, 66 lateral_sources=None, 67 cte_sources=None, 68 ): 69 self.expression = expression 70 self.sources = sources or {} 71 self.lateral_sources = lateral_sources or {} 72 self.cte_sources = cte_sources or {} 73 self.sources.update(self.lateral_sources) 74 self.sources.update(self.cte_sources) 75 self.outer_columns = outer_columns or [] 76 self.parent = parent 77 self.scope_type = scope_type 78 self.subquery_scopes = [] 79 self.derived_table_scopes = [] 80 self.table_scopes = [] 81 self.cte_scopes = [] 82 self.union_scopes = [] 83 self.udtf_scopes = [] 84 self.clear_cache() 85 86 def clear_cache(self): 87 self._collected = False 88 self._raw_columns = None 89 self._derived_tables = None 90 self._udtfs = None 91 self._tables = None 92 self._ctes = None 93 self._subqueries = None 94 self._selected_sources = None 95 self._columns = None 96 self._external_columns = None 97 self._join_hints = None 98 self._pivots = None 99 self._references = None 100 101 def branch( 102 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 103 ): 104 """Branch from the current scope to a new, inner scope""" 105 return Scope( 106 expression=expression.unnest(), 107 sources=sources.copy() if sources else None, 108 parent=self, 109 scope_type=scope_type, 110 cte_sources={**self.cte_sources, **(cte_sources or {})}, 111 lateral_sources=lateral_sources.copy() if lateral_sources else None, 112 **kwargs, 113 ) 114 115 def _collect(self): 116 self._tables = [] 117 self._ctes = [] 118 self._subqueries = [] 119 self._derived_tables = [] 120 self._udtfs = [] 121 self._raw_columns = [] 122 self._join_hints = [] 123 124 for node in self.walk(bfs=False): 125 if node is self.expression: 126 continue 127 128 if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 129 self._raw_columns.append(node) 130 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 131 self._tables.append(node) 132 elif isinstance(node, exp.JoinHint): 133 self._join_hints.append(node) 134 elif isinstance(node, exp.UDTF): 135 self._udtfs.append(node) 136 elif isinstance(node, exp.CTE): 137 self._ctes.append(node) 138 elif _is_derived_table(node) and isinstance( 139 node.parent, (exp.From, exp.Join, exp.Subquery) 140 ): 141 self._derived_tables.append(node) 142 elif isinstance(node, exp.UNWRAPPED_QUERIES): 143 self._subqueries.append(node) 144 145 self._collected = True 146 147 def _ensure_collected(self): 148 if not self._collected: 149 self._collect() 150 151 def walk(self, bfs=True, prune=None): 152 return walk_in_scope(self.expression, bfs=bfs, prune=None) 153 154 def find(self, *expression_types, bfs=True): 155 return find_in_scope(self.expression, expression_types, bfs=bfs) 156 157 def find_all(self, *expression_types, bfs=True): 158 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 159 160 def replace(self, old, new): 161 """ 162 Replace `old` with `new`. 163 164 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 165 166 Args: 167 old (exp.Expression): old node 168 new (exp.Expression): new node 169 """ 170 old.replace(new) 171 self.clear_cache() 172 173 @property 174 def tables(self): 175 """ 176 List of tables in this scope. 177 178 Returns: 179 list[exp.Table]: tables 180 """ 181 self._ensure_collected() 182 return self._tables 183 184 @property 185 def ctes(self): 186 """ 187 List of CTEs in this scope. 188 189 Returns: 190 list[exp.CTE]: ctes 191 """ 192 self._ensure_collected() 193 return self._ctes 194 195 @property 196 def derived_tables(self): 197 """ 198 List of derived tables in this scope. 199 200 For example: 201 SELECT * FROM (SELECT ...) <- that's a derived table 202 203 Returns: 204 list[exp.Subquery]: derived tables 205 """ 206 self._ensure_collected() 207 return self._derived_tables 208 209 @property 210 def udtfs(self): 211 """ 212 List of "User Defined Tabular Functions" in this scope. 213 214 Returns: 215 list[exp.UDTF]: UDTFs 216 """ 217 self._ensure_collected() 218 return self._udtfs 219 220 @property 221 def subqueries(self): 222 """ 223 List of subqueries in this scope. 224 225 For example: 226 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 227 228 Returns: 229 list[exp.Select | exp.Union]: subqueries 230 """ 231 self._ensure_collected() 232 return self._subqueries 233 234 @property 235 def columns(self): 236 """ 237 List of columns in this scope. 238 239 Returns: 240 list[exp.Column]: Column instances in this scope, plus any 241 Columns that reference this scope from correlated subqueries. 242 """ 243 if self._columns is None: 244 self._ensure_collected() 245 columns = self._raw_columns 246 247 external_columns = [ 248 column 249 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 250 for column in scope.external_columns 251 ] 252 253 named_selects = set(self.expression.named_selects) 254 255 self._columns = [] 256 for column in columns + external_columns: 257 ancestor = column.find_ancestor( 258 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 259 ) 260 if ( 261 not ancestor 262 or column.table 263 or isinstance(ancestor, exp.Select) 264 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 265 or ( 266 isinstance(ancestor, exp.Order) 267 and ( 268 isinstance(ancestor.parent, exp.Window) 269 or column.name not in named_selects 270 ) 271 ) 272 ): 273 self._columns.append(column) 274 275 return self._columns 276 277 @property 278 def selected_sources(self): 279 """ 280 Mapping of nodes and sources that are actually selected from in this scope. 281 282 That is, all tables in a schema are selectable at any point. But a 283 table only becomes a selected source if it's included in a FROM or JOIN clause. 284 285 Returns: 286 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 287 """ 288 if self._selected_sources is None: 289 result = {} 290 291 for name, node in self.references: 292 if name in result: 293 raise OptimizeError(f"Alias already used: {name}") 294 if name in self.sources: 295 result[name] = (node, self.sources[name]) 296 297 self._selected_sources = result 298 return self._selected_sources 299 300 @property 301 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 302 if self._references is None: 303 self._references = [] 304 305 for table in self.tables: 306 self._references.append((table.alias_or_name, table)) 307 for expression in itertools.chain(self.derived_tables, self.udtfs): 308 self._references.append( 309 ( 310 expression.alias, 311 expression if expression.args.get("pivots") else expression.unnest(), 312 ) 313 ) 314 315 return self._references 316 317 @property 318 def external_columns(self): 319 """ 320 Columns that appear to reference sources in outer scopes. 321 322 Returns: 323 list[exp.Column]: Column instances that don't reference 324 sources in the current scope. 325 """ 326 if self._external_columns is None: 327 if isinstance(self.expression, exp.Union): 328 left, right = self.union_scopes 329 self._external_columns = left.external_columns + right.external_columns 330 else: 331 self._external_columns = [ 332 c for c in self.columns if c.table not in self.selected_sources 333 ] 334 335 return self._external_columns 336 337 @property 338 def unqualified_columns(self): 339 """ 340 Unqualified columns in the current scope. 341 342 Returns: 343 list[exp.Column]: Unqualified columns 344 """ 345 return [c for c in self.columns if not c.table] 346 347 @property 348 def join_hints(self): 349 """ 350 Hints that exist in the scope that reference tables 351 352 Returns: 353 list[exp.JoinHint]: Join hints that are referenced within the scope 354 """ 355 if self._join_hints is None: 356 return [] 357 return self._join_hints 358 359 @property 360 def pivots(self): 361 if not self._pivots: 362 self._pivots = [ 363 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 364 ] 365 366 return self._pivots 367 368 def source_columns(self, source_name): 369 """ 370 Get all columns in the current scope for a particular source. 371 372 Args: 373 source_name (str): Name of the source 374 Returns: 375 list[exp.Column]: Column instances that reference `source_name` 376 """ 377 return [column for column in self.columns if column.table == source_name] 378 379 @property 380 def is_subquery(self): 381 """Determine if this scope is a subquery""" 382 return self.scope_type == ScopeType.SUBQUERY 383 384 @property 385 def is_derived_table(self): 386 """Determine if this scope is a derived table""" 387 return self.scope_type == ScopeType.DERIVED_TABLE 388 389 @property 390 def is_union(self): 391 """Determine if this scope is a union""" 392 return self.scope_type == ScopeType.UNION 393 394 @property 395 def is_cte(self): 396 """Determine if this scope is a common table expression""" 397 return self.scope_type == ScopeType.CTE 398 399 @property 400 def is_root(self): 401 """Determine if this is the root scope""" 402 return self.scope_type == ScopeType.ROOT 403 404 @property 405 def is_udtf(self): 406 """Determine if this scope is a UDTF (User Defined Table Function)""" 407 return self.scope_type == ScopeType.UDTF 408 409 @property 410 def is_correlated_subquery(self): 411 """Determine if this scope is a correlated subquery""" 412 return bool( 413 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 414 and self.external_columns 415 ) 416 417 def rename_source(self, old_name, new_name): 418 """Rename a source in this scope""" 419 columns = self.sources.pop(old_name or "", []) 420 self.sources[new_name] = columns 421 422 def add_source(self, name, source): 423 """Add a source to this scope""" 424 self.sources[name] = source 425 self.clear_cache() 426 427 def remove_source(self, name): 428 """Remove a source from this scope""" 429 self.sources.pop(name, None) 430 self.clear_cache() 431 432 def __repr__(self): 433 return f"Scope<{self.expression.sql()}>" 434 435 def traverse(self): 436 """ 437 Traverse the scope tree from this node. 438 439 Yields: 440 Scope: scope instances in depth-first-search post-order 441 """ 442 stack = [self] 443 result = [] 444 while stack: 445 scope = stack.pop() 446 result.append(scope) 447 stack.extend( 448 itertools.chain( 449 scope.cte_scopes, 450 scope.union_scopes, 451 scope.table_scopes, 452 scope.subquery_scopes, 453 ) 454 ) 455 456 yield from reversed(result) 457 458 def ref_count(self): 459 """ 460 Count the number of times each scope in this tree is referenced. 461 462 Returns: 463 dict[int, int]: Mapping of Scope instance ID to reference count 464 """ 465 scope_ref_count = defaultdict(lambda: 0) 466 467 for scope in self.traverse(): 468 for _, source in scope.selected_sources.values(): 469 scope_ref_count[id(source)] += 1 470 471 return scope_ref_count 472 473 474def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 475 """ 476 Traverse an expression by its "scopes". 477 478 "Scope" represents the current context of a Select statement. 479 480 This is helpful for optimizing queries, where we need more information than 481 the expression tree itself. For example, we might care about the source 482 names within a subquery. Returns a list because a generator could result in 483 incomplete properties which is confusing. 484 485 Examples: 486 >>> import sqlglot 487 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 488 >>> scopes = traverse_scope(expression) 489 >>> scopes[0].expression.sql(), list(scopes[0].sources) 490 ('SELECT a FROM x', ['x']) 491 >>> scopes[1].expression.sql(), list(scopes[1].sources) 492 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 493 494 Args: 495 expression: Expression to traverse 496 497 Returns: 498 A list of the created scope instances 499 """ 500 if isinstance(expression, TRAVERSABLES): 501 return list(_traverse_scope(Scope(expression))) 502 return [] 503 504 505def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 506 """ 507 Build a scope tree. 508 509 Args: 510 expression: Expression to build the scope tree for. 511 512 Returns: 513 The root scope 514 """ 515 return seq_get(traverse_scope(expression), -1) 516 517 518def _traverse_scope(scope): 519 expression = scope.expression 520 521 if isinstance(expression, exp.Select): 522 yield from _traverse_select(scope) 523 elif isinstance(expression, exp.Union): 524 yield from _traverse_ctes(scope) 525 yield from _traverse_union(scope) 526 return 527 elif isinstance(expression, exp.Subquery): 528 if scope.is_root: 529 yield from _traverse_select(scope) 530 else: 531 yield from _traverse_subqueries(scope) 532 elif isinstance(expression, exp.Table): 533 yield from _traverse_tables(scope) 534 elif isinstance(expression, exp.UDTF): 535 yield from _traverse_udtfs(scope) 536 elif isinstance(expression, exp.DDL): 537 if isinstance(expression.expression, exp.Query): 538 yield from _traverse_ctes(scope) 539 yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources)) 540 return 541 elif isinstance(expression, exp.DML): 542 yield from _traverse_ctes(scope) 543 for query in find_all_in_scope(expression, exp.Query): 544 # This check ensures we don't yield the CTE queries twice 545 if not isinstance(query.parent, exp.CTE): 546 yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources)) 547 return 548 else: 549 logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression)) 550 return 551 552 yield scope 553 554 555def _traverse_select(scope): 556 yield from _traverse_ctes(scope) 557 yield from _traverse_tables(scope) 558 yield from _traverse_subqueries(scope) 559 560 561def _traverse_union(scope): 562 prev_scope = None 563 union_scope_stack = [scope] 564 expression_stack = [scope.expression.right, scope.expression.left] 565 566 while expression_stack: 567 expression = expression_stack.pop() 568 union_scope = union_scope_stack[-1] 569 570 new_scope = union_scope.branch( 571 expression, 572 outer_columns=union_scope.outer_columns, 573 scope_type=ScopeType.UNION, 574 ) 575 576 if isinstance(expression, exp.Union): 577 yield from _traverse_ctes(new_scope) 578 579 union_scope_stack.append(new_scope) 580 expression_stack.extend([expression.right, expression.left]) 581 continue 582 583 for scope in _traverse_scope(new_scope): 584 yield scope 585 586 if prev_scope: 587 union_scope_stack.pop() 588 union_scope.union_scopes = [prev_scope, scope] 589 prev_scope = union_scope 590 591 yield union_scope 592 else: 593 prev_scope = scope 594 595 596def _traverse_ctes(scope): 597 sources = {} 598 599 for cte in scope.ctes: 600 cte_name = cte.alias 601 602 # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. 603 # thus the recursive scope is the first section of the union. 604 with_ = scope.expression.args.get("with") 605 if with_ and with_.recursive: 606 union = cte.this 607 608 if isinstance(union, exp.Union): 609 sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE) 610 611 child_scope = None 612 613 for child_scope in _traverse_scope( 614 scope.branch( 615 cte.this, 616 cte_sources=sources, 617 outer_columns=cte.alias_column_names, 618 scope_type=ScopeType.CTE, 619 ) 620 ): 621 yield child_scope 622 623 # append the final child_scope yielded 624 if child_scope: 625 sources[cte_name] = child_scope 626 scope.cte_scopes.append(child_scope) 627 628 scope.sources.update(sources) 629 scope.cte_sources.update(sources) 630 631 632def _is_derived_table(expression: exp.Subquery) -> bool: 633 """ 634 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", 635 as it doesn't introduce a new scope. If an alias is present, it shadows all names 636 under the Subquery, so that's one exception to this rule. 637 """ 638 return isinstance(expression, exp.Subquery) and bool( 639 expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) 640 ) 641 642 643def _traverse_tables(scope): 644 sources = {} 645 646 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 647 expressions = [] 648 from_ = scope.expression.args.get("from") 649 if from_: 650 expressions.append(from_.this) 651 652 for join in scope.expression.args.get("joins") or []: 653 expressions.append(join.this) 654 655 if isinstance(scope.expression, exp.Table): 656 expressions.append(scope.expression) 657 658 expressions.extend(scope.expression.args.get("laterals") or []) 659 660 for expression in expressions: 661 if isinstance(expression, exp.Table): 662 table_name = expression.name 663 source_name = expression.alias_or_name 664 665 if table_name in scope.sources and not expression.db: 666 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 667 # it is pivoted, because then we get back a new table and hence a new source. 668 pivots = expression.args.get("pivots") 669 if pivots: 670 sources[pivots[0].alias] = expression 671 else: 672 sources[source_name] = scope.sources[table_name] 673 elif source_name in sources: 674 sources[find_new_name(sources, table_name)] = expression 675 else: 676 sources[source_name] = expression 677 678 # Make sure to not include the joins twice 679 if expression is not scope.expression: 680 expressions.extend(join.this for join in expression.args.get("joins") or []) 681 682 continue 683 684 if not isinstance(expression, exp.DerivedTable): 685 continue 686 687 if isinstance(expression, exp.UDTF): 688 lateral_sources = sources 689 scope_type = ScopeType.UDTF 690 scopes = scope.udtf_scopes 691 elif _is_derived_table(expression): 692 lateral_sources = None 693 scope_type = ScopeType.DERIVED_TABLE 694 scopes = scope.derived_table_scopes 695 expressions.extend(join.this for join in expression.args.get("joins") or []) 696 else: 697 # Makes sure we check for possible sources in nested table constructs 698 expressions.append(expression.this) 699 expressions.extend(join.this for join in expression.args.get("joins") or []) 700 continue 701 702 for child_scope in _traverse_scope( 703 scope.branch( 704 expression, 705 lateral_sources=lateral_sources, 706 outer_columns=expression.alias_column_names, 707 scope_type=scope_type, 708 ) 709 ): 710 yield child_scope 711 712 # Tables without aliases will be set as "" 713 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 714 # Until then, this means that only a single, unaliased derived table is allowed (rather, 715 # the latest one wins. 716 sources[expression.alias] = child_scope 717 718 # append the final child_scope yielded 719 scopes.append(child_scope) 720 scope.table_scopes.append(child_scope) 721 722 scope.sources.update(sources) 723 724 725def _traverse_subqueries(scope): 726 for subquery in scope.subqueries: 727 top = None 728 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 729 yield child_scope 730 top = child_scope 731 scope.subquery_scopes.append(top) 732 733 734def _traverse_udtfs(scope): 735 if isinstance(scope.expression, exp.Unnest): 736 expressions = scope.expression.expressions 737 elif isinstance(scope.expression, exp.Lateral): 738 expressions = [scope.expression.this] 739 else: 740 expressions = [] 741 742 sources = {} 743 for expression in expressions: 744 if _is_derived_table(expression): 745 top = None 746 for child_scope in _traverse_scope( 747 scope.branch( 748 expression, 749 scope_type=ScopeType.SUBQUERY, 750 outer_columns=expression.alias_column_names, 751 ) 752 ): 753 yield child_scope 754 top = child_scope 755 sources[expression.alias] = child_scope 756 757 scope.subquery_scopes.append(top) 758 759 scope.sources.update(sources) 760 761 762def walk_in_scope(expression, bfs=True, prune=None): 763 """ 764 Returns a generator object which visits all nodes in the syntrax tree, stopping at 765 nodes that start child scopes. 766 767 Args: 768 expression (exp.Expression): 769 bfs (bool): if set to True the BFS traversal order will be applied, 770 otherwise the DFS traversal will be used instead. 771 prune ((node, parent, arg_key) -> bool): callable that returns True if 772 the generator should stop traversing this branch of the tree. 773 774 Yields: 775 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 776 """ 777 # We'll use this variable to pass state into the dfs generator. 778 # Whenever we set it to True, we exclude a subtree from traversal. 779 crossed_scope_boundary = False 780 781 for node in expression.walk( 782 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 783 ): 784 crossed_scope_boundary = False 785 786 yield node 787 788 if node is expression: 789 continue 790 if ( 791 isinstance(node, exp.CTE) 792 or ( 793 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 794 and (_is_derived_table(node) or isinstance(node, exp.UDTF)) 795 ) 796 or isinstance(node, exp.UNWRAPPED_QUERIES) 797 ): 798 crossed_scope_boundary = True 799 800 if isinstance(node, (exp.Subquery, exp.UDTF)): 801 # The following args are not actually in the inner scope, so we should visit them 802 for key in ("joins", "laterals", "pivots"): 803 for arg in node.args.get(key) or []: 804 yield from walk_in_scope(arg, bfs=bfs) 805 806 807def find_all_in_scope(expression, expression_types, bfs=True): 808 """ 809 Returns a generator object which visits all nodes in this scope and only yields those that 810 match at least one of the specified expression types. 811 812 This does NOT traverse into subscopes. 813 814 Args: 815 expression (exp.Expression): 816 expression_types (tuple[type]|type): the expression type(s) to match. 817 bfs (bool): True to use breadth-first search, False to use depth-first. 818 819 Yields: 820 exp.Expression: nodes 821 """ 822 for expression in walk_in_scope(expression, bfs=bfs): 823 if isinstance(expression, tuple(ensure_collection(expression_types))): 824 yield expression 825 826 827def find_in_scope(expression, expression_types, bfs=True): 828 """ 829 Returns the first node in this scope which matches at least one of the specified types. 830 831 This does NOT traverse into subscopes. 832 833 Args: 834 expression (exp.Expression): 835 expression_types (tuple[type]|type): the expression type(s) to match. 836 bfs (bool): True to use breadth-first search, False to use depth-first. 837 838 Returns: 839 exp.Expression: the node which matches the criteria or None if no node matching 840 the criteria was found. 841 """ 842 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
19class ScopeType(Enum): 20 ROOT = auto() 21 SUBQUERY = auto() 22 DERIVED_TABLE = auto() 23 CTE = auto() 24 UNION = auto() 25 UDTF = auto()
An enumeration.
Inherited Members
- enum.Enum
- name
- value
28class Scope: 29 """ 30 Selection scope. 31 32 Attributes: 33 expression (exp.Select|exp.Union): Root expression of this scope 34 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 35 a Table expression or another Scope instance. For example: 36 SELECT * FROM x {"x": Table(this="x")} 37 SELECT * FROM x AS y {"y": Table(this="x")} 38 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 39 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 40 For example: 41 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 42 The LATERAL VIEW EXPLODE gets x as a source. 43 cte_sources (dict[str, Scope]): Sources from CTES 44 outer_columns (list[str]): If this is a derived table or CTE, and the outer query 45 defines a column list for the alias of this scope, this is that list of columns. 46 For example: 47 SELECT * FROM (SELECT ...) AS y(col1, col2) 48 The inner query would have `["col1", "col2"]` for its `outer_columns` 49 parent (Scope): Parent scope 50 scope_type (ScopeType): Type of this scope, relative to it's parent 51 subquery_scopes (list[Scope]): List of all child scopes for subqueries 52 cte_scopes (list[Scope]): List of all child scopes for CTEs 53 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 54 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 55 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 56 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 57 a list of the left and right child scopes. 58 """ 59 60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 ): 70 self.expression = expression 71 self.sources = sources or {} 72 self.lateral_sources = lateral_sources or {} 73 self.cte_sources = cte_sources or {} 74 self.sources.update(self.lateral_sources) 75 self.sources.update(self.cte_sources) 76 self.outer_columns = outer_columns or [] 77 self.parent = parent 78 self.scope_type = scope_type 79 self.subquery_scopes = [] 80 self.derived_table_scopes = [] 81 self.table_scopes = [] 82 self.cte_scopes = [] 83 self.union_scopes = [] 84 self.udtf_scopes = [] 85 self.clear_cache() 86 87 def clear_cache(self): 88 self._collected = False 89 self._raw_columns = None 90 self._derived_tables = None 91 self._udtfs = None 92 self._tables = None 93 self._ctes = None 94 self._subqueries = None 95 self._selected_sources = None 96 self._columns = None 97 self._external_columns = None 98 self._join_hints = None 99 self._pivots = None 100 self._references = None 101 102 def branch( 103 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 104 ): 105 """Branch from the current scope to a new, inner scope""" 106 return Scope( 107 expression=expression.unnest(), 108 sources=sources.copy() if sources else None, 109 parent=self, 110 scope_type=scope_type, 111 cte_sources={**self.cte_sources, **(cte_sources or {})}, 112 lateral_sources=lateral_sources.copy() if lateral_sources else None, 113 **kwargs, 114 ) 115 116 def _collect(self): 117 self._tables = [] 118 self._ctes = [] 119 self._subqueries = [] 120 self._derived_tables = [] 121 self._udtfs = [] 122 self._raw_columns = [] 123 self._join_hints = [] 124 125 for node in self.walk(bfs=False): 126 if node is self.expression: 127 continue 128 129 if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 130 self._raw_columns.append(node) 131 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 132 self._tables.append(node) 133 elif isinstance(node, exp.JoinHint): 134 self._join_hints.append(node) 135 elif isinstance(node, exp.UDTF): 136 self._udtfs.append(node) 137 elif isinstance(node, exp.CTE): 138 self._ctes.append(node) 139 elif _is_derived_table(node) and isinstance( 140 node.parent, (exp.From, exp.Join, exp.Subquery) 141 ): 142 self._derived_tables.append(node) 143 elif isinstance(node, exp.UNWRAPPED_QUERIES): 144 self._subqueries.append(node) 145 146 self._collected = True 147 148 def _ensure_collected(self): 149 if not self._collected: 150 self._collect() 151 152 def walk(self, bfs=True, prune=None): 153 return walk_in_scope(self.expression, bfs=bfs, prune=None) 154 155 def find(self, *expression_types, bfs=True): 156 return find_in_scope(self.expression, expression_types, bfs=bfs) 157 158 def find_all(self, *expression_types, bfs=True): 159 return find_all_in_scope(self.expression, expression_types, bfs=bfs) 160 161 def replace(self, old, new): 162 """ 163 Replace `old` with `new`. 164 165 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 166 167 Args: 168 old (exp.Expression): old node 169 new (exp.Expression): new node 170 """ 171 old.replace(new) 172 self.clear_cache() 173 174 @property 175 def tables(self): 176 """ 177 List of tables in this scope. 178 179 Returns: 180 list[exp.Table]: tables 181 """ 182 self._ensure_collected() 183 return self._tables 184 185 @property 186 def ctes(self): 187 """ 188 List of CTEs in this scope. 189 190 Returns: 191 list[exp.CTE]: ctes 192 """ 193 self._ensure_collected() 194 return self._ctes 195 196 @property 197 def derived_tables(self): 198 """ 199 List of derived tables in this scope. 200 201 For example: 202 SELECT * FROM (SELECT ...) <- that's a derived table 203 204 Returns: 205 list[exp.Subquery]: derived tables 206 """ 207 self._ensure_collected() 208 return self._derived_tables 209 210 @property 211 def udtfs(self): 212 """ 213 List of "User Defined Tabular Functions" in this scope. 214 215 Returns: 216 list[exp.UDTF]: UDTFs 217 """ 218 self._ensure_collected() 219 return self._udtfs 220 221 @property 222 def subqueries(self): 223 """ 224 List of subqueries in this scope. 225 226 For example: 227 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 228 229 Returns: 230 list[exp.Select | exp.Union]: subqueries 231 """ 232 self._ensure_collected() 233 return self._subqueries 234 235 @property 236 def columns(self): 237 """ 238 List of columns in this scope. 239 240 Returns: 241 list[exp.Column]: Column instances in this scope, plus any 242 Columns that reference this scope from correlated subqueries. 243 """ 244 if self._columns is None: 245 self._ensure_collected() 246 columns = self._raw_columns 247 248 external_columns = [ 249 column 250 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 251 for column in scope.external_columns 252 ] 253 254 named_selects = set(self.expression.named_selects) 255 256 self._columns = [] 257 for column in columns + external_columns: 258 ancestor = column.find_ancestor( 259 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 260 ) 261 if ( 262 not ancestor 263 or column.table 264 or isinstance(ancestor, exp.Select) 265 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 266 or ( 267 isinstance(ancestor, exp.Order) 268 and ( 269 isinstance(ancestor.parent, exp.Window) 270 or column.name not in named_selects 271 ) 272 ) 273 ): 274 self._columns.append(column) 275 276 return self._columns 277 278 @property 279 def selected_sources(self): 280 """ 281 Mapping of nodes and sources that are actually selected from in this scope. 282 283 That is, all tables in a schema are selectable at any point. But a 284 table only becomes a selected source if it's included in a FROM or JOIN clause. 285 286 Returns: 287 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 288 """ 289 if self._selected_sources is None: 290 result = {} 291 292 for name, node in self.references: 293 if name in result: 294 raise OptimizeError(f"Alias already used: {name}") 295 if name in self.sources: 296 result[name] = (node, self.sources[name]) 297 298 self._selected_sources = result 299 return self._selected_sources 300 301 @property 302 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 303 if self._references is None: 304 self._references = [] 305 306 for table in self.tables: 307 self._references.append((table.alias_or_name, table)) 308 for expression in itertools.chain(self.derived_tables, self.udtfs): 309 self._references.append( 310 ( 311 expression.alias, 312 expression if expression.args.get("pivots") else expression.unnest(), 313 ) 314 ) 315 316 return self._references 317 318 @property 319 def external_columns(self): 320 """ 321 Columns that appear to reference sources in outer scopes. 322 323 Returns: 324 list[exp.Column]: Column instances that don't reference 325 sources in the current scope. 326 """ 327 if self._external_columns is None: 328 if isinstance(self.expression, exp.Union): 329 left, right = self.union_scopes 330 self._external_columns = left.external_columns + right.external_columns 331 else: 332 self._external_columns = [ 333 c for c in self.columns if c.table not in self.selected_sources 334 ] 335 336 return self._external_columns 337 338 @property 339 def unqualified_columns(self): 340 """ 341 Unqualified columns in the current scope. 342 343 Returns: 344 list[exp.Column]: Unqualified columns 345 """ 346 return [c for c in self.columns if not c.table] 347 348 @property 349 def join_hints(self): 350 """ 351 Hints that exist in the scope that reference tables 352 353 Returns: 354 list[exp.JoinHint]: Join hints that are referenced within the scope 355 """ 356 if self._join_hints is None: 357 return [] 358 return self._join_hints 359 360 @property 361 def pivots(self): 362 if not self._pivots: 363 self._pivots = [ 364 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 365 ] 366 367 return self._pivots 368 369 def source_columns(self, source_name): 370 """ 371 Get all columns in the current scope for a particular source. 372 373 Args: 374 source_name (str): Name of the source 375 Returns: 376 list[exp.Column]: Column instances that reference `source_name` 377 """ 378 return [column for column in self.columns if column.table == source_name] 379 380 @property 381 def is_subquery(self): 382 """Determine if this scope is a subquery""" 383 return self.scope_type == ScopeType.SUBQUERY 384 385 @property 386 def is_derived_table(self): 387 """Determine if this scope is a derived table""" 388 return self.scope_type == ScopeType.DERIVED_TABLE 389 390 @property 391 def is_union(self): 392 """Determine if this scope is a union""" 393 return self.scope_type == ScopeType.UNION 394 395 @property 396 def is_cte(self): 397 """Determine if this scope is a common table expression""" 398 return self.scope_type == ScopeType.CTE 399 400 @property 401 def is_root(self): 402 """Determine if this is the root scope""" 403 return self.scope_type == ScopeType.ROOT 404 405 @property 406 def is_udtf(self): 407 """Determine if this scope is a UDTF (User Defined Table Function)""" 408 return self.scope_type == ScopeType.UDTF 409 410 @property 411 def is_correlated_subquery(self): 412 """Determine if this scope is a correlated subquery""" 413 return bool( 414 (self.is_subquery or (self.parent and isinstance(self.parent.expression, exp.Lateral))) 415 and self.external_columns 416 ) 417 418 def rename_source(self, old_name, new_name): 419 """Rename a source in this scope""" 420 columns = self.sources.pop(old_name or "", []) 421 self.sources[new_name] = columns 422 423 def add_source(self, name, source): 424 """Add a source to this scope""" 425 self.sources[name] = source 426 self.clear_cache() 427 428 def remove_source(self, name): 429 """Remove a source from this scope""" 430 self.sources.pop(name, None) 431 self.clear_cache() 432 433 def __repr__(self): 434 return f"Scope<{self.expression.sql()}>" 435 436 def traverse(self): 437 """ 438 Traverse the scope tree from this node. 439 440 Yields: 441 Scope: scope instances in depth-first-search post-order 442 """ 443 stack = [self] 444 result = [] 445 while stack: 446 scope = stack.pop() 447 result.append(scope) 448 stack.extend( 449 itertools.chain( 450 scope.cte_scopes, 451 scope.union_scopes, 452 scope.table_scopes, 453 scope.subquery_scopes, 454 ) 455 ) 456 457 yield from reversed(result) 458 459 def ref_count(self): 460 """ 461 Count the number of times each scope in this tree is referenced. 462 463 Returns: 464 dict[int, int]: Mapping of Scope instance ID to reference count 465 """ 466 scope_ref_count = defaultdict(lambda: 0) 467 468 for scope in self.traverse(): 469 for _, source in scope.selected_sources.values(): 470 scope_ref_count[id(source)] += 1 471 472 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.Union): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- cte_sources (dict[str, Scope]): Sources from CTES
- outer_columns (list[str]): If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_columns
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
60 def __init__( 61 self, 62 expression, 63 sources=None, 64 outer_columns=None, 65 parent=None, 66 scope_type=ScopeType.ROOT, 67 lateral_sources=None, 68 cte_sources=None, 69 ): 70 self.expression = expression 71 self.sources = sources or {} 72 self.lateral_sources = lateral_sources or {} 73 self.cte_sources = cte_sources or {} 74 self.sources.update(self.lateral_sources) 75 self.sources.update(self.cte_sources) 76 self.outer_columns = outer_columns or [] 77 self.parent = parent 78 self.scope_type = scope_type 79 self.subquery_scopes = [] 80 self.derived_table_scopes = [] 81 self.table_scopes = [] 82 self.cte_scopes = [] 83 self.union_scopes = [] 84 self.udtf_scopes = [] 85 self.clear_cache()
87 def clear_cache(self): 88 self._collected = False 89 self._raw_columns = None 90 self._derived_tables = None 91 self._udtfs = None 92 self._tables = None 93 self._ctes = None 94 self._subqueries = None 95 self._selected_sources = None 96 self._columns = None 97 self._external_columns = None 98 self._join_hints = None 99 self._pivots = None 100 self._references = None
102 def branch( 103 self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs 104 ): 105 """Branch from the current scope to a new, inner scope""" 106 return Scope( 107 expression=expression.unnest(), 108 sources=sources.copy() if sources else None, 109 parent=self, 110 scope_type=scope_type, 111 cte_sources={**self.cte_sources, **(cte_sources or {})}, 112 lateral_sources=lateral_sources.copy() if lateral_sources else None, 113 **kwargs, 114 )
Branch from the current scope to a new, inner scope
161 def replace(self, old, new): 162 """ 163 Replace `old` with `new`. 164 165 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 166 167 Args: 168 old (exp.Expression): old node 169 new (exp.Expression): new node 170 """ 171 old.replace(new) 172 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
174 @property 175 def tables(self): 176 """ 177 List of tables in this scope. 178 179 Returns: 180 list[exp.Table]: tables 181 """ 182 self._ensure_collected() 183 return self._tables
List of tables in this scope.
Returns:
list[exp.Table]: tables
185 @property 186 def ctes(self): 187 """ 188 List of CTEs in this scope. 189 190 Returns: 191 list[exp.CTE]: ctes 192 """ 193 self._ensure_collected() 194 return self._ctes
List of CTEs in this scope.
Returns:
list[exp.CTE]: ctes
196 @property 197 def derived_tables(self): 198 """ 199 List of derived tables in this scope. 200 201 For example: 202 SELECT * FROM (SELECT ...) <- that's a derived table 203 204 Returns: 205 list[exp.Subquery]: derived tables 206 """ 207 self._ensure_collected() 208 return self._derived_tables
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
210 @property 211 def udtfs(self): 212 """ 213 List of "User Defined Tabular Functions" in this scope. 214 215 Returns: 216 list[exp.UDTF]: UDTFs 217 """ 218 self._ensure_collected() 219 return self._udtfs
List of "User Defined Tabular Functions" in this scope.
Returns:
list[exp.UDTF]: UDTFs
221 @property 222 def subqueries(self): 223 """ 224 List of subqueries in this scope. 225 226 For example: 227 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 228 229 Returns: 230 list[exp.Select | exp.Union]: subqueries 231 """ 232 self._ensure_collected() 233 return self._subqueries
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Select | exp.Union]: subqueries
235 @property 236 def columns(self): 237 """ 238 List of columns in this scope. 239 240 Returns: 241 list[exp.Column]: Column instances in this scope, plus any 242 Columns that reference this scope from correlated subqueries. 243 """ 244 if self._columns is None: 245 self._ensure_collected() 246 columns = self._raw_columns 247 248 external_columns = [ 249 column 250 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 251 for column in scope.external_columns 252 ] 253 254 named_selects = set(self.expression.named_selects) 255 256 self._columns = [] 257 for column in columns + external_columns: 258 ancestor = column.find_ancestor( 259 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table, exp.Star 260 ) 261 if ( 262 not ancestor 263 or column.table 264 or isinstance(ancestor, exp.Select) 265 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 266 or ( 267 isinstance(ancestor, exp.Order) 268 and ( 269 isinstance(ancestor.parent, exp.Window) 270 or column.name not in named_selects 271 ) 272 ) 273 ): 274 self._columns.append(column) 275 276 return self._columns
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
278 @property 279 def selected_sources(self): 280 """ 281 Mapping of nodes and sources that are actually selected from in this scope. 282 283 That is, all tables in a schema are selectable at any point. But a 284 table only becomes a selected source if it's included in a FROM or JOIN clause. 285 286 Returns: 287 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 288 """ 289 if self._selected_sources is None: 290 result = {} 291 292 for name, node in self.references: 293 if name in result: 294 raise OptimizeError(f"Alias already used: {name}") 295 if name in self.sources: 296 result[name] = (node, self.sources[name]) 297 298 self._selected_sources = result 299 return self._selected_sources
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
301 @property 302 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 303 if self._references is None: 304 self._references = [] 305 306 for table in self.tables: 307 self._references.append((table.alias_or_name, table)) 308 for expression in itertools.chain(self.derived_tables, self.udtfs): 309 self._references.append( 310 ( 311 expression.alias, 312 expression if expression.args.get("pivots") else expression.unnest(), 313 ) 314 ) 315 316 return self._references
318 @property 319 def external_columns(self): 320 """ 321 Columns that appear to reference sources in outer scopes. 322 323 Returns: 324 list[exp.Column]: Column instances that don't reference 325 sources in the current scope. 326 """ 327 if self._external_columns is None: 328 if isinstance(self.expression, exp.Union): 329 left, right = self.union_scopes 330 self._external_columns = left.external_columns + right.external_columns 331 else: 332 self._external_columns = [ 333 c for c in self.columns if c.table not in self.selected_sources 334 ] 335 336 return self._external_columns
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
338 @property 339 def unqualified_columns(self): 340 """ 341 Unqualified columns in the current scope. 342 343 Returns: 344 list[exp.Column]: Unqualified columns 345 """ 346 return [c for c in self.columns if not c.table]
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
348 @property 349 def join_hints(self): 350 """ 351 Hints that exist in the scope that reference tables 352 353 Returns: 354 list[exp.JoinHint]: Join hints that are referenced within the scope 355 """ 356 if self._join_hints is None: 357 return [] 358 return self._join_hints
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
369 def source_columns(self, source_name): 370 """ 371 Get all columns in the current scope for a particular source. 372 373 Args: 374 source_name (str): Name of the source 375 Returns: 376 list[exp.Column]: Column instances that reference `source_name` 377 """ 378 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
380 @property 381 def is_subquery(self): 382 """Determine if this scope is a subquery""" 383 return self.scope_type == ScopeType.SUBQUERY
Determine if this scope is a subquery
385 @property 386 def is_derived_table(self): 387 """Determine if this scope is a derived table""" 388 return self.scope_type == ScopeType.DERIVED_TABLE
Determine if this scope is a derived table
390 @property 391 def is_union(self): 392 """Determine if this scope is a union""" 393 return self.scope_type == ScopeType.UNION
Determine if this scope is a union
395 @property 396 def is_cte(self): 397 """Determine if this scope is a common table expression""" 398 return self.scope_type == ScopeType.CTE
Determine if this scope is a common table expression
400 @property 401 def is_root(self): 402 """Determine if this is the root scope""" 403 return self.scope_type == ScopeType.ROOT
Determine if this is the root scope
405 @property 406 def is_udtf(self): 407 """Determine if this scope is a UDTF (User Defined Table Function)""" 408 return self.scope_type == ScopeType.UDTF
Determine if this scope is a UDTF (User Defined Table Function)
418 def rename_source(self, old_name, new_name): 419 """Rename a source in this scope""" 420 columns = self.sources.pop(old_name or "", []) 421 self.sources[new_name] = columns
Rename a source in this scope
423 def add_source(self, name, source): 424 """Add a source to this scope""" 425 self.sources[name] = source 426 self.clear_cache()
Add a source to this scope
428 def remove_source(self, name): 429 """Remove a source from this scope""" 430 self.sources.pop(name, None) 431 self.clear_cache()
Remove a source from this scope
436 def traverse(self): 437 """ 438 Traverse the scope tree from this node. 439 440 Yields: 441 Scope: scope instances in depth-first-search post-order 442 """ 443 stack = [self] 444 result = [] 445 while stack: 446 scope = stack.pop() 447 result.append(scope) 448 stack.extend( 449 itertools.chain( 450 scope.cte_scopes, 451 scope.union_scopes, 452 scope.table_scopes, 453 scope.subquery_scopes, 454 ) 455 ) 456 457 yield from reversed(result)
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
459 def ref_count(self): 460 """ 461 Count the number of times each scope in this tree is referenced. 462 463 Returns: 464 dict[int, int]: Mapping of Scope instance ID to reference count 465 """ 466 scope_ref_count = defaultdict(lambda: 0) 467 468 for scope in self.traverse(): 469 for _, source in scope.selected_sources.values(): 470 scope_ref_count[id(source)] += 1 471 472 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
475def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 476 """ 477 Traverse an expression by its "scopes". 478 479 "Scope" represents the current context of a Select statement. 480 481 This is helpful for optimizing queries, where we need more information than 482 the expression tree itself. For example, we might care about the source 483 names within a subquery. Returns a list because a generator could result in 484 incomplete properties which is confusing. 485 486 Examples: 487 >>> import sqlglot 488 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 489 >>> scopes = traverse_scope(expression) 490 >>> scopes[0].expression.sql(), list(scopes[0].sources) 491 ('SELECT a FROM x', ['x']) 492 >>> scopes[1].expression.sql(), list(scopes[1].sources) 493 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 494 495 Args: 496 expression: Expression to traverse 497 498 Returns: 499 A list of the created scope instances 500 """ 501 if isinstance(expression, TRAVERSABLES): 502 return list(_traverse_scope(Scope(expression))) 503 return []
Traverse an expression by its "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression: Expression to traverse
Returns:
A list of the created scope instances
506def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 507 """ 508 Build a scope tree. 509 510 Args: 511 expression: Expression to build the scope tree for. 512 513 Returns: 514 The root scope 515 """ 516 return seq_get(traverse_scope(expression), -1)
Build a scope tree.
Arguments:
- expression: Expression to build the scope tree for.
Returns:
The root scope
763def walk_in_scope(expression, bfs=True, prune=None): 764 """ 765 Returns a generator object which visits all nodes in the syntrax tree, stopping at 766 nodes that start child scopes. 767 768 Args: 769 expression (exp.Expression): 770 bfs (bool): if set to True the BFS traversal order will be applied, 771 otherwise the DFS traversal will be used instead. 772 prune ((node, parent, arg_key) -> bool): callable that returns True if 773 the generator should stop traversing this branch of the tree. 774 775 Yields: 776 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 777 """ 778 # We'll use this variable to pass state into the dfs generator. 779 # Whenever we set it to True, we exclude a subtree from traversal. 780 crossed_scope_boundary = False 781 782 for node in expression.walk( 783 bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) 784 ): 785 crossed_scope_boundary = False 786 787 yield node 788 789 if node is expression: 790 continue 791 if ( 792 isinstance(node, exp.CTE) 793 or ( 794 isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) 795 and (_is_derived_table(node) or isinstance(node, exp.UDTF)) 796 ) 797 or isinstance(node, exp.UNWRAPPED_QUERIES) 798 ): 799 crossed_scope_boundary = True 800 801 if isinstance(node, (exp.Subquery, exp.UDTF)): 802 # The following args are not actually in the inner scope, so we should visit them 803 for key in ("joins", "laterals", "pivots"): 804 for arg in node.args.get(key) or []: 805 yield from walk_in_scope(arg, bfs=bfs)
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
- prune ((node, parent, arg_key) -> bool): callable that returns True if the generator should stop traversing this branch of the tree.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
808def find_all_in_scope(expression, expression_types, bfs=True): 809 """ 810 Returns a generator object which visits all nodes in this scope and only yields those that 811 match at least one of the specified expression types. 812 813 This does NOT traverse into subscopes. 814 815 Args: 816 expression (exp.Expression): 817 expression_types (tuple[type]|type): the expression type(s) to match. 818 bfs (bool): True to use breadth-first search, False to use depth-first. 819 820 Yields: 821 exp.Expression: nodes 822 """ 823 for expression in walk_in_scope(expression, bfs=bfs): 824 if isinstance(expression, tuple(ensure_collection(expression_types))): 825 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
828def find_in_scope(expression, expression_types, bfs=True): 829 """ 830 Returns the first node in this scope which matches at least one of the specified types. 831 832 This does NOT traverse into subscopes. 833 834 Args: 835 expression (exp.Expression): 836 expression_types (tuple[type]|type): the expression type(s) to match. 837 bfs (bool): True to use breadth-first search, False to use depth-first. 838 839 Returns: 840 exp.Expression: the node which matches the criteria or None if no node matching 841 the criteria was found. 842 """ 843 return next(find_all_in_scope(expression, expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression (exp.Expression):
- expression_types (tuple[type]|type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.