sqlglot.optimizer.scope
1import itertools 2import logging 3import typing as t 4from collections import defaultdict 5from enum import Enum, auto 6 7from sqlglot import exp 8from sqlglot.errors import OptimizeError 9from sqlglot.helper import find_new_name 10 11logger = logging.getLogger("sqlglot") 12 13 14class ScopeType(Enum): 15 ROOT = auto() 16 SUBQUERY = auto() 17 DERIVED_TABLE = auto() 18 CTE = auto() 19 UNION = auto() 20 UDTF = auto() 21 22 23class Scope: 24 """ 25 Selection scope. 26 27 Attributes: 28 expression (exp.Select|exp.Union): Root expression of this scope 29 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 30 a Table expression or another Scope instance. For example: 31 SELECT * FROM x {"x": Table(this="x")} 32 SELECT * FROM x AS y {"y": Table(this="x")} 33 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 34 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 35 For example: 36 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 37 The LATERAL VIEW EXPLODE gets x as a source. 38 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 39 defines a column list of it's alias of this scope, this is that list of columns. 40 For example: 41 SELECT * FROM (SELECT ...) AS y(col1, col2) 42 The inner query would have `["col1", "col2"]` for its `outer_column_list` 43 parent (Scope): Parent scope 44 scope_type (ScopeType): Type of this scope, relative to it's parent 45 subquery_scopes (list[Scope]): List of all child scopes for subqueries 46 cte_scopes (list[Scope]): List of all child scopes for CTEs 47 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 48 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 49 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 50 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 51 a list of the left and right child scopes. 52 """ 53 54 def __init__( 55 self, 56 expression, 57 sources=None, 58 outer_column_list=None, 59 parent=None, 60 scope_type=ScopeType.ROOT, 61 lateral_sources=None, 62 ): 63 self.expression = expression 64 self.sources = sources or {} 65 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 66 self.sources.update(self.lateral_sources) 67 self.outer_column_list = outer_column_list or [] 68 self.parent = parent 69 self.scope_type = scope_type 70 self.subquery_scopes = [] 71 self.derived_table_scopes = [] 72 self.table_scopes = [] 73 self.cte_scopes = [] 74 self.union_scopes = [] 75 self.udtf_scopes = [] 76 self.clear_cache() 77 78 def clear_cache(self): 79 self._collected = False 80 self._raw_columns = None 81 self._derived_tables = None 82 self._udtfs = None 83 self._tables = None 84 self._ctes = None 85 self._subqueries = None 86 self._selected_sources = None 87 self._columns = None 88 self._external_columns = None 89 self._join_hints = None 90 self._pivots = None 91 self._references = None 92 93 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 94 """Branch from the current scope to a new, inner scope""" 95 return Scope( 96 expression=expression.unnest(), 97 sources={**self.cte_sources, **(chain_sources or {})}, 98 parent=self, 99 scope_type=scope_type, 100 **kwargs, 101 ) 102 103 def _collect(self): 104 self._tables = [] 105 self._ctes = [] 106 self._subqueries = [] 107 self._derived_tables = [] 108 self._udtfs = [] 109 self._raw_columns = [] 110 self._join_hints = [] 111 112 for node, parent, _ in self.walk(bfs=False): 113 if node is self.expression: 114 continue 115 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 116 self._raw_columns.append(node) 117 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 118 self._tables.append(node) 119 elif isinstance(node, exp.JoinHint): 120 self._join_hints.append(node) 121 elif isinstance(node, exp.UDTF): 122 self._udtfs.append(node) 123 elif isinstance(node, exp.CTE): 124 self._ctes.append(node) 125 elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): 126 self._derived_tables.append(node) 127 elif isinstance(node, exp.Subqueryable): 128 self._subqueries.append(node) 129 130 self._collected = True 131 132 def _ensure_collected(self): 133 if not self._collected: 134 self._collect() 135 136 def walk(self, bfs=True): 137 return walk_in_scope(self.expression, bfs=bfs) 138 139 def find(self, *expression_types, bfs=True): 140 """ 141 Returns the first node in this scope which matches at least one of the specified types. 142 143 This does NOT traverse into subscopes. 144 145 Args: 146 expression_types (type): the expression type(s) to match. 147 bfs (bool): True to use breadth-first search, False to use depth-first. 148 149 Returns: 150 exp.Expression: the node which matches the criteria or None if no node matching 151 the criteria was found. 152 """ 153 return next(self.find_all(*expression_types, bfs=bfs), None) 154 155 def find_all(self, *expression_types, bfs=True): 156 """ 157 Returns a generator object which visits all nodes in this scope and only yields those that 158 match at least one of the specified expression types. 159 160 This does NOT traverse into subscopes. 161 162 Args: 163 expression_types (type): the expression type(s) to match. 164 bfs (bool): True to use breadth-first search, False to use depth-first. 165 166 Yields: 167 exp.Expression: nodes 168 """ 169 for expression, *_ in self.walk(bfs=bfs): 170 if isinstance(expression, expression_types): 171 yield expression 172 173 def replace(self, old, new): 174 """ 175 Replace `old` with `new`. 176 177 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 178 179 Args: 180 old (exp.Expression): old node 181 new (exp.Expression): new node 182 """ 183 old.replace(new) 184 self.clear_cache() 185 186 @property 187 def tables(self): 188 """ 189 List of tables in this scope. 190 191 Returns: 192 list[exp.Table]: tables 193 """ 194 self._ensure_collected() 195 return self._tables 196 197 @property 198 def ctes(self): 199 """ 200 List of CTEs in this scope. 201 202 Returns: 203 list[exp.CTE]: ctes 204 """ 205 self._ensure_collected() 206 return self._ctes 207 208 @property 209 def derived_tables(self): 210 """ 211 List of derived tables in this scope. 212 213 For example: 214 SELECT * FROM (SELECT ...) <- that's a derived table 215 216 Returns: 217 list[exp.Subquery]: derived tables 218 """ 219 self._ensure_collected() 220 return self._derived_tables 221 222 @property 223 def udtfs(self): 224 """ 225 List of "User Defined Tabular Functions" in this scope. 226 227 Returns: 228 list[exp.UDTF]: UDTFs 229 """ 230 self._ensure_collected() 231 return self._udtfs 232 233 @property 234 def subqueries(self): 235 """ 236 List of subqueries in this scope. 237 238 For example: 239 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 240 241 Returns: 242 list[exp.Subqueryable]: subqueries 243 """ 244 self._ensure_collected() 245 return self._subqueries 246 247 @property 248 def columns(self): 249 """ 250 List of columns in this scope. 251 252 Returns: 253 list[exp.Column]: Column instances in this scope, plus any 254 Columns that reference this scope from correlated subqueries. 255 """ 256 if self._columns is None: 257 self._ensure_collected() 258 columns = self._raw_columns 259 260 external_columns = [ 261 column 262 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 263 for column in scope.external_columns 264 ] 265 266 named_selects = set(self.expression.named_selects) 267 268 self._columns = [] 269 for column in columns + external_columns: 270 ancestor = column.find_ancestor( 271 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table 272 ) 273 if ( 274 not ancestor 275 or column.table 276 or isinstance(ancestor, exp.Select) 277 or ( 278 isinstance(ancestor, exp.Order) 279 and ( 280 isinstance(ancestor.parent, exp.Window) 281 or column.name not in named_selects 282 ) 283 ) 284 ): 285 self._columns.append(column) 286 287 return self._columns 288 289 @property 290 def selected_sources(self): 291 """ 292 Mapping of nodes and sources that are actually selected from in this scope. 293 294 That is, all tables in a schema are selectable at any point. But a 295 table only becomes a selected source if it's included in a FROM or JOIN clause. 296 297 Returns: 298 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 299 """ 300 if self._selected_sources is None: 301 result = {} 302 303 for name, node in self.references: 304 if name in result: 305 raise OptimizeError(f"Alias already used: {name}") 306 if name in self.sources: 307 result[name] = (node, self.sources[name]) 308 309 self._selected_sources = result 310 return self._selected_sources 311 312 @property 313 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 314 if self._references is None: 315 self._references = [] 316 317 for table in self.tables: 318 self._references.append((table.alias_or_name, table)) 319 for expression in itertools.chain(self.derived_tables, self.udtfs): 320 self._references.append( 321 ( 322 expression.alias, 323 expression if expression.args.get("pivots") else expression.unnest(), 324 ) 325 ) 326 327 return self._references 328 329 @property 330 def cte_sources(self): 331 """ 332 Sources that are CTEs. 333 334 Returns: 335 dict[str, Scope]: Mapping of source alias to Scope 336 """ 337 return { 338 alias: scope 339 for alias, scope in self.sources.items() 340 if isinstance(scope, Scope) and scope.is_cte 341 } 342 343 @property 344 def selects(self): 345 """ 346 Select expressions of this scope. 347 348 For example, for the following expression: 349 SELECT 1 as a, 2 as b FROM x 350 351 The outputs are the "1 as a" and "2 as b" expressions. 352 353 Returns: 354 list[exp.Expression]: expressions 355 """ 356 if isinstance(self.expression, exp.Union): 357 return self.expression.unnest().selects 358 return self.expression.selects 359 360 @property 361 def external_columns(self): 362 """ 363 Columns that appear to reference sources in outer scopes. 364 365 Returns: 366 list[exp.Column]: Column instances that don't reference 367 sources in the current scope. 368 """ 369 if self._external_columns is None: 370 self._external_columns = [ 371 c for c in self.columns if c.table not in self.selected_sources 372 ] 373 return self._external_columns 374 375 @property 376 def unqualified_columns(self): 377 """ 378 Unqualified columns in the current scope. 379 380 Returns: 381 list[exp.Column]: Unqualified columns 382 """ 383 return [c for c in self.columns if not c.table] 384 385 @property 386 def join_hints(self): 387 """ 388 Hints that exist in the scope that reference tables 389 390 Returns: 391 list[exp.JoinHint]: Join hints that are referenced within the scope 392 """ 393 if self._join_hints is None: 394 return [] 395 return self._join_hints 396 397 @property 398 def pivots(self): 399 if not self._pivots: 400 self._pivots = [ 401 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 402 ] 403 404 return self._pivots 405 406 def source_columns(self, source_name): 407 """ 408 Get all columns in the current scope for a particular source. 409 410 Args: 411 source_name (str): Name of the source 412 Returns: 413 list[exp.Column]: Column instances that reference `source_name` 414 """ 415 return [column for column in self.columns if column.table == source_name] 416 417 @property 418 def is_subquery(self): 419 """Determine if this scope is a subquery""" 420 return self.scope_type == ScopeType.SUBQUERY 421 422 @property 423 def is_derived_table(self): 424 """Determine if this scope is a derived table""" 425 return self.scope_type == ScopeType.DERIVED_TABLE 426 427 @property 428 def is_union(self): 429 """Determine if this scope is a union""" 430 return self.scope_type == ScopeType.UNION 431 432 @property 433 def is_cte(self): 434 """Determine if this scope is a common table expression""" 435 return self.scope_type == ScopeType.CTE 436 437 @property 438 def is_root(self): 439 """Determine if this is the root scope""" 440 return self.scope_type == ScopeType.ROOT 441 442 @property 443 def is_udtf(self): 444 """Determine if this scope is a UDTF (User Defined Table Function)""" 445 return self.scope_type == ScopeType.UDTF 446 447 @property 448 def is_correlated_subquery(self): 449 """Determine if this scope is a correlated subquery""" 450 return bool(self.is_subquery and self.external_columns) 451 452 def rename_source(self, old_name, new_name): 453 """Rename a source in this scope""" 454 columns = self.sources.pop(old_name or "", []) 455 self.sources[new_name] = columns 456 457 def add_source(self, name, source): 458 """Add a source to this scope""" 459 self.sources[name] = source 460 self.clear_cache() 461 462 def remove_source(self, name): 463 """Remove a source from this scope""" 464 self.sources.pop(name, None) 465 self.clear_cache() 466 467 def __repr__(self): 468 return f"Scope<{self.expression.sql()}>" 469 470 def traverse(self): 471 """ 472 Traverse the scope tree from this node. 473 474 Yields: 475 Scope: scope instances in depth-first-search post-order 476 """ 477 for child_scope in itertools.chain( 478 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 479 ): 480 yield from child_scope.traverse() 481 yield self 482 483 def ref_count(self): 484 """ 485 Count the number of times each scope in this tree is referenced. 486 487 Returns: 488 dict[int, int]: Mapping of Scope instance ID to reference count 489 """ 490 scope_ref_count = defaultdict(lambda: 0) 491 492 for scope in self.traverse(): 493 for _, source in scope.selected_sources.values(): 494 scope_ref_count[id(source)] += 1 495 496 return scope_ref_count 497 498 499def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 500 """ 501 Traverse an expression by it's "scopes". 502 503 "Scope" represents the current context of a Select statement. 504 505 This is helpful for optimizing queries, where we need more information than 506 the expression tree itself. For example, we might care about the source 507 names within a subquery. Returns a list because a generator could result in 508 incomplete properties which is confusing. 509 510 Examples: 511 >>> import sqlglot 512 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 513 >>> scopes = traverse_scope(expression) 514 >>> scopes[0].expression.sql(), list(scopes[0].sources) 515 ('SELECT a FROM x', ['x']) 516 >>> scopes[1].expression.sql(), list(scopes[1].sources) 517 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 518 519 Args: 520 expression (exp.Expression): expression to traverse 521 Returns: 522 list[Scope]: scope instances 523 """ 524 if not isinstance(expression, exp.Unionable): 525 return [] 526 return list(_traverse_scope(Scope(expression))) 527 528 529def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 530 """ 531 Build a scope tree. 532 533 Args: 534 expression (exp.Expression): expression to build the scope tree for 535 Returns: 536 Scope: root scope 537 """ 538 scopes = traverse_scope(expression) 539 if scopes: 540 return scopes[-1] 541 return None 542 543 544def _traverse_scope(scope): 545 if isinstance(scope.expression, exp.Select): 546 yield from _traverse_select(scope) 547 elif isinstance(scope.expression, exp.Union): 548 yield from _traverse_union(scope) 549 elif isinstance(scope.expression, exp.Subquery): 550 yield from _traverse_subqueries(scope) 551 elif isinstance(scope.expression, exp.Table): 552 # This case corresponds to a "join construct", i.e. (tbl1 JOIN tbl2 ON ..) 553 yield from _traverse_tables(scope) 554 elif isinstance(scope.expression, exp.UDTF): 555 pass 556 else: 557 logger.warning( 558 "Cannot traverse scope %s with type '%s'", scope.expression, type(scope.expression) 559 ) 560 return 561 562 yield scope 563 564 565def _traverse_select(scope): 566 yield from _traverse_ctes(scope) 567 yield from _traverse_tables(scope) 568 yield from _traverse_subqueries(scope) 569 570 571def _traverse_union(scope): 572 yield from _traverse_ctes(scope) 573 574 # The last scope to be yield should be the top most scope 575 left = None 576 for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)): 577 yield left 578 579 right = None 580 for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): 581 yield right 582 583 scope.union_scopes = [left, right] 584 585 586def _traverse_ctes(scope): 587 sources = {} 588 589 for cte in scope.ctes: 590 recursive_scope = None 591 592 # if the scope is a recursive cte, it must be in the form of 593 # base_case UNION recursive. thus the recursive scope is the first 594 # section of the union. 595 if scope.expression.args["with"].recursive: 596 union = cte.this 597 598 if isinstance(union, exp.Union): 599 recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) 600 601 child_scope = None 602 603 for child_scope in _traverse_scope( 604 scope.branch( 605 cte.this, 606 chain_sources=sources, 607 outer_column_list=cte.alias_column_names, 608 scope_type=ScopeType.CTE, 609 ) 610 ): 611 yield child_scope 612 613 alias = cte.alias 614 sources[alias] = child_scope 615 616 if recursive_scope: 617 child_scope.add_source(alias, recursive_scope) 618 619 # append the final child_scope yielded 620 if child_scope: 621 scope.cte_scopes.append(child_scope) 622 623 scope.sources.update(sources) 624 625 626def _traverse_tables(scope): 627 sources = {} 628 629 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 630 expressions = [] 631 from_ = scope.expression.args.get("from") 632 if from_: 633 expressions.append(from_.this) 634 635 for join in scope.expression.args.get("joins") or []: 636 expressions.append(join.this) 637 638 if isinstance(scope.expression, exp.Table): 639 expressions.append(scope.expression) 640 641 expressions.extend(scope.expression.args.get("laterals") or []) 642 643 for expression in expressions: 644 if isinstance(expression, exp.Table): 645 table_name = expression.name 646 source_name = expression.alias_or_name 647 648 if table_name in scope.sources and not expression.db: 649 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 650 # it is pivoted, because then we get back a new table and hence a new source. 651 pivots = expression.args.get("pivots") 652 if pivots: 653 sources[pivots[0].alias] = expression 654 else: 655 sources[source_name] = scope.sources[table_name] 656 elif source_name in sources: 657 sources[find_new_name(sources, table_name)] = expression 658 else: 659 sources[source_name] = expression 660 continue 661 662 if not isinstance(expression, exp.DerivedTable): 663 continue 664 665 if isinstance(expression, exp.UDTF): 666 lateral_sources = sources 667 scope_type = ScopeType.UDTF 668 scopes = scope.udtf_scopes 669 else: 670 lateral_sources = None 671 scope_type = ScopeType.DERIVED_TABLE 672 scopes = scope.derived_table_scopes 673 674 for child_scope in _traverse_scope( 675 scope.branch( 676 expression, 677 lateral_sources=lateral_sources, 678 outer_column_list=expression.alias_column_names, 679 scope_type=scope_type, 680 ) 681 ): 682 yield child_scope 683 684 # Tables without aliases will be set as "" 685 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 686 # Until then, this means that only a single, unaliased derived table is allowed (rather, 687 # the latest one wins. 688 alias = expression.alias 689 sources[alias] = child_scope 690 691 # append the final child_scope yielded 692 scopes.append(child_scope) 693 scope.table_scopes.append(child_scope) 694 695 scope.sources.update(sources) 696 697 698def _traverse_subqueries(scope): 699 for subquery in scope.subqueries: 700 top = None 701 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 702 yield child_scope 703 top = child_scope 704 scope.subquery_scopes.append(top) 705 706 707def walk_in_scope(expression, bfs=True): 708 """ 709 Returns a generator object which visits all nodes in the syntrax tree, stopping at 710 nodes that start child scopes. 711 712 Args: 713 expression (exp.Expression): 714 bfs (bool): if set to True the BFS traversal order will be applied, 715 otherwise the DFS traversal will be used instead. 716 717 Yields: 718 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 719 """ 720 # We'll use this variable to pass state into the dfs generator. 721 # Whenever we set it to True, we exclude a subtree from traversal. 722 prune = False 723 724 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): 725 prune = False 726 727 yield node, parent, key 728 729 if node is expression: 730 continue 731 if ( 732 isinstance(node, exp.CTE) 733 or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) 734 or isinstance(node, exp.UDTF) 735 or isinstance(node, exp.Subqueryable) 736 ): 737 prune = True
15class ScopeType(Enum): 16 ROOT = auto() 17 SUBQUERY = auto() 18 DERIVED_TABLE = auto() 19 CTE = auto() 20 UNION = auto() 21 UDTF = auto()
An enumeration.
Inherited Members
- enum.Enum
- name
- value
24class Scope: 25 """ 26 Selection scope. 27 28 Attributes: 29 expression (exp.Select|exp.Union): Root expression of this scope 30 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 31 a Table expression or another Scope instance. For example: 32 SELECT * FROM x {"x": Table(this="x")} 33 SELECT * FROM x AS y {"y": Table(this="x")} 34 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 35 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 36 For example: 37 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 38 The LATERAL VIEW EXPLODE gets x as a source. 39 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 40 defines a column list of it's alias of this scope, this is that list of columns. 41 For example: 42 SELECT * FROM (SELECT ...) AS y(col1, col2) 43 The inner query would have `["col1", "col2"]` for its `outer_column_list` 44 parent (Scope): Parent scope 45 scope_type (ScopeType): Type of this scope, relative to it's parent 46 subquery_scopes (list[Scope]): List of all child scopes for subqueries 47 cte_scopes (list[Scope]): List of all child scopes for CTEs 48 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 49 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 50 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 51 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 52 a list of the left and right child scopes. 53 """ 54 55 def __init__( 56 self, 57 expression, 58 sources=None, 59 outer_column_list=None, 60 parent=None, 61 scope_type=ScopeType.ROOT, 62 lateral_sources=None, 63 ): 64 self.expression = expression 65 self.sources = sources or {} 66 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 67 self.sources.update(self.lateral_sources) 68 self.outer_column_list = outer_column_list or [] 69 self.parent = parent 70 self.scope_type = scope_type 71 self.subquery_scopes = [] 72 self.derived_table_scopes = [] 73 self.table_scopes = [] 74 self.cte_scopes = [] 75 self.union_scopes = [] 76 self.udtf_scopes = [] 77 self.clear_cache() 78 79 def clear_cache(self): 80 self._collected = False 81 self._raw_columns = None 82 self._derived_tables = None 83 self._udtfs = None 84 self._tables = None 85 self._ctes = None 86 self._subqueries = None 87 self._selected_sources = None 88 self._columns = None 89 self._external_columns = None 90 self._join_hints = None 91 self._pivots = None 92 self._references = None 93 94 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 95 """Branch from the current scope to a new, inner scope""" 96 return Scope( 97 expression=expression.unnest(), 98 sources={**self.cte_sources, **(chain_sources or {})}, 99 parent=self, 100 scope_type=scope_type, 101 **kwargs, 102 ) 103 104 def _collect(self): 105 self._tables = [] 106 self._ctes = [] 107 self._subqueries = [] 108 self._derived_tables = [] 109 self._udtfs = [] 110 self._raw_columns = [] 111 self._join_hints = [] 112 113 for node, parent, _ in self.walk(bfs=False): 114 if node is self.expression: 115 continue 116 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 117 self._raw_columns.append(node) 118 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 119 self._tables.append(node) 120 elif isinstance(node, exp.JoinHint): 121 self._join_hints.append(node) 122 elif isinstance(node, exp.UDTF): 123 self._udtfs.append(node) 124 elif isinstance(node, exp.CTE): 125 self._ctes.append(node) 126 elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): 127 self._derived_tables.append(node) 128 elif isinstance(node, exp.Subqueryable): 129 self._subqueries.append(node) 130 131 self._collected = True 132 133 def _ensure_collected(self): 134 if not self._collected: 135 self._collect() 136 137 def walk(self, bfs=True): 138 return walk_in_scope(self.expression, bfs=bfs) 139 140 def find(self, *expression_types, bfs=True): 141 """ 142 Returns the first node in this scope which matches at least one of the specified types. 143 144 This does NOT traverse into subscopes. 145 146 Args: 147 expression_types (type): the expression type(s) to match. 148 bfs (bool): True to use breadth-first search, False to use depth-first. 149 150 Returns: 151 exp.Expression: the node which matches the criteria or None if no node matching 152 the criteria was found. 153 """ 154 return next(self.find_all(*expression_types, bfs=bfs), None) 155 156 def find_all(self, *expression_types, bfs=True): 157 """ 158 Returns a generator object which visits all nodes in this scope and only yields those that 159 match at least one of the specified expression types. 160 161 This does NOT traverse into subscopes. 162 163 Args: 164 expression_types (type): the expression type(s) to match. 165 bfs (bool): True to use breadth-first search, False to use depth-first. 166 167 Yields: 168 exp.Expression: nodes 169 """ 170 for expression, *_ in self.walk(bfs=bfs): 171 if isinstance(expression, expression_types): 172 yield expression 173 174 def replace(self, old, new): 175 """ 176 Replace `old` with `new`. 177 178 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 179 180 Args: 181 old (exp.Expression): old node 182 new (exp.Expression): new node 183 """ 184 old.replace(new) 185 self.clear_cache() 186 187 @property 188 def tables(self): 189 """ 190 List of tables in this scope. 191 192 Returns: 193 list[exp.Table]: tables 194 """ 195 self._ensure_collected() 196 return self._tables 197 198 @property 199 def ctes(self): 200 """ 201 List of CTEs in this scope. 202 203 Returns: 204 list[exp.CTE]: ctes 205 """ 206 self._ensure_collected() 207 return self._ctes 208 209 @property 210 def derived_tables(self): 211 """ 212 List of derived tables in this scope. 213 214 For example: 215 SELECT * FROM (SELECT ...) <- that's a derived table 216 217 Returns: 218 list[exp.Subquery]: derived tables 219 """ 220 self._ensure_collected() 221 return self._derived_tables 222 223 @property 224 def udtfs(self): 225 """ 226 List of "User Defined Tabular Functions" in this scope. 227 228 Returns: 229 list[exp.UDTF]: UDTFs 230 """ 231 self._ensure_collected() 232 return self._udtfs 233 234 @property 235 def subqueries(self): 236 """ 237 List of subqueries in this scope. 238 239 For example: 240 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 241 242 Returns: 243 list[exp.Subqueryable]: subqueries 244 """ 245 self._ensure_collected() 246 return self._subqueries 247 248 @property 249 def columns(self): 250 """ 251 List of columns in this scope. 252 253 Returns: 254 list[exp.Column]: Column instances in this scope, plus any 255 Columns that reference this scope from correlated subqueries. 256 """ 257 if self._columns is None: 258 self._ensure_collected() 259 columns = self._raw_columns 260 261 external_columns = [ 262 column 263 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 264 for column in scope.external_columns 265 ] 266 267 named_selects = set(self.expression.named_selects) 268 269 self._columns = [] 270 for column in columns + external_columns: 271 ancestor = column.find_ancestor( 272 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint, exp.Table 273 ) 274 if ( 275 not ancestor 276 or column.table 277 or isinstance(ancestor, exp.Select) 278 or ( 279 isinstance(ancestor, exp.Order) 280 and ( 281 isinstance(ancestor.parent, exp.Window) 282 or column.name not in named_selects 283 ) 284 ) 285 ): 286 self._columns.append(column) 287 288 return self._columns 289 290 @property 291 def selected_sources(self): 292 """ 293 Mapping of nodes and sources that are actually selected from in this scope. 294 295 That is, all tables in a schema are selectable at any point. But a 296 table only becomes a selected source if it's included in a FROM or JOIN clause. 297 298 Returns: 299 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 300 """ 301 if self._selected_sources is None: 302 result = {} 303 304 for name, node in self.references: 305 if name in result: 306 raise OptimizeError(f"Alias already used: {name}") 307 if name in self.sources: 308 result[name] = (node, self.sources[name]) 309 310 self._selected_sources = result 311 return self._selected_sources 312 313 @property 314 def references(self) -> t.List[t.Tuple[str, exp.Expression]]: 315 if self._references is None: 316 self._references = [] 317 318 for table in self.tables: 319 self._references.append((table.alias_or_name, table)) 320 for expression in itertools.chain(self.derived_tables, self.udtfs): 321 self._references.append( 322 ( 323 expression.alias, 324 expression if expression.args.get("pivots") else expression.unnest(), 325 ) 326 ) 327 328 return self._references 329 330 @property 331 def cte_sources(self): 332 """ 333 Sources that are CTEs. 334 335 Returns: 336 dict[str, Scope]: Mapping of source alias to Scope 337 """ 338 return { 339 alias: scope 340 for alias, scope in self.sources.items() 341 if isinstance(scope, Scope) and scope.is_cte 342 } 343 344 @property 345 def selects(self): 346 """ 347 Select expressions of this scope. 348 349 For example, for the following expression: 350 SELECT 1 as a, 2 as b FROM x 351 352 The outputs are the "1 as a" and "2 as b" expressions. 353 354 Returns: 355 list[exp.Expression]: expressions 356 """ 357 if isinstance(self.expression, exp.Union): 358 return self.expression.unnest().selects 359 return self.expression.selects 360 361 @property 362 def external_columns(self): 363 """ 364 Columns that appear to reference sources in outer scopes. 365 366 Returns: 367 list[exp.Column]: Column instances that don't reference 368 sources in the current scope. 369 """ 370 if self._external_columns is None: 371 self._external_columns = [ 372 c for c in self.columns if c.table not in self.selected_sources 373 ] 374 return self._external_columns 375 376 @property 377 def unqualified_columns(self): 378 """ 379 Unqualified columns in the current scope. 380 381 Returns: 382 list[exp.Column]: Unqualified columns 383 """ 384 return [c for c in self.columns if not c.table] 385 386 @property 387 def join_hints(self): 388 """ 389 Hints that exist in the scope that reference tables 390 391 Returns: 392 list[exp.JoinHint]: Join hints that are referenced within the scope 393 """ 394 if self._join_hints is None: 395 return [] 396 return self._join_hints 397 398 @property 399 def pivots(self): 400 if not self._pivots: 401 self._pivots = [ 402 pivot for _, node in self.references for pivot in node.args.get("pivots") or [] 403 ] 404 405 return self._pivots 406 407 def source_columns(self, source_name): 408 """ 409 Get all columns in the current scope for a particular source. 410 411 Args: 412 source_name (str): Name of the source 413 Returns: 414 list[exp.Column]: Column instances that reference `source_name` 415 """ 416 return [column for column in self.columns if column.table == source_name] 417 418 @property 419 def is_subquery(self): 420 """Determine if this scope is a subquery""" 421 return self.scope_type == ScopeType.SUBQUERY 422 423 @property 424 def is_derived_table(self): 425 """Determine if this scope is a derived table""" 426 return self.scope_type == ScopeType.DERIVED_TABLE 427 428 @property 429 def is_union(self): 430 """Determine if this scope is a union""" 431 return self.scope_type == ScopeType.UNION 432 433 @property 434 def is_cte(self): 435 """Determine if this scope is a common table expression""" 436 return self.scope_type == ScopeType.CTE 437 438 @property 439 def is_root(self): 440 """Determine if this is the root scope""" 441 return self.scope_type == ScopeType.ROOT 442 443 @property 444 def is_udtf(self): 445 """Determine if this scope is a UDTF (User Defined Table Function)""" 446 return self.scope_type == ScopeType.UDTF 447 448 @property 449 def is_correlated_subquery(self): 450 """Determine if this scope is a correlated subquery""" 451 return bool(self.is_subquery and self.external_columns) 452 453 def rename_source(self, old_name, new_name): 454 """Rename a source in this scope""" 455 columns = self.sources.pop(old_name or "", []) 456 self.sources[new_name] = columns 457 458 def add_source(self, name, source): 459 """Add a source to this scope""" 460 self.sources[name] = source 461 self.clear_cache() 462 463 def remove_source(self, name): 464 """Remove a source from this scope""" 465 self.sources.pop(name, None) 466 self.clear_cache() 467 468 def __repr__(self): 469 return f"Scope<{self.expression.sql()}>" 470 471 def traverse(self): 472 """ 473 Traverse the scope tree from this node. 474 475 Yields: 476 Scope: scope instances in depth-first-search post-order 477 """ 478 for child_scope in itertools.chain( 479 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 480 ): 481 yield from child_scope.traverse() 482 yield self 483 484 def ref_count(self): 485 """ 486 Count the number of times each scope in this tree is referenced. 487 488 Returns: 489 dict[int, int]: Mapping of Scope instance ID to reference count 490 """ 491 scope_ref_count = defaultdict(lambda: 0) 492 493 for scope in self.traverse(): 494 for _, source in scope.selected_sources.values(): 495 scope_ref_count[id(source)] += 1 496 497 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.Union): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_column_list
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
55 def __init__( 56 self, 57 expression, 58 sources=None, 59 outer_column_list=None, 60 parent=None, 61 scope_type=ScopeType.ROOT, 62 lateral_sources=None, 63 ): 64 self.expression = expression 65 self.sources = sources or {} 66 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 67 self.sources.update(self.lateral_sources) 68 self.outer_column_list = outer_column_list or [] 69 self.parent = parent 70 self.scope_type = scope_type 71 self.subquery_scopes = [] 72 self.derived_table_scopes = [] 73 self.table_scopes = [] 74 self.cte_scopes = [] 75 self.union_scopes = [] 76 self.udtf_scopes = [] 77 self.clear_cache()
79 def clear_cache(self): 80 self._collected = False 81 self._raw_columns = None 82 self._derived_tables = None 83 self._udtfs = None 84 self._tables = None 85 self._ctes = None 86 self._subqueries = None 87 self._selected_sources = None 88 self._columns = None 89 self._external_columns = None 90 self._join_hints = None 91 self._pivots = None 92 self._references = None
94 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 95 """Branch from the current scope to a new, inner scope""" 96 return Scope( 97 expression=expression.unnest(), 98 sources={**self.cte_sources, **(chain_sources or {})}, 99 parent=self, 100 scope_type=scope_type, 101 **kwargs, 102 )
Branch from the current scope to a new, inner scope
140 def find(self, *expression_types, bfs=True): 141 """ 142 Returns the first node in this scope which matches at least one of the specified types. 143 144 This does NOT traverse into subscopes. 145 146 Args: 147 expression_types (type): the expression type(s) to match. 148 bfs (bool): True to use breadth-first search, False to use depth-first. 149 150 Returns: 151 exp.Expression: the node which matches the criteria or None if no node matching 152 the criteria was found. 153 """ 154 return next(self.find_all(*expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.
156 def find_all(self, *expression_types, bfs=True): 157 """ 158 Returns a generator object which visits all nodes in this scope and only yields those that 159 match at least one of the specified expression types. 160 161 This does NOT traverse into subscopes. 162 163 Args: 164 expression_types (type): the expression type(s) to match. 165 bfs (bool): True to use breadth-first search, False to use depth-first. 166 167 Yields: 168 exp.Expression: nodes 169 """ 170 for expression, *_ in self.walk(bfs=bfs): 171 if isinstance(expression, expression_types): 172 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
174 def replace(self, old, new): 175 """ 176 Replace `old` with `new`. 177 178 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 179 180 Args: 181 old (exp.Expression): old node 182 new (exp.Expression): new node 183 """ 184 old.replace(new) 185 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Subqueryable]: subqueries
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
Select expressions of this scope.
For example, for the following expression: SELECT 1 as a, 2 as b FROM x
The outputs are the "1 as a" and "2 as b" expressions.
Returns:
list[exp.Expression]: expressions
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
407 def source_columns(self, source_name): 408 """ 409 Get all columns in the current scope for a particular source. 410 411 Args: 412 source_name (str): Name of the source 413 Returns: 414 list[exp.Column]: Column instances that reference `source_name` 415 """ 416 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
453 def rename_source(self, old_name, new_name): 454 """Rename a source in this scope""" 455 columns = self.sources.pop(old_name or "", []) 456 self.sources[new_name] = columns
Rename a source in this scope
458 def add_source(self, name, source): 459 """Add a source to this scope""" 460 self.sources[name] = source 461 self.clear_cache()
Add a source to this scope
463 def remove_source(self, name): 464 """Remove a source from this scope""" 465 self.sources.pop(name, None) 466 self.clear_cache()
Remove a source from this scope
471 def traverse(self): 472 """ 473 Traverse the scope tree from this node. 474 475 Yields: 476 Scope: scope instances in depth-first-search post-order 477 """ 478 for child_scope in itertools.chain( 479 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 480 ): 481 yield from child_scope.traverse() 482 yield self
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
484 def ref_count(self): 485 """ 486 Count the number of times each scope in this tree is referenced. 487 488 Returns: 489 dict[int, int]: Mapping of Scope instance ID to reference count 490 """ 491 scope_ref_count = defaultdict(lambda: 0) 492 493 for scope in self.traverse(): 494 for _, source in scope.selected_sources.values(): 495 scope_ref_count[id(source)] += 1 496 497 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
500def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 501 """ 502 Traverse an expression by it's "scopes". 503 504 "Scope" represents the current context of a Select statement. 505 506 This is helpful for optimizing queries, where we need more information than 507 the expression tree itself. For example, we might care about the source 508 names within a subquery. Returns a list because a generator could result in 509 incomplete properties which is confusing. 510 511 Examples: 512 >>> import sqlglot 513 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 514 >>> scopes = traverse_scope(expression) 515 >>> scopes[0].expression.sql(), list(scopes[0].sources) 516 ('SELECT a FROM x', ['x']) 517 >>> scopes[1].expression.sql(), list(scopes[1].sources) 518 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 519 520 Args: 521 expression (exp.Expression): expression to traverse 522 Returns: 523 list[Scope]: scope instances 524 """ 525 if not isinstance(expression, exp.Unionable): 526 return [] 527 return list(_traverse_scope(Scope(expression)))
Traverse an expression by it's "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression (exp.Expression): expression to traverse
Returns:
list[Scope]: scope instances
530def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 531 """ 532 Build a scope tree. 533 534 Args: 535 expression (exp.Expression): expression to build the scope tree for 536 Returns: 537 Scope: root scope 538 """ 539 scopes = traverse_scope(expression) 540 if scopes: 541 return scopes[-1] 542 return None
Build a scope tree.
Arguments:
- expression (exp.Expression): expression to build the scope tree for
Returns:
Scope: root scope
708def walk_in_scope(expression, bfs=True): 709 """ 710 Returns a generator object which visits all nodes in the syntrax tree, stopping at 711 nodes that start child scopes. 712 713 Args: 714 expression (exp.Expression): 715 bfs (bool): if set to True the BFS traversal order will be applied, 716 otherwise the DFS traversal will be used instead. 717 718 Yields: 719 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 720 """ 721 # We'll use this variable to pass state into the dfs generator. 722 # Whenever we set it to True, we exclude a subtree from traversal. 723 prune = False 724 725 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): 726 prune = False 727 728 yield node, parent, key 729 730 if node is expression: 731 continue 732 if ( 733 isinstance(node, exp.CTE) 734 or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) 735 or isinstance(node, exp.UDTF) 736 or isinstance(node, exp.Subqueryable) 737 ): 738 prune = True
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key