sqlglot.optimizer.scope
1import itertools 2import logging 3import typing as t 4from collections import defaultdict 5from enum import Enum, auto 6 7from sqlglot import exp 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import find_new_name 10 11logger = logging.getLogger("sqlglot") 12 13 14class ScopeType(Enum): 15 ROOT = auto() 16 SUBQUERY = auto() 17 DERIVED_TABLE = auto() 18 CTE = auto() 19 UNION = auto() 20 UDTF = auto() 21 22 23class Scope: 24 """ 25 Selection scope. 26 27 Attributes: 28 expression (exp.Select|exp.Union): Root expression of this scope 29 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 30 a Table expression or another Scope instance. For example: 31 SELECT * FROM x {"x": Table(this="x")} 32 SELECT * FROM x AS y {"y": Table(this="x")} 33 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 34 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 35 For example: 36 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 37 The LATERAL VIEW EXPLODE gets x as a source. 38 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 39 defines a column list of it's alias of this scope, this is that list of columns. 40 For example: 41 SELECT * FROM (SELECT ...) AS y(col1, col2) 42 The inner query would have `["col1", "col2"]` for its `outer_column_list` 43 parent (Scope): Parent scope 44 scope_type (ScopeType): Type of this scope, relative to it's parent 45 subquery_scopes (list[Scope]): List of all child scopes for subqueries 46 cte_scopes (list[Scope]): List of all child scopes for CTEs 47 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 48 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 49 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 50 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 51 a list of the left and right child scopes. 52 """ 53 54 def __init__( 55 self, 56 expression, 57 sources=None, 58 outer_column_list=None, 59 parent=None, 60 scope_type=ScopeType.ROOT, 61 lateral_sources=None, 62 ): 63 self.expression = expression 64 self.sources = sources or {} 65 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 66 self.sources.update(self.lateral_sources) 67 self.outer_column_list = outer_column_list or [] 68 self.parent = parent 69 self.scope_type = scope_type 70 self.subquery_scopes = [] 71 self.derived_table_scopes = [] 72 self.table_scopes = [] 73 self.cte_scopes = [] 74 self.union_scopes = [] 75 self.udtf_scopes = [] 76 self.clear_cache() 77 78 def clear_cache(self): 79 self._collected = False 80 self._raw_columns = None 81 self._derived_tables = None 82 self._udtfs = None 83 self._tables = None 84 self._ctes = None 85 self._subqueries = None 86 self._selected_sources = None 87 self._columns = None 88 self._external_columns = None 89 self._join_hints = None 90 self._pivots = None 91 self._references = None 92 93 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 94 """Branch from the current scope to a new, inner scope""" 95 return Scope( 96 expression=expression.unnest(), 97 sources={**self.cte_sources, **(chain_sources or {})}, 98 parent=self, 99 scope_type=scope_type, 100 **kwargs, 101 ) 102 103 def _collect(self): 104 self._tables = [] 105 self._ctes = [] 106 self._subqueries = [] 107 self._derived_tables = [] 108 self._udtfs = [] 109 self._raw_columns = [] 110 self._join_hints = [] 111 112 for node, parent, _ in self.walk(bfs=False): 113 if node is self.expression: 114 continue 115 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 116 self._raw_columns.append(node) 117 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 118 self._tables.append(node) 119 elif isinstance(node, exp.JoinHint): 120 self._join_hints.append(node) 121 elif isinstance(node, exp.UDTF): 122 self._udtfs.append(node) 123 elif isinstance(node, exp.CTE): 124 self._ctes.append(node) 125 elif ( 126 isinstance(node, exp.Subquery) 127 and isinstance(parent, (exp.From, exp.Join)) 128 and _is_subquery_scope(node) 129 ): 130 self._derived_tables.append(node) 131 elif isinstance(node, exp.Subqueryable): 132 self._subqueries.append(node) 133 134 self._collected = True 135 136 def _ensure_collected(self): 137 if not self._collected: 138 self._collect() 139 140 def walk(self, bfs=True): 141 return walk_in_scope(self.expression, bfs=bfs) 142 143 def find(self, *expression_types, bfs=True): 144 """ 145 Returns the first node in this scope which matches at least one of the specified types. 146 147 This does NOT traverse into subscopes. 148 149 Args: 150 expression_types (type): the expression type(s) to match. 151 bfs (bool): True to use breadth-first search, False to use depth-first. 152 153 Returns: 154 exp.Expression: the node which matches the criteria or None if no node matching 155 the criteria was found. 156 """ 157 return next(self.find_all(*expression_types, bfs=bfs), None) 158 159 def find_all(self, *expression_types, bfs=True): 160 """ 161 Returns a generator object which visits all nodes in this scope and only yields those that 162 match at least one of the specified expression types. 163 164 This does NOT traverse into subscopes. 165 166 Args: 167 expression_types (type): the expression type(s) to match. 168 bfs (bool): True to use breadth-first search, False to use depth-first. 169 170 Yields: 171 exp.Expression: nodes 172 """ 173 for expression, *_ in self.walk(bfs=bfs): 174 if isinstance(expression, expression_types): 175 yield expression 176 177 def replace(self, old, new): 178 """ 179 Replace `old` with `new`. 180 181 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 182 183 Args: 184 old (exp.Expression): old node 185 new (exp.Expression): new node 186 """ 187 old.replace(new) 188 self.clear_cache() 189 190 @property 191 def tables(self): 192 """ 193 List of tables in this scope. 194 195 Returns: 196 list[exp.Table]: tables 197 """ 198 self._ensure_collected() 199 return self._tables 200 201 @property 202 def ctes(self): 203 """ 204 List of CTEs in this scope. 205 206 Returns: 207 list[exp.CTE]: ctes 208 """ 209 self._ensure_collected() 210 return self._ctes 211 212 @property 213 def derived_tables(self): 214 """ 215 List of derived tables in this scope. 216 217 For example: 218 SELECT * FROM (SELECT ...) <- that's a derived table 219 220 Returns: 221 list[exp.Subquery]: derived tables 222 """ 223 self._ensure_collected() 224 return self._derived_tables 225 226 @property 227 def udtfs(self): 228 """ 229 List of "User Defined Tabular Functions" in this scope. 230 231 Returns: 232 list[exp.UDTF]: UDTFs 233 """ 234 self._ensure_collected() 235 return self._udtfs 236 237 @property 238 def subqueries(self): 239 """ 240 List of subqueries in this scope. 241 242 For example: 243 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 244 245 Returns: 246 list[exp.Subqueryable]: subqueries 247 """ 248 self._ensure_collected() 249 return self._subqueries 250 251 @property 252 def columns(self): 253 """ 254 List of columns in this scope. 255 256 Returns: 257 list[exp.Column]: Column instances in this scope, plus any 258 Columns that reference this scope from correlated subqueries. 259 """ 260 if self._columns is None: 261 self._ensure_collected() 262 columns = self._raw_columns 263 264 external_columns = [ 265 column 266 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 267 for column in scope.external_columns 268 ] 269 270 named_selects = set(self.expression.named_selects) 271 272 self._columns = [] 273 for column in columns + external_columns: 274 ancestor = column.find_ancestor( 275 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table 276 ) 277 if ( 278 not ancestor 279 or column.table 280 or isinstance(ancestor, exp.Select) 281 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 282 or ( 283 isinstance(ancestor, exp.Order) 284 and ( 285 isinstance(ancestor.parent, exp.Window) 286 or column.name not in named_selects 287 ) 288 ) 289 ): 290 self._columns.append(column) 291 292 return self._columns 293 294 @property 295 def selected_sources(self): 296 """ 297 Mapping of nodes and sources that are actually selected from in this scope. 298 299 That is, all tables in a schema are selectable at any point. But a 300 table only becomes a selected source if it's included in a FROM or JOIN clause. 301 302 Returns: 303 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 304 """ 305 if self._selected_sources is None: 306 result = {} 307 308 for name, node in self.references: 309 if name in result: 310 raise OptimizeError(f"Alias already used: {name}") 311 if name in self.sources: 312 result[name] = (node, self.sources[name]) 313 314 self._selected_sources = result 315 return self._selected_sources 316 317 @property 318 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 319 if self._references is None: 320 self._references = [] 321 322 for table in self.tables: 323 self._references.append((table.alias_or_name, table)) 324 for expression in itertools.chain(self.derived_tables, self.udtfs): 325 self._references.append( 326 ( 327 expression.alias, 328 expression if expression.args.get("pivots") else expression.unnest(), 329 ) 330 ) 331 332 return self._references 333 334 @property 335 def cte_sources(self): 336 """ 337 Sources that are CTEs. 338 339 Returns: 340 dict[str, Scope]: Mapping of source alias to Scope 341 """ 342 return { 343 alias: scope 344 for alias, scope in self.sources.items() 345 if isinstance(scope, Scope) and scope.is_cte 346 } 347 348 @property 349 def external_columns(self): 350 """ 351 Columns that appear to reference sources in outer scopes. 352 353 Returns: 354 list[exp.Column]: Column instances that don't reference 355 sources in the current scope. 356 """ 357 if self._external_columns is None: 358 self._external_columns = [ 359 c for c in self.columns if c.table not in self.selected_sources 360 ] 361 return self._external_columns 362 363 @property 364 def unqualified_columns(self): 365 """ 366 Unqualified columns in the current scope. 367 368 Returns: 369 list[exp.Column]: Unqualified columns 370 """ 371 return [c for c in self.columns if not c.table] 372 373 @property 374 def join_hints(self): 375 """ 376 Hints that exist in the scope that reference tables 377 378 Returns: 379 list[exp.JoinHint]: Join hints that are referenced within the scope 380 """ 381 if self._join_hints is None: 382 return [] 383 return self._join_hints 384 385 @property 386 def pivots(self): 387 if not self._pivots: 388 self._pivots = [ 389 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 390 ] 391 392 return self._pivots 393 394 def source_columns(self, source_name): 395 """ 396 Get all columns in the current scope for a particular source. 397 398 Args: 399 source_name (str): Name of the source 400 Returns: 401 list[exp.Column]: Column instances that reference `source_name` 402 """ 403 return [column for column in self.columns if column.table == source_name] 404 405 @property 406 def is_subquery(self): 407 """Determine if this scope is a subquery""" 408 return self.scope_type == ScopeType.SUBQUERY 409 410 @property 411 def is_derived_table(self): 412 """Determine if this scope is a derived table""" 413 return self.scope_type == ScopeType.DERIVED_TABLE 414 415 @property 416 def is_union(self): 417 """Determine if this scope is a union""" 418 return self.scope_type == ScopeType.UNION 419 420 @property 421 def is_cte(self): 422 """Determine if this scope is a common table expression""" 423 return self.scope_type == ScopeType.CTE 424 425 @property 426 def is_root(self): 427 """Determine if this is the root scope""" 428 return self.scope_type == ScopeType.ROOT 429 430 @property 431 def is_udtf(self): 432 """Determine if this scope is a UDTF (User Defined Table Function)""" 433 return self.scope_type == ScopeType.UDTF 434 435 @property 436 def is_correlated_subquery(self): 437 """Determine if this scope is a correlated subquery""" 438 return bool(self.is_subquery and self.external_columns) 439 440 def rename_source(self, old_name, new_name): 441 """Rename a source in this scope""" 442 columns = self.sources.pop(old_name or "", []) 443 self.sources[new_name] = columns 444 445 def add_source(self, name, source): 446 """Add a source to this scope""" 447 self.sources[name] = source 448 self.clear_cache() 449 450 def remove_source(self, name): 451 """Remove a source from this scope""" 452 self.sources.pop(name, None) 453 self.clear_cache() 454 455 def __repr__(self): 456 return f"Scope<{self.expression.sql()}>" 457 458 def traverse(self): 459 """ 460 Traverse the scope tree from this node. 461 462 Yields: 463 Scope: scope instances in depth-first-search post-order 464 """ 465 for child_scope in itertools.chain( 466 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 467 ): 468 yield from child_scope.traverse() 469 yield self 470 471 def ref_count(self): 472 """ 473 Count the number of times each scope in this tree is referenced. 474 475 Returns: 476 dict[int, int]: Mapping of Scope instance ID to reference count 477 """ 478 scope_ref_count = defaultdict(lambda: 0) 479 480 for scope in self.traverse(): 481 for _, source in scope.selected_sources.values(): 482 scope_ref_count[id(source)] += 1 483 484 return scope_ref_count 485 486 487def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 488 """ 489 Traverse an expression by it's "scopes". 490 491 "Scope" represents the current context of a Select statement. 492 493 This is helpful for optimizing queries, where we need more information than 494 the expression tree itself. For example, we might care about the source 495 names within a subquery. Returns a list because a generator could result in 496 incomplete properties which is confusing. 497 498 Examples: 499 >>> import sqlglot 500 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 501 >>> scopes = traverse_scope(expression) 502 >>> scopes[0].expression.sql(), list(scopes[0].sources) 503 ('SELECT a FROM x', ['x']) 504 >>> scopes[1].expression.sql(), list(scopes[1].sources) 505 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 506 507 Args: 508 expression (exp.Expression): expression to traverse 509 Returns: 510 list[Scope]: scope instances 511 """ 512 if not isinstance(expression, exp.Unionable): 513 return [] 514 return list(_traverse_scope(Scope(expression))) 515 516 517def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 518 """ 519 Build a scope tree. 520 521 Args: 522 expression (exp.Expression): expression to build the scope tree for 523 Returns: 524 Scope: root scope 525 """ 526 scopes = traverse_scope(expression) 527 if scopes: 528 return scopes[-1] 529 return None 530 531 532def _traverse_scope(scope): 533 if isinstance(scope.expression, exp.Select): 534 yield from _traverse_select(scope) 535 elif isinstance(scope.expression, exp.Union): 536 yield from _traverse_union(scope) 537 elif isinstance(scope.expression, exp.Subquery): 538 yield from _traverse_subqueries(scope) 539 elif isinstance(scope.expression, exp.Table): 540 yield from _traverse_tables(scope) 541 elif isinstance(scope.expression, exp.UDTF): 542 pass 543 else: 544 logger.warning( 545 "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) 546 ) 547 return 548 549 yield scope 550 551 552def _traverse_select(scope): 553 yield from _traverse_ctes(scope) 554 yield from _traverse_tables(scope) 555 yield from _traverse_subqueries(scope) 556 557 558def _traverse_union(scope): 559 yield from _traverse_ctes(scope) 560 561 # The last scope to be yield should be the top most scope 562 left = None 563 for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)): 564 yield left 565 566 right = None 567 for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): 568 yield right 569 570 scope.union_scopes = [left, right] 571 572 573def _traverse_ctes(scope): 574 sources = {} 575 576 for cte in scope.ctes: 577 recursive_scope = None 578 579 # if the scope is a recursive cte, it must be in the form of 580 # base_case UNION recursive. thus the recursive scope is the first 581 # section of the union. 582 if scope.expression.args["with"].recursive: 583 union = cte.this 584 585 if isinstance(union, exp.Union): 586 recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) 587 588 child_scope = None 589 590 for child_scope in _traverse_scope( 591 scope.branch( 592 cte.this, 593 chain_sources=sources, 594 outer_column_list=cte.alias_column_names, 595 scope_type=ScopeType.CTE, 596 ) 597 ): 598 yield child_scope 599 600 alias = cte.alias 601 sources[alias] = child_scope 602 603 if recursive_scope: 604 child_scope.add_source(alias, recursive_scope) 605 606 # append the final child_scope yielded 607 if child_scope: 608 scope.cte_scopes.append(child_scope) 609 610 scope.sources.update(sources) 611 612 613def _is_subquery_scope(expression: exp.Subquery) -> bool: 614 """ 615 We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a new scope. 616 If an alias is present, it shadows all names under the Subquery, so that's an 617 exception to this rule. 618 """ 619 return bool(not isinstance(expression.unnest(), exp.Table) or expression.alias) 620 621 622def _traverse_tables(scope): 623 sources = {} 624 625 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 626 expressions = [] 627 from_ = scope.expression.args.get("from") 628 if from_: 629 expressions.append(from_.this) 630 631 for join in scope.expression.args.get("joins") or []: 632 expressions.append(join.this) 633 634 if isinstance(scope.expression, exp.Table): 635 expressions.append(scope.expression) 636 637 expressions.extend(scope.expression.args.get("laterals") or []) 638 639 for expression in expressions: 640 if isinstance(expression, exp.Table): 641 table_name = expression.name 642 source_name = expression.alias_or_name 643 644 if table_name in scope.sources and not expression.db: 645 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 646 # it is pivoted, because then we get back a new table and hence a new source. 647 pivots = expression.args.get("pivots") 648 if pivots: 649 sources[pivots[0].alias] = expression 650 else: 651 sources[source_name] = scope.sources[table_name] 652 elif source_name in sources: 653 sources[find_new_name(sources, table_name)] = expression 654 else: 655 sources[source_name] = expression 656 657 expressions.extend(join.this for join in expression.args.get("joins") or []) 658 continue 659 660 if not isinstance(expression, exp.DerivedTable): 661 continue 662 663 if isinstance(expression, exp.UDTF): 664 lateral_sources = sources 665 scope_type = ScopeType.UDTF 666 scopes = scope.udtf_scopes 667 elif _is_subquery_scope(expression): 668 lateral_sources = None 669 scope_type = ScopeType.DERIVED_TABLE 670 scopes = scope.derived_table_scopes 671 else: 672 # Makes sure we check for possible sources in nested table constructs 673 expressions.append(expression.this) 674 expressions.extend(join.this for join in expression.args.get("joins") or []) 675 continue 676 677 for child_scope in _traverse_scope( 678 scope.branch( 679 expression, 680 lateral_sources=lateral_sources, 681 outer_column_list=expression.alias_column_names, 682 scope_type=scope_type, 683 ) 684 ): 685 yield child_scope 686 687 # Tables without aliases will be set as "" 688 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 689 # Until then, this means that only a single, unaliased derived table is allowed (rather, 690 # the latest one wins. 691 alias = expression.alias 692 sources[alias] = child_scope 693 694 # append the final child_scope yielded 695 scopes.append(child_scope) 696 scope.table_scopes.append(child_scope) 697 698 scope.sources.update(sources) 699 700 701def _traverse_subqueries(scope): 702 for subquery in scope.subqueries: 703 top = None 704 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 705 yield child_scope 706 top = child_scope 707 scope.subquery_scopes.append(top) 708 709 710def walk_in_scope(expression, bfs=True): 711 """ 712 Returns a generator object which visits all nodes in the syntrax tree, stopping at 713 nodes that start child scopes. 714 715 Args: 716 expression (exp.Expression): 717 bfs (bool): if set to True the BFS traversal order will be applied, 718 otherwise the DFS traversal will be used instead. 719 720 Yields: 721 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 722 """ 723 # We'll use this variable to pass state into the dfs generator. 724 # Whenever we set it to True, we exclude a subtree from traversal. 725 prune = False 726 727 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): 728 prune = False 729 730 yield node, parent, key 731 732 if node is expression: 733 continue 734 if ( 735 isinstance(node, exp.CTE) 736 or ( 737 isinstance(node, exp.Subquery) 738 and isinstance(parent, (exp.From, exp.Join)) 739 and _is_subquery_scope(node) 740 ) 741 or isinstance(node, exp.UDTF) 742 or isinstance(node, exp.Subqueryable) 743 ): 744 prune = True
15class ScopeType(Enum): 16 ROOT = auto() 17 SUBQUERY = auto() 18 DERIVED_TABLE = auto() 19 CTE = auto() 20 UNION = auto() 21 UDTF = auto()
An enumeration.
Inherited Members
- enum.Enum
- name
- value
24class Scope: 25 """ 26 Selection scope. 27 28 Attributes: 29 expression (exp.Select|exp.Union): Root expression of this scope 30 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 31 a Table expression or another Scope instance. For example: 32 SELECT * FROM x {"x": Table(this="x")} 33 SELECT * FROM x AS y {"y": Table(this="x")} 34 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 35 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 36 For example: 37 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 38 The LATERAL VIEW EXPLODE gets x as a source. 39 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 40 defines a column list of it's alias of this scope, this is that list of columns. 41 For example: 42 SELECT * FROM (SELECT ...) AS y(col1, col2) 43 The inner query would have `["col1", "col2"]` for its `outer_column_list` 44 parent (Scope): Parent scope 45 scope_type (ScopeType): Type of this scope, relative to it's parent 46 subquery_scopes (list[Scope]): List of all child scopes for subqueries 47 cte_scopes (list[Scope]): List of all child scopes for CTEs 48 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 49 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 50 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 51 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 52 a list of the left and right child scopes. 53 """ 54 55 def __init__( 56 self, 57 expression, 58 sources=None, 59 outer_column_list=None, 60 parent=None, 61 scope_type=ScopeType.ROOT, 62 lateral_sources=None, 63 ): 64 self.expression = expression 65 self.sources = sources or {} 66 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 67 self.sources.update(self.lateral_sources) 68 self.outer_column_list = outer_column_list or [] 69 self.parent = parent 70 self.scope_type = scope_type 71 self.subquery_scopes = [] 72 self.derived_table_scopes = [] 73 self.table_scopes = [] 74 self.cte_scopes = [] 75 self.union_scopes = [] 76 self.udtf_scopes = [] 77 self.clear_cache() 78 79 def clear_cache(self): 80 self._collected = False 81 self._raw_columns = None 82 self._derived_tables = None 83 self._udtfs = None 84 self._tables = None 85 self._ctes = None 86 self._subqueries = None 87 self._selected_sources = None 88 self._columns = None 89 self._external_columns = None 90 self._join_hints = None 91 self._pivots = None 92 self._references = None 93 94 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 95 """Branch from the current scope to a new, inner scope""" 96 return Scope( 97 expression=expression.unnest(), 98 sources={**self.cte_sources, **(chain_sources or {})}, 99 parent=self, 100 scope_type=scope_type, 101 **kwargs, 102 ) 103 104 def _collect(self): 105 self._tables = [] 106 self._ctes = [] 107 self._subqueries = [] 108 self._derived_tables = [] 109 self._udtfs = [] 110 self._raw_columns = [] 111 self._join_hints = [] 112 113 for node, parent, _ in self.walk(bfs=False): 114 if node is self.expression: 115 continue 116 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 117 self._raw_columns.append(node) 118 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 119 self._tables.append(node) 120 elif isinstance(node, exp.JoinHint): 121 self._join_hints.append(node) 122 elif isinstance(node, exp.UDTF): 123 self._udtfs.append(node) 124 elif isinstance(node, exp.CTE): 125 self._ctes.append(node) 126 elif ( 127 isinstance(node, exp.Subquery) 128 and isinstance(parent, (exp.From, exp.Join)) 129 and _is_subquery_scope(node) 130 ): 131 self._derived_tables.append(node) 132 elif isinstance(node, exp.Subqueryable): 133 self._subqueries.append(node) 134 135 self._collected = True 136 137 def _ensure_collected(self): 138 if not self._collected: 139 self._collect() 140 141 def walk(self, bfs=True): 142 return walk_in_scope(self.expression, bfs=bfs) 143 144 def find(self, *expression_types, bfs=True): 145 """ 146 Returns the first node in this scope which matches at least one of the specified types. 147 148 This does NOT traverse into subscopes. 149 150 Args: 151 expression_types (type): the expression type(s) to match. 152 bfs (bool): True to use breadth-first search, False to use depth-first. 153 154 Returns: 155 exp.Expression: the node which matches the criteria or None if no node matching 156 the criteria was found. 157 """ 158 return next(self.find_all(*expression_types, bfs=bfs), None) 159 160 def find_all(self, *expression_types, bfs=True): 161 """ 162 Returns a generator object which visits all nodes in this scope and only yields those that 163 match at least one of the specified expression types. 164 165 This does NOT traverse into subscopes. 166 167 Args: 168 expression_types (type): the expression type(s) to match. 169 bfs (bool): True to use breadth-first search, False to use depth-first. 170 171 Yields: 172 exp.Expression: nodes 173 """ 174 for expression, *_ in self.walk(bfs=bfs): 175 if isinstance(expression, expression_types): 176 yield expression 177 178 def replace(self, old, new): 179 """ 180 Replace `old` with `new`. 181 182 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 183 184 Args: 185 old (exp.Expression): old node 186 new (exp.Expression): new node 187 """ 188 old.replace(new) 189 self.clear_cache() 190 191 @property 192 def tables(self): 193 """ 194 List of tables in this scope. 195 196 Returns: 197 list[exp.Table]: tables 198 """ 199 self._ensure_collected() 200 return self._tables 201 202 @property 203 def ctes(self): 204 """ 205 List of CTEs in this scope. 206 207 Returns: 208 list[exp.CTE]: ctes 209 """ 210 self._ensure_collected() 211 return self._ctes 212 213 @property 214 def derived_tables(self): 215 """ 216 List of derived tables in this scope. 217 218 For example: 219 SELECT * FROM (SELECT ...) <- that's a derived table 220 221 Returns: 222 list[exp.Subquery]: derived tables 223 """ 224 self._ensure_collected() 225 return self._derived_tables 226 227 @property 228 def udtfs(self): 229 """ 230 List of "User Defined Tabular Functions" in this scope. 231 232 Returns: 233 list[exp.UDTF]: UDTFs 234 """ 235 self._ensure_collected() 236 return self._udtfs 237 238 @property 239 def subqueries(self): 240 """ 241 List of subqueries in this scope. 242 243 For example: 244 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 245 246 Returns: 247 list[exp.Subqueryable]: subqueries 248 """ 249 self._ensure_collected() 250 return self._subqueries 251 252 @property 253 def columns(self): 254 """ 255 List of columns in this scope. 256 257 Returns: 258 list[exp.Column]: Column instances in this scope, plus any 259 Columns that reference this scope from correlated subqueries. 260 """ 261 if self._columns is None: 262 self._ensure_collected() 263 columns = self._raw_columns 264 265 external_columns = [ 266 column 267 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 268 for column in scope.external_columns 269 ] 270 271 named_selects = set(self.expression.named_selects) 272 273 self._columns = [] 274 for column in columns + external_columns: 275 ancestor = column.find_ancestor( 276 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table 277 ) 278 if ( 279 not ancestor 280 or column.table 281 or isinstance(ancestor, exp.Select) 282 or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) 283 or ( 284 isinstance(ancestor, exp.Order) 285 and ( 286 isinstance(ancestor.parent, exp.Window) 287 or column.name not in named_selects 288 ) 289 ) 290 ): 291 self._columns.append(column) 292 293 return self._columns 294 295 @property 296 def selected_sources(self): 297 """ 298 Mapping of nodes and sources that are actually selected from in this scope. 299 300 That is, all tables in a schema are selectable at any point. But a 301 table only becomes a selected source if it's included in a FROM or JOIN clause. 302 303 Returns: 304 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 305 """ 306 if self._selected_sources is None: 307 result = {} 308 309 for name, node in self.references: 310 if name in result: 311 raise OptimizeError(f"Alias already used: {name}") 312 if name in self.sources: 313 result[name] = (node, self.sources[name]) 314 315 self._selected_sources = result 316 return self._selected_sources 317 318 @property 319 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 320 if self._references is None: 321 self._references = [] 322 323 for table in self.tables: 324 self._references.append((table.alias_or_name, table)) 325 for expression in itertools.chain(self.derived_tables, self.udtfs): 326 self._references.append( 327 ( 328 expression.alias, 329 expression if expression.args.get("pivots") else expression.unnest(), 330 ) 331 ) 332 333 return self._references 334 335 @property 336 def cte_sources(self): 337 """ 338 Sources that are CTEs. 339 340 Returns: 341 dict[str, Scope]: Mapping of source alias to Scope 342 """ 343 return { 344 alias: scope 345 for alias, scope in self.sources.items() 346 if isinstance(scope, Scope) and scope.is_cte 347 } 348 349 @property 350 def external_columns(self): 351 """ 352 Columns that appear to reference sources in outer scopes. 353 354 Returns: 355 list[exp.Column]: Column instances that don't reference 356 sources in the current scope. 357 """ 358 if self._external_columns is None: 359 self._external_columns = [ 360 c for c in self.columns if c.table not in self.selected_sources 361 ] 362 return self._external_columns 363 364 @property 365 def unqualified_columns(self): 366 """ 367 Unqualified columns in the current scope. 368 369 Returns: 370 list[exp.Column]: Unqualified columns 371 """ 372 return [c for c in self.columns if not c.table] 373 374 @property 375 def join_hints(self): 376 """ 377 Hints that exist in the scope that reference tables 378 379 Returns: 380 list[exp.JoinHint]: Join hints that are referenced within the scope 381 """ 382 if self._join_hints is None: 383 return [] 384 return self._join_hints 385 386 @property 387 def pivots(self): 388 if not self._pivots: 389 self._pivots = [ 390 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 391 ] 392 393 return self._pivots 394 395 def source_columns(self, source_name): 396 """ 397 Get all columns in the current scope for a particular source. 398 399 Args: 400 source_name (str): Name of the source 401 Returns: 402 list[exp.Column]: Column instances that reference `source_name` 403 """ 404 return [column for column in self.columns if column.table == source_name] 405 406 @property 407 def is_subquery(self): 408 """Determine if this scope is a subquery""" 409 return self.scope_type == ScopeType.SUBQUERY 410 411 @property 412 def is_derived_table(self): 413 """Determine if this scope is a derived table""" 414 return self.scope_type == ScopeType.DERIVED_TABLE 415 416 @property 417 def is_union(self): 418 """Determine if this scope is a union""" 419 return self.scope_type == ScopeType.UNION 420 421 @property 422 def is_cte(self): 423 """Determine if this scope is a common table expression""" 424 return self.scope_type == ScopeType.CTE 425 426 @property 427 def is_root(self): 428 """Determine if this is the root scope""" 429 return self.scope_type == ScopeType.ROOT 430 431 @property 432 def is_udtf(self): 433 """Determine if this scope is a UDTF (User Defined Table Function)""" 434 return self.scope_type == ScopeType.UDTF 435 436 @property 437 def is_correlated_subquery(self): 438 """Determine if this scope is a correlated subquery""" 439 return bool(self.is_subquery and self.external_columns) 440 441 def rename_source(self, old_name, new_name): 442 """Rename a source in this scope""" 443 columns = self.sources.pop(old_name or "", []) 444 self.sources[new_name] = columns 445 446 def add_source(self, name, source): 447 """Add a source to this scope""" 448 self.sources[name] = source 449 self.clear_cache() 450 451 def remove_source(self, name): 452 """Remove a source from this scope""" 453 self.sources.pop(name, None) 454 self.clear_cache() 455 456 def __repr__(self): 457 return f"Scope<{self.expression.sql()}>" 458 459 def traverse(self): 460 """ 461 Traverse the scope tree from this node. 462 463 Yields: 464 Scope: scope instances in depth-first-search post-order 465 """ 466 for child_scope in itertools.chain( 467 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 468 ): 469 yield from child_scope.traverse() 470 yield self 471 472 def ref_count(self): 473 """ 474 Count the number of times each scope in this tree is referenced. 475 476 Returns: 477 dict[int, int]: Mapping of Scope instance ID to reference count 478 """ 479 scope_ref_count = defaultdict(lambda: 0) 480 481 for scope in self.traverse(): 482 for _, source in scope.selected_sources.values(): 483 scope_ref_count[id(source)] += 1 484 485 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.Union): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_column_list
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
55 def __init__( 56 self, 57 expression, 58 sources=None, 59 outer_column_list=None, 60 parent=None, 61 scope_type=ScopeType.ROOT, 62 lateral_sources=None, 63 ): 64 self.expression = expression 65 self.sources = sources or {} 66 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 67 self.sources.update(self.lateral_sources) 68 self.outer_column_list = outer_column_list or [] 69 self.parent = parent 70 self.scope_type = scope_type 71 self.subquery_scopes = [] 72 self.derived_table_scopes = [] 73 self.table_scopes = [] 74 self.cte_scopes = [] 75 self.union_scopes = [] 76 self.udtf_scopes = [] 77 self.clear_cache()
79 def clear_cache(self): 80 self._collected = False 81 self._raw_columns = None 82 self._derived_tables = None 83 self._udtfs = None 84 self._tables = None 85 self._ctes = None 86 self._subqueries = None 87 self._selected_sources = None 88 self._columns = None 89 self._external_columns = None 90 self._join_hints = None 91 self._pivots = None 92 self._references = None
94 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 95 """Branch from the current scope to a new, inner scope""" 96 return Scope( 97 expression=expression.unnest(), 98 sources={**self.cte_sources, **(chain_sources or {})}, 99 parent=self, 100 scope_type=scope_type, 101 **kwargs, 102 )
Branch from the current scope to a new, inner scope
144 def find(self, *expression_types, bfs=True): 145 """ 146 Returns the first node in this scope which matches at least one of the specified types. 147 148 This does NOT traverse into subscopes. 149 150 Args: 151 expression_types (type): the expression type(s) to match. 152 bfs (bool): True to use breadth-first search, False to use depth-first. 153 154 Returns: 155 exp.Expression: the node which matches the criteria or None if no node matching 156 the criteria was found. 157 """ 158 return next(self.find_all(*expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.
160 def find_all(self, *expression_types, bfs=True): 161 """ 162 Returns a generator object which visits all nodes in this scope and only yields those that 163 match at least one of the specified expression types. 164 165 This does NOT traverse into subscopes. 166 167 Args: 168 expression_types (type): the expression type(s) to match. 169 bfs (bool): True to use breadth-first search, False to use depth-first. 170 171 Yields: 172 exp.Expression: nodes 173 """ 174 for expression, *_ in self.walk(bfs=bfs): 175 if isinstance(expression, expression_types): 176 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
178 def replace(self, old, new): 179 """ 180 Replace `old` with `new`. 181 182 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 183 184 Args: 185 old (exp.Expression): old node 186 new (exp.Expression): new node 187 """ 188 old.replace(new) 189 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Subqueryable]: subqueries
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
395 def source_columns(self, source_name): 396 """ 397 Get all columns in the current scope for a particular source. 398 399 Args: 400 source_name (str): Name of the source 401 Returns: 402 list[exp.Column]: Column instances that reference `source_name` 403 """ 404 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
441 def rename_source(self, old_name, new_name): 442 """Rename a source in this scope""" 443 columns = self.sources.pop(old_name or "", []) 444 self.sources[new_name] = columns
Rename a source in this scope
446 def add_source(self, name, source): 447 """Add a source to this scope""" 448 self.sources[name] = source 449 self.clear_cache()
Add a source to this scope
451 def remove_source(self, name): 452 """Remove a source from this scope""" 453 self.sources.pop(name, None) 454 self.clear_cache()
Remove a source from this scope
459 def traverse(self): 460 """ 461 Traverse the scope tree from this node. 462 463 Yields: 464 Scope: scope instances in depth-first-search post-order 465 """ 466 for child_scope in itertools.chain( 467 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 468 ): 469 yield from child_scope.traverse() 470 yield self
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
472 def ref_count(self): 473 """ 474 Count the number of times each scope in this tree is referenced. 475 476 Returns: 477 dict[int, int]: Mapping of Scope instance ID to reference count 478 """ 479 scope_ref_count = defaultdict(lambda: 0) 480 481 for scope in self.traverse(): 482 for _, source in scope.selected_sources.values(): 483 scope_ref_count[id(source)] += 1 484 485 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
488def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 489 """ 490 Traverse an expression by it's "scopes". 491 492 "Scope" represents the current context of a Select statement. 493 494 This is helpful for optimizing queries, where we need more information than 495 the expression tree itself. For example, we might care about the source 496 names within a subquery. Returns a list because a generator could result in 497 incomplete properties which is confusing. 498 499 Examples: 500 >>> import sqlglot 501 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 502 >>> scopes = traverse_scope(expression) 503 >>> scopes[0].expression.sql(), list(scopes[0].sources) 504 ('SELECT a FROM x', ['x']) 505 >>> scopes[1].expression.sql(), list(scopes[1].sources) 506 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 507 508 Args: 509 expression (exp.Expression): expression to traverse 510 Returns: 511 list[Scope]: scope instances 512 """ 513 if not isinstance(expression, exp.Unionable): 514 return [] 515 return list(_traverse_scope(Scope(expression)))
Traverse an expression by it's "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression (exp.Expression): expression to traverse
Returns:
list[Scope]: scope instances
518def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 519 """ 520 Build a scope tree. 521 522 Args: 523 expression (exp.Expression): expression to build the scope tree for 524 Returns: 525 Scope: root scope 526 """ 527 scopes = traverse_scope(expression) 528 if scopes: 529 return scopes[-1] 530 return None
Build a scope tree.
Arguments:
- expression (exp.Expression): expression to build the scope tree for
Returns:
Scope: root scope
711def walk_in_scope(expression, bfs=True): 712 """ 713 Returns a generator object which visits all nodes in the syntrax tree, stopping at 714 nodes that start child scopes. 715 716 Args: 717 expression (exp.Expression): 718 bfs (bool): if set to True the BFS traversal order will be applied, 719 otherwise the DFS traversal will be used instead. 720 721 Yields: 722 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 723 """ 724 # We'll use this variable to pass state into the dfs generator. 725 # Whenever we set it to True, we exclude a subtree from traversal. 726 prune = False 727 728 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): 729 prune = False 730 731 yield node, parent, key 732 733 if node is expression: 734 continue 735 if ( 736 isinstance(node, exp.CTE) 737 or ( 738 isinstance(node, exp.Subquery) 739 and isinstance(parent, (exp.From, exp.Join)) 740 and _is_subquery_scope(node) 741 ) 742 or isinstance(node, exp.UDTF) 743 or isinstance(node, exp.Subqueryable) 744 ): 745 prune = True
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key