sqlglot.optimizer.scope
1import itertools 2import typing as t 3from collections import defaultdict 4from enum import Enum, auto 5 6from sqlglot import exp 7from sqlglot.errors import OptimizeError 8from sqlglot.helper import find_new_name 9 10 11class ScopeType(Enum): 12 ROOT = auto() 13 SUBQUERY = auto() 14 DERIVED_TABLE = auto() 15 CTE = auto() 16 UNION = auto() 17 UDTF = auto() 18 19 20class Scope: 21 """ 22 Selection scope. 23 24 Attributes: 25 expression (exp.Select|exp.Union): Root expression of this scope 26 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 27 a Table expression or another Scope instance. For example: 28 SELECT * FROM x {"x": Table(this="x")} 29 SELECT * FROM x AS y {"y": Table(this="x")} 30 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 31 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 32 For example: 33 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 34 The LATERAL VIEW EXPLODE gets x as a source. 35 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 36 defines a column list of it's alias of this scope, this is that list of columns. 37 For example: 38 SELECT * FROM (SELECT ...) AS y(col1, col2) 39 The inner query would have `["col1", "col2"]` for its `outer_column_list` 40 parent (Scope): Parent scope 41 scope_type (ScopeType): Type of this scope, relative to it's parent 42 subquery_scopes (list[Scope]): List of all child scopes for subqueries 43 cte_scopes (list[Scope]): List of all child scopes for CTEs 44 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 45 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 46 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 47 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 48 a list of the left and right child scopes. 49 """ 50 51 def __init__( 52 self, 53 expression, 54 sources=None, 55 outer_column_list=None, 56 parent=None, 57 scope_type=ScopeType.ROOT, 58 lateral_sources=None, 59 ): 60 self.expression = expression 61 self.sources = sources or {} 62 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 63 self.sources.update(self.lateral_sources) 64 self.outer_column_list = outer_column_list or [] 65 self.parent = parent 66 self.scope_type = scope_type 67 self.subquery_scopes = [] 68 self.derived_table_scopes = [] 69 self.table_scopes = [] 70 self.cte_scopes = [] 71 self.union_scopes = [] 72 self.udtf_scopes = [] 73 self.clear_cache() 74 75 def clear_cache(self): 76 self._collected = False 77 self._raw_columns = None 78 self._derived_tables = None 79 self._udtfs = None 80 self._tables = None 81 self._ctes = None 82 self._subqueries = None 83 self._selected_sources = None 84 self._columns = None 85 self._external_columns = None 86 self._join_hints = None 87 self._pivots = None 88 89 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 90 """Branch from the current scope to a new, inner scope""" 91 return Scope( 92 expression=expression.unnest(), 93 sources={**self.cte_sources, **(chain_sources or {})}, 94 parent=self, 95 scope_type=scope_type, 96 **kwargs, 97 ) 98 99 def _collect(self): 100 self._tables = [] 101 self._ctes = [] 102 self._subqueries = [] 103 self._derived_tables = [] 104 self._udtfs = [] 105 self._raw_columns = [] 106 self._join_hints = [] 107 108 for node, parent, _ in self.walk(bfs=False): 109 if node is self.expression: 110 continue 111 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 112 self._raw_columns.append(node) 113 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 114 self._tables.append(node) 115 elif isinstance(node, exp.JoinHint): 116 self._join_hints.append(node) 117 elif isinstance(node, exp.UDTF): 118 self._udtfs.append(node) 119 elif isinstance(node, exp.CTE): 120 self._ctes.append(node) 121 elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): 122 self._derived_tables.append(node) 123 elif isinstance(node, exp.Subqueryable): 124 self._subqueries.append(node) 125 126 self._collected = True 127 128 def _ensure_collected(self): 129 if not self._collected: 130 self._collect() 131 132 def walk(self, bfs=True): 133 return walk_in_scope(self.expression, bfs=bfs) 134 135 def find(self, *expression_types, bfs=True): 136 """ 137 Returns the first node in this scope which matches at least one of the specified types. 138 139 This does NOT traverse into subscopes. 140 141 Args: 142 expression_types (type): the expression type(s) to match. 143 bfs (bool): True to use breadth-first search, False to use depth-first. 144 145 Returns: 146 exp.Expression: the node which matches the criteria or None if no node matching 147 the criteria was found. 148 """ 149 return next(self.find_all(*expression_types, bfs=bfs), None) 150 151 def find_all(self, *expression_types, bfs=True): 152 """ 153 Returns a generator object which visits all nodes in this scope and only yields those that 154 match at least one of the specified expression types. 155 156 This does NOT traverse into subscopes. 157 158 Args: 159 expression_types (type): the expression type(s) to match. 160 bfs (bool): True to use breadth-first search, False to use depth-first. 161 162 Yields: 163 exp.Expression: nodes 164 """ 165 for expression, *_ in self.walk(bfs=bfs): 166 if isinstance(expression, expression_types): 167 yield expression 168 169 def replace(self, old, new): 170 """ 171 Replace `old` with `new`. 172 173 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 174 175 Args: 176 old (exp.Expression): old node 177 new (exp.Expression): new node 178 """ 179 old.replace(new) 180 self.clear_cache() 181 182 @property 183 def tables(self): 184 """ 185 List of tables in this scope. 186 187 Returns: 188 list[exp.Table]: tables 189 """ 190 self._ensure_collected() 191 return self._tables 192 193 @property 194 def ctes(self): 195 """ 196 List of CTEs in this scope. 197 198 Returns: 199 list[exp.CTE]: ctes 200 """ 201 self._ensure_collected() 202 return self._ctes 203 204 @property 205 def derived_tables(self): 206 """ 207 List of derived tables in this scope. 208 209 For example: 210 SELECT * FROM (SELECT ...) <- that's a derived table 211 212 Returns: 213 list[exp.Subquery]: derived tables 214 """ 215 self._ensure_collected() 216 return self._derived_tables 217 218 @property 219 def udtfs(self): 220 """ 221 List of "User Defined Tabular Functions" in this scope. 222 223 Returns: 224 list[exp.UDTF]: UDTFs 225 """ 226 self._ensure_collected() 227 return self._udtfs 228 229 @property 230 def subqueries(self): 231 """ 232 List of subqueries in this scope. 233 234 For example: 235 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 236 237 Returns: 238 list[exp.Subqueryable]: subqueries 239 """ 240 self._ensure_collected() 241 return self._subqueries 242 243 @property 244 def columns(self): 245 """ 246 List of columns in this scope. 247 248 Returns: 249 list[exp.Column]: Column instances in this scope, plus any 250 Columns that reference this scope from correlated subqueries. 251 """ 252 if self._columns is None: 253 self._ensure_collected() 254 columns = self._raw_columns 255 256 external_columns = [ 257 column 258 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 259 for column in scope.external_columns 260 ] 261 262 named_selects = set(self.expression.named_selects) 263 264 self._columns = [] 265 for column in columns + external_columns: 266 ancestor = column.find_ancestor( 267 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint 268 ) 269 if ( 270 not ancestor 271 or column.table 272 or isinstance(ancestor, exp.Select) 273 or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window)) 274 or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) 275 ): 276 self._columns.append(column) 277 278 return self._columns 279 280 @property 281 def selected_sources(self): 282 """ 283 Mapping of nodes and sources that are actually selected from in this scope. 284 285 That is, all tables in a schema are selectable at any point. But a 286 table only becomes a selected source if it's included in a FROM or JOIN clause. 287 288 Returns: 289 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 290 """ 291 if self._selected_sources is None: 292 referenced_names = [] 293 294 for table in self.tables: 295 referenced_names.append((table.alias_or_name, table)) 296 for expression in itertools.chain(self.derived_tables, self.udtfs): 297 referenced_names.append((expression.alias, expression.unnest())) 298 result = {} 299 300 for name, node in referenced_names: 301 if name in result: 302 raise OptimizeError(f"Alias already used: {name}") 303 if name in self.sources: 304 result[name] = (node, self.sources[name]) 305 306 self._selected_sources = result 307 return self._selected_sources 308 309 @property 310 def cte_sources(self): 311 """ 312 Sources that are CTEs. 313 314 Returns: 315 dict[str, Scope]: Mapping of source alias to Scope 316 """ 317 return { 318 alias: scope 319 for alias, scope in self.sources.items() 320 if isinstance(scope, Scope) and scope.is_cte 321 } 322 323 @property 324 def selects(self): 325 """ 326 Select expressions of this scope. 327 328 For example, for the following expression: 329 SELECT 1 as a, 2 as b FROM x 330 331 The outputs are the "1 as a" and "2 as b" expressions. 332 333 Returns: 334 list[exp.Expression]: expressions 335 """ 336 if isinstance(self.expression, exp.Union): 337 return self.expression.unnest().selects 338 return self.expression.selects 339 340 @property 341 def external_columns(self): 342 """ 343 Columns that appear to reference sources in outer scopes. 344 345 Returns: 346 list[exp.Column]: Column instances that don't reference 347 sources in the current scope. 348 """ 349 if self._external_columns is None: 350 self._external_columns = [ 351 c for c in self.columns if c.table not in self.selected_sources 352 ] 353 return self._external_columns 354 355 @property 356 def unqualified_columns(self): 357 """ 358 Unqualified columns in the current scope. 359 360 Returns: 361 list[exp.Column]: Unqualified columns 362 """ 363 return [c for c in self.columns if not c.table] 364 365 @property 366 def join_hints(self): 367 """ 368 Hints that exist in the scope that reference tables 369 370 Returns: 371 list[exp.JoinHint]: Join hints that are referenced within the scope 372 """ 373 if self._join_hints is None: 374 return [] 375 return self._join_hints 376 377 @property 378 def pivots(self): 379 if not self._pivots: 380 self._pivots = [ 381 pivot 382 for node in self.tables + self.derived_tables 383 for pivot in node.args.get("pivots") or [] 384 ] 385 386 return self._pivots 387 388 def source_columns(self, source_name): 389 """ 390 Get all columns in the current scope for a particular source. 391 392 Args: 393 source_name (str): Name of the source 394 Returns: 395 list[exp.Column]: Column instances that reference `source_name` 396 """ 397 return [column for column in self.columns if column.table == source_name] 398 399 @property 400 def is_subquery(self): 401 """Determine if this scope is a subquery""" 402 return self.scope_type == ScopeType.SUBQUERY 403 404 @property 405 def is_derived_table(self): 406 """Determine if this scope is a derived table""" 407 return self.scope_type == ScopeType.DERIVED_TABLE 408 409 @property 410 def is_union(self): 411 """Determine if this scope is a union""" 412 return self.scope_type == ScopeType.UNION 413 414 @property 415 def is_cte(self): 416 """Determine if this scope is a common table expression""" 417 return self.scope_type == ScopeType.CTE 418 419 @property 420 def is_root(self): 421 """Determine if this is the root scope""" 422 return self.scope_type == ScopeType.ROOT 423 424 @property 425 def is_udtf(self): 426 """Determine if this scope is a UDTF (User Defined Table Function)""" 427 return self.scope_type == ScopeType.UDTF 428 429 @property 430 def is_correlated_subquery(self): 431 """Determine if this scope is a correlated subquery""" 432 return bool(self.is_subquery and self.external_columns) 433 434 def rename_source(self, old_name, new_name): 435 """Rename a source in this scope""" 436 columns = self.sources.pop(old_name or "", []) 437 self.sources[new_name] = columns 438 439 def add_source(self, name, source): 440 """Add a source to this scope""" 441 self.sources[name] = source 442 self.clear_cache() 443 444 def remove_source(self, name): 445 """Remove a source from this scope""" 446 self.sources.pop(name, None) 447 self.clear_cache() 448 449 def __repr__(self): 450 return f"Scope<{self.expression.sql()}>" 451 452 def traverse(self): 453 """ 454 Traverse the scope tree from this node. 455 456 Yields: 457 Scope: scope instances in depth-first-search post-order 458 """ 459 for child_scope in itertools.chain( 460 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 461 ): 462 yield from child_scope.traverse() 463 yield self 464 465 def ref_count(self): 466 """ 467 Count the number of times each scope in this tree is referenced. 468 469 Returns: 470 dict[int, int]: Mapping of Scope instance ID to reference count 471 """ 472 scope_ref_count = defaultdict(lambda: 0) 473 474 for scope in self.traverse(): 475 for _, source in scope.selected_sources.values(): 476 scope_ref_count[id(source)] += 1 477 478 return scope_ref_count 479 480 481def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 482 """ 483 Traverse an expression by it's "scopes". 484 485 "Scope" represents the current context of a Select statement. 486 487 This is helpful for optimizing queries, where we need more information than 488 the expression tree itself. For example, we might care about the source 489 names within a subquery. Returns a list because a generator could result in 490 incomplete properties which is confusing. 491 492 Examples: 493 >>> import sqlglot 494 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 495 >>> scopes = traverse_scope(expression) 496 >>> scopes[0].expression.sql(), list(scopes[0].sources) 497 ('SELECT a FROM x', ['x']) 498 >>> scopes[1].expression.sql(), list(scopes[1].sources) 499 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 500 501 Args: 502 expression (exp.Expression): expression to traverse 503 Returns: 504 list[Scope]: scope instances 505 """ 506 if not isinstance(expression, exp.Unionable): 507 return [] 508 return list(_traverse_scope(Scope(expression))) 509 510 511def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 512 """ 513 Build a scope tree. 514 515 Args: 516 expression (exp.Expression): expression to build the scope tree for 517 Returns: 518 Scope: root scope 519 """ 520 scopes = traverse_scope(expression) 521 if scopes: 522 return scopes[-1] 523 return None 524 525 526def _traverse_scope(scope): 527 if isinstance(scope.expression, exp.Select): 528 yield from _traverse_select(scope) 529 elif isinstance(scope.expression, exp.Union): 530 yield from _traverse_union(scope) 531 elif isinstance(scope.expression, exp.Subquery): 532 yield from _traverse_subqueries(scope) 533 elif isinstance(scope.expression, exp.Table): 534 # This case corresponds to a "join construct", i.e. (tbl1 JOIN tbl2 ON ..) 535 yield from _traverse_tables(scope) 536 elif isinstance(scope.expression, exp.UDTF): 537 pass 538 else: 539 raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}") 540 yield scope 541 542 543def _traverse_select(scope): 544 yield from _traverse_ctes(scope) 545 yield from _traverse_tables(scope) 546 yield from _traverse_subqueries(scope) 547 548 549def _traverse_union(scope): 550 yield from _traverse_ctes(scope) 551 552 # The last scope to be yield should be the top most scope 553 left = None 554 for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)): 555 yield left 556 557 right = None 558 for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): 559 yield right 560 561 scope.union_scopes = [left, right] 562 563 564def _traverse_ctes(scope): 565 sources = {} 566 567 for cte in scope.ctes: 568 recursive_scope = None 569 570 # if the scope is a recursive cte, it must be in the form of 571 # base_case UNION recursive. thus the recursive scope is the first 572 # section of the union. 573 if scope.expression.args["with"].recursive: 574 union = cte.this 575 576 if isinstance(union, exp.Union): 577 recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) 578 579 for child_scope in _traverse_scope( 580 scope.branch( 581 cte.this, 582 chain_sources=sources, 583 outer_column_list=cte.alias_column_names, 584 scope_type=ScopeType.CTE, 585 ) 586 ): 587 yield child_scope 588 589 alias = cte.alias 590 sources[alias] = child_scope 591 592 if recursive_scope: 593 child_scope.add_source(alias, recursive_scope) 594 595 # append the final child_scope yielded 596 scope.cte_scopes.append(child_scope) 597 598 scope.sources.update(sources) 599 600 601def _traverse_tables(scope): 602 sources = {} 603 604 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 605 expressions = [] 606 from_ = scope.expression.args.get("from") 607 if from_: 608 expressions.append(from_.this) 609 610 for join in scope.expression.args.get("joins") or []: 611 expressions.append(join.this) 612 613 if isinstance(scope.expression, exp.Table): 614 expressions.append(scope.expression) 615 616 expressions.extend(scope.expression.args.get("laterals") or []) 617 618 for expression in expressions: 619 if isinstance(expression, exp.Table): 620 table_name = expression.name 621 source_name = expression.alias_or_name 622 623 if table_name in scope.sources and not expression.db: 624 # This is a reference to a parent source (e.g. a CTE), not an actual table, unless 625 # it is pivoted, because then we get back a new table and hence a new source. 626 pivots = expression.args.get("pivots") 627 if pivots: 628 sources[pivots[0].alias] = expression 629 else: 630 sources[source_name] = scope.sources[table_name] 631 elif source_name in sources: 632 sources[find_new_name(sources, table_name)] = expression 633 else: 634 sources[source_name] = expression 635 continue 636 637 if isinstance(expression, exp.UDTF): 638 lateral_sources = sources 639 scope_type = ScopeType.UDTF 640 scopes = scope.udtf_scopes 641 else: 642 lateral_sources = None 643 scope_type = ScopeType.DERIVED_TABLE 644 scopes = scope.derived_table_scopes 645 646 for child_scope in _traverse_scope( 647 scope.branch( 648 expression, 649 lateral_sources=lateral_sources, 650 outer_column_list=expression.alias_column_names, 651 scope_type=scope_type, 652 ) 653 ): 654 yield child_scope 655 656 # Tables without aliases will be set as "" 657 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 658 # Until then, this means that only a single, unaliased derived table is allowed (rather, 659 # the latest one wins. 660 alias = expression.alias 661 sources[alias] = child_scope 662 663 # append the final child_scope yielded 664 scopes.append(child_scope) 665 scope.table_scopes.append(child_scope) 666 667 scope.sources.update(sources) 668 669 670def _traverse_subqueries(scope): 671 for subquery in scope.subqueries: 672 top = None 673 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 674 yield child_scope 675 top = child_scope 676 scope.subquery_scopes.append(top) 677 678 679def walk_in_scope(expression, bfs=True): 680 """ 681 Returns a generator object which visits all nodes in the syntrax tree, stopping at 682 nodes that start child scopes. 683 684 Args: 685 expression (exp.Expression): 686 bfs (bool): if set to True the BFS traversal order will be applied, 687 otherwise the DFS traversal will be used instead. 688 689 Yields: 690 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 691 """ 692 # We'll use this variable to pass state into the dfs generator. 693 # Whenever we set it to True, we exclude a subtree from traversal. 694 prune = False 695 696 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): 697 prune = False 698 699 yield node, parent, key 700 701 if node is expression: 702 continue 703 if ( 704 isinstance(node, exp.CTE) 705 or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) 706 or isinstance(node, exp.UDTF) 707 or isinstance(node, exp.Subqueryable) 708 ): 709 prune = True
12class ScopeType(Enum): 13 ROOT = auto() 14 SUBQUERY = auto() 15 DERIVED_TABLE = auto() 16 CTE = auto() 17 UNION = auto() 18 UDTF = auto()
An enumeration.
Inherited Members
- enum.Enum
- name
- value
21class Scope: 22 """ 23 Selection scope. 24 25 Attributes: 26 expression (exp.Select|exp.Union): Root expression of this scope 27 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 28 a Table expression or another Scope instance. For example: 29 SELECT * FROM x {"x": Table(this="x")} 30 SELECT * FROM x AS y {"y": Table(this="x")} 31 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 32 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 33 For example: 34 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 35 The LATERAL VIEW EXPLODE gets x as a source. 36 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 37 defines a column list of it's alias of this scope, this is that list of columns. 38 For example: 39 SELECT * FROM (SELECT ...) AS y(col1, col2) 40 The inner query would have `["col1", "col2"]` for its `outer_column_list` 41 parent (Scope): Parent scope 42 scope_type (ScopeType): Type of this scope, relative to it's parent 43 subquery_scopes (list[Scope]): List of all child scopes for subqueries 44 cte_scopes (list[Scope]): List of all child scopes for CTEs 45 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 46 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 47 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 48 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 49 a list of the left and right child scopes. 50 """ 51 52 def __init__( 53 self, 54 expression, 55 sources=None, 56 outer_column_list=None, 57 parent=None, 58 scope_type=ScopeType.ROOT, 59 lateral_sources=None, 60 ): 61 self.expression = expression 62 self.sources = sources or {} 63 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 64 self.sources.update(self.lateral_sources) 65 self.outer_column_list = outer_column_list or [] 66 self.parent = parent 67 self.scope_type = scope_type 68 self.subquery_scopes = [] 69 self.derived_table_scopes = [] 70 self.table_scopes = [] 71 self.cte_scopes = [] 72 self.union_scopes = [] 73 self.udtf_scopes = [] 74 self.clear_cache() 75 76 def clear_cache(self): 77 self._collected = False 78 self._raw_columns = None 79 self._derived_tables = None 80 self._udtfs = None 81 self._tables = None 82 self._ctes = None 83 self._subqueries = None 84 self._selected_sources = None 85 self._columns = None 86 self._external_columns = None 87 self._join_hints = None 88 self._pivots = None 89 90 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 91 """Branch from the current scope to a new, inner scope""" 92 return Scope( 93 expression=expression.unnest(), 94 sources={**self.cte_sources, **(chain_sources or {})}, 95 parent=self, 96 scope_type=scope_type, 97 **kwargs, 98 ) 99 100 def _collect(self): 101 self._tables = [] 102 self._ctes = [] 103 self._subqueries = [] 104 self._derived_tables = [] 105 self._udtfs = [] 106 self._raw_columns = [] 107 self._join_hints = [] 108 109 for node, parent, _ in self.walk(bfs=False): 110 if node is self.expression: 111 continue 112 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 113 self._raw_columns.append(node) 114 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 115 self._tables.append(node) 116 elif isinstance(node, exp.JoinHint): 117 self._join_hints.append(node) 118 elif isinstance(node, exp.UDTF): 119 self._udtfs.append(node) 120 elif isinstance(node, exp.CTE): 121 self._ctes.append(node) 122 elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): 123 self._derived_tables.append(node) 124 elif isinstance(node, exp.Subqueryable): 125 self._subqueries.append(node) 126 127 self._collected = True 128 129 def _ensure_collected(self): 130 if not self._collected: 131 self._collect() 132 133 def walk(self, bfs=True): 134 return walk_in_scope(self.expression, bfs=bfs) 135 136 def find(self, *expression_types, bfs=True): 137 """ 138 Returns the first node in this scope which matches at least one of the specified types. 139 140 This does NOT traverse into subscopes. 141 142 Args: 143 expression_types (type): the expression type(s) to match. 144 bfs (bool): True to use breadth-first search, False to use depth-first. 145 146 Returns: 147 exp.Expression: the node which matches the criteria or None if no node matching 148 the criteria was found. 149 """ 150 return next(self.find_all(*expression_types, bfs=bfs), None) 151 152 def find_all(self, *expression_types, bfs=True): 153 """ 154 Returns a generator object which visits all nodes in this scope and only yields those that 155 match at least one of the specified expression types. 156 157 This does NOT traverse into subscopes. 158 159 Args: 160 expression_types (type): the expression type(s) to match. 161 bfs (bool): True to use breadth-first search, False to use depth-first. 162 163 Yields: 164 exp.Expression: nodes 165 """ 166 for expression, *_ in self.walk(bfs=bfs): 167 if isinstance(expression, expression_types): 168 yield expression 169 170 def replace(self, old, new): 171 """ 172 Replace `old` with `new`. 173 174 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 175 176 Args: 177 old (exp.Expression): old node 178 new (exp.Expression): new node 179 """ 180 old.replace(new) 181 self.clear_cache() 182 183 @property 184 def tables(self): 185 """ 186 List of tables in this scope. 187 188 Returns: 189 list[exp.Table]: tables 190 """ 191 self._ensure_collected() 192 return self._tables 193 194 @property 195 def ctes(self): 196 """ 197 List of CTEs in this scope. 198 199 Returns: 200 list[exp.CTE]: ctes 201 """ 202 self._ensure_collected() 203 return self._ctes 204 205 @property 206 def derived_tables(self): 207 """ 208 List of derived tables in this scope. 209 210 For example: 211 SELECT * FROM (SELECT ...) <- that's a derived table 212 213 Returns: 214 list[exp.Subquery]: derived tables 215 """ 216 self._ensure_collected() 217 return self._derived_tables 218 219 @property 220 def udtfs(self): 221 """ 222 List of "User Defined Tabular Functions" in this scope. 223 224 Returns: 225 list[exp.UDTF]: UDTFs 226 """ 227 self._ensure_collected() 228 return self._udtfs 229 230 @property 231 def subqueries(self): 232 """ 233 List of subqueries in this scope. 234 235 For example: 236 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 237 238 Returns: 239 list[exp.Subqueryable]: subqueries 240 """ 241 self._ensure_collected() 242 return self._subqueries 243 244 @property 245 def columns(self): 246 """ 247 List of columns in this scope. 248 249 Returns: 250 list[exp.Column]: Column instances in this scope, plus any 251 Columns that reference this scope from correlated subqueries. 252 """ 253 if self._columns is None: 254 self._ensure_collected() 255 columns = self._raw_columns 256 257 external_columns = [ 258 column 259 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 260 for column in scope.external_columns 261 ] 262 263 named_selects = set(self.expression.named_selects) 264 265 self._columns = [] 266 for column in columns + external_columns: 267 ancestor = column.find_ancestor( 268 exp.Select, exp.Qualify, exp.Order, exp.Having, exp.Hint 269 ) 270 if ( 271 not ancestor 272 or column.table 273 or isinstance(ancestor, exp.Select) 274 or (isinstance(ancestor, exp.Order) and isinstance(ancestor.parent, exp.Window)) 275 or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) 276 ): 277 self._columns.append(column) 278 279 return self._columns 280 281 @property 282 def selected_sources(self): 283 """ 284 Mapping of nodes and sources that are actually selected from in this scope. 285 286 That is, all tables in a schema are selectable at any point. But a 287 table only becomes a selected source if it's included in a FROM or JOIN clause. 288 289 Returns: 290 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 291 """ 292 if self._selected_sources is None: 293 referenced_names = [] 294 295 for table in self.tables: 296 referenced_names.append((table.alias_or_name, table)) 297 for expression in itertools.chain(self.derived_tables, self.udtfs): 298 referenced_names.append((expression.alias, expression.unnest())) 299 result = {} 300 301 for name, node in referenced_names: 302 if name in result: 303 raise OptimizeError(f"Alias already used: {name}") 304 if name in self.sources: 305 result[name] = (node, self.sources[name]) 306 307 self._selected_sources = result 308 return self._selected_sources 309 310 @property 311 def cte_sources(self): 312 """ 313 Sources that are CTEs. 314 315 Returns: 316 dict[str, Scope]: Mapping of source alias to Scope 317 """ 318 return { 319 alias: scope 320 for alias, scope in self.sources.items() 321 if isinstance(scope, Scope) and scope.is_cte 322 } 323 324 @property 325 def selects(self): 326 """ 327 Select expressions of this scope. 328 329 For example, for the following expression: 330 SELECT 1 as a, 2 as b FROM x 331 332 The outputs are the "1 as a" and "2 as b" expressions. 333 334 Returns: 335 list[exp.Expression]: expressions 336 """ 337 if isinstance(self.expression, exp.Union): 338 return self.expression.unnest().selects 339 return self.expression.selects 340 341 @property 342 def external_columns(self): 343 """ 344 Columns that appear to reference sources in outer scopes. 345 346 Returns: 347 list[exp.Column]: Column instances that don't reference 348 sources in the current scope. 349 """ 350 if self._external_columns is None: 351 self._external_columns = [ 352 c for c in self.columns if c.table not in self.selected_sources 353 ] 354 return self._external_columns 355 356 @property 357 def unqualified_columns(self): 358 """ 359 Unqualified columns in the current scope. 360 361 Returns: 362 list[exp.Column]: Unqualified columns 363 """ 364 return [c for c in self.columns if not c.table] 365 366 @property 367 def join_hints(self): 368 """ 369 Hints that exist in the scope that reference tables 370 371 Returns: 372 list[exp.JoinHint]: Join hints that are referenced within the scope 373 """ 374 if self._join_hints is None: 375 return [] 376 return self._join_hints 377 378 @property 379 def pivots(self): 380 if not self._pivots: 381 self._pivots = [ 382 pivot 383 for node in self.tables + self.derived_tables 384 for pivot in node.args.get("pivots") or [] 385 ] 386 387 return self._pivots 388 389 def source_columns(self, source_name): 390 """ 391 Get all columns in the current scope for a particular source. 392 393 Args: 394 source_name (str): Name of the source 395 Returns: 396 list[exp.Column]: Column instances that reference `source_name` 397 """ 398 return [column for column in self.columns if column.table == source_name] 399 400 @property 401 def is_subquery(self): 402 """Determine if this scope is a subquery""" 403 return self.scope_type == ScopeType.SUBQUERY 404 405 @property 406 def is_derived_table(self): 407 """Determine if this scope is a derived table""" 408 return self.scope_type == ScopeType.DERIVED_TABLE 409 410 @property 411 def is_union(self): 412 """Determine if this scope is a union""" 413 return self.scope_type == ScopeType.UNION 414 415 @property 416 def is_cte(self): 417 """Determine if this scope is a common table expression""" 418 return self.scope_type == ScopeType.CTE 419 420 @property 421 def is_root(self): 422 """Determine if this is the root scope""" 423 return self.scope_type == ScopeType.ROOT 424 425 @property 426 def is_udtf(self): 427 """Determine if this scope is a UDTF (User Defined Table Function)""" 428 return self.scope_type == ScopeType.UDTF 429 430 @property 431 def is_correlated_subquery(self): 432 """Determine if this scope is a correlated subquery""" 433 return bool(self.is_subquery and self.external_columns) 434 435 def rename_source(self, old_name, new_name): 436 """Rename a source in this scope""" 437 columns = self.sources.pop(old_name or "", []) 438 self.sources[new_name] = columns 439 440 def add_source(self, name, source): 441 """Add a source to this scope""" 442 self.sources[name] = source 443 self.clear_cache() 444 445 def remove_source(self, name): 446 """Remove a source from this scope""" 447 self.sources.pop(name, None) 448 self.clear_cache() 449 450 def __repr__(self): 451 return f"Scope<{self.expression.sql()}>" 452 453 def traverse(self): 454 """ 455 Traverse the scope tree from this node. 456 457 Yields: 458 Scope: scope instances in depth-first-search post-order 459 """ 460 for child_scope in itertools.chain( 461 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 462 ): 463 yield from child_scope.traverse() 464 yield self 465 466 def ref_count(self): 467 """ 468 Count the number of times each scope in this tree is referenced. 469 470 Returns: 471 dict[int, int]: Mapping of Scope instance ID to reference count 472 """ 473 scope_ref_count = defaultdict(lambda: 0) 474 475 for scope in self.traverse(): 476 for _, source in scope.selected_sources.values(): 477 scope_ref_count[id(source)] += 1 478 479 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.Union): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_column_list
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
52 def __init__( 53 self, 54 expression, 55 sources=None, 56 outer_column_list=None, 57 parent=None, 58 scope_type=ScopeType.ROOT, 59 lateral_sources=None, 60 ): 61 self.expression = expression 62 self.sources = sources or {} 63 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 64 self.sources.update(self.lateral_sources) 65 self.outer_column_list = outer_column_list or [] 66 self.parent = parent 67 self.scope_type = scope_type 68 self.subquery_scopes = [] 69 self.derived_table_scopes = [] 70 self.table_scopes = [] 71 self.cte_scopes = [] 72 self.union_scopes = [] 73 self.udtf_scopes = [] 74 self.clear_cache()
76 def clear_cache(self): 77 self._collected = False 78 self._raw_columns = None 79 self._derived_tables = None 80 self._udtfs = None 81 self._tables = None 82 self._ctes = None 83 self._subqueries = None 84 self._selected_sources = None 85 self._columns = None 86 self._external_columns = None 87 self._join_hints = None 88 self._pivots = None
90 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 91 """Branch from the current scope to a new, inner scope""" 92 return Scope( 93 expression=expression.unnest(), 94 sources={**self.cte_sources, **(chain_sources or {})}, 95 parent=self, 96 scope_type=scope_type, 97 **kwargs, 98 )
Branch from the current scope to a new, inner scope
136 def find(self, *expression_types, bfs=True): 137 """ 138 Returns the first node in this scope which matches at least one of the specified types. 139 140 This does NOT traverse into subscopes. 141 142 Args: 143 expression_types (type): the expression type(s) to match. 144 bfs (bool): True to use breadth-first search, False to use depth-first. 145 146 Returns: 147 exp.Expression: the node which matches the criteria or None if no node matching 148 the criteria was found. 149 """ 150 return next(self.find_all(*expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.
152 def find_all(self, *expression_types, bfs=True): 153 """ 154 Returns a generator object which visits all nodes in this scope and only yields those that 155 match at least one of the specified expression types. 156 157 This does NOT traverse into subscopes. 158 159 Args: 160 expression_types (type): the expression type(s) to match. 161 bfs (bool): True to use breadth-first search, False to use depth-first. 162 163 Yields: 164 exp.Expression: nodes 165 """ 166 for expression, *_ in self.walk(bfs=bfs): 167 if isinstance(expression, expression_types): 168 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
170 def replace(self, old, new): 171 """ 172 Replace `old` with `new`. 173 174 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 175 176 Args: 177 old (exp.Expression): old node 178 new (exp.Expression): new node 179 """ 180 old.replace(new) 181 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Subqueryable]: subqueries
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
Select expressions of this scope.
For example, for the following expression: SELECT 1 as a, 2 as b FROM x
The outputs are the "1 as a" and "2 as b" expressions.
Returns:
list[exp.Expression]: expressions
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
389 def source_columns(self, source_name): 390 """ 391 Get all columns in the current scope for a particular source. 392 393 Args: 394 source_name (str): Name of the source 395 Returns: 396 list[exp.Column]: Column instances that reference `source_name` 397 """ 398 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
435 def rename_source(self, old_name, new_name): 436 """Rename a source in this scope""" 437 columns = self.sources.pop(old_name or "", []) 438 self.sources[new_name] = columns
Rename a source in this scope
440 def add_source(self, name, source): 441 """Add a source to this scope""" 442 self.sources[name] = source 443 self.clear_cache()
Add a source to this scope
445 def remove_source(self, name): 446 """Remove a source from this scope""" 447 self.sources.pop(name, None) 448 self.clear_cache()
Remove a source from this scope
453 def traverse(self): 454 """ 455 Traverse the scope tree from this node. 456 457 Yields: 458 Scope: scope instances in depth-first-search post-order 459 """ 460 for child_scope in itertools.chain( 461 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 462 ): 463 yield from child_scope.traverse() 464 yield self
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
466 def ref_count(self): 467 """ 468 Count the number of times each scope in this tree is referenced. 469 470 Returns: 471 dict[int, int]: Mapping of Scope instance ID to reference count 472 """ 473 scope_ref_count = defaultdict(lambda: 0) 474 475 for scope in self.traverse(): 476 for _, source in scope.selected_sources.values(): 477 scope_ref_count[id(source)] += 1 478 479 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
482def traverse_scope(expression: exp.Expression) -> t.List[Scope]: 483 """ 484 Traverse an expression by it's "scopes". 485 486 "Scope" represents the current context of a Select statement. 487 488 This is helpful for optimizing queries, where we need more information than 489 the expression tree itself. For example, we might care about the source 490 names within a subquery. Returns a list because a generator could result in 491 incomplete properties which is confusing. 492 493 Examples: 494 >>> import sqlglot 495 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 496 >>> scopes = traverse_scope(expression) 497 >>> scopes[0].expression.sql(), list(scopes[0].sources) 498 ('SELECT a FROM x', ['x']) 499 >>> scopes[1].expression.sql(), list(scopes[1].sources) 500 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 501 502 Args: 503 expression (exp.Expression): expression to traverse 504 Returns: 505 list[Scope]: scope instances 506 """ 507 if not isinstance(expression, exp.Unionable): 508 return [] 509 return list(_traverse_scope(Scope(expression)))
Traverse an expression by it's "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression (exp.Expression): expression to traverse
Returns:
list[Scope]: scope instances
512def build_scope(expression: exp.Expression) -> t.Optional[Scope]: 513 """ 514 Build a scope tree. 515 516 Args: 517 expression (exp.Expression): expression to build the scope tree for 518 Returns: 519 Scope: root scope 520 """ 521 scopes = traverse_scope(expression) 522 if scopes: 523 return scopes[-1] 524 return None
Build a scope tree.
Arguments:
- expression (exp.Expression): expression to build the scope tree for
Returns:
Scope: root scope
680def walk_in_scope(expression, bfs=True): 681 """ 682 Returns a generator object which visits all nodes in the syntrax tree, stopping at 683 nodes that start child scopes. 684 685 Args: 686 expression (exp.Expression): 687 bfs (bool): if set to True the BFS traversal order will be applied, 688 otherwise the DFS traversal will be used instead. 689 690 Yields: 691 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 692 """ 693 # We'll use this variable to pass state into the dfs generator. 694 # Whenever we set it to True, we exclude a subtree from traversal. 695 prune = False 696 697 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): 698 prune = False 699 700 yield node, parent, key 701 702 if node is expression: 703 continue 704 if ( 705 isinstance(node, exp.CTE) 706 or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) 707 or isinstance(node, exp.UDTF) 708 or isinstance(node, exp.Subqueryable) 709 ): 710 prune = True
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key