sqlglot.optimizer.scope
1import itertools 2from collections import defaultdict 3from enum import Enum, auto 4 5from sqlglot import exp 6from sqlglot.errors import OptimizeError 7from sqlglot.helper import find_new_name 8 9 10class ScopeType(Enum): 11 ROOT = auto() 12 SUBQUERY = auto() 13 DERIVED_TABLE = auto() 14 CTE = auto() 15 UNION = auto() 16 UDTF = auto() 17 18 19class Scope: 20 """ 21 Selection scope. 22 23 Attributes: 24 expression (exp.Select|exp.Union): Root expression of this scope 25 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 26 a Table expression or another Scope instance. For example: 27 SELECT * FROM x {"x": Table(this="x")} 28 SELECT * FROM x AS y {"y": Table(this="x")} 29 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 30 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 31 For example: 32 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 33 The LATERAL VIEW EXPLODE gets x as a source. 34 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 35 defines a column list of it's alias of this scope, this is that list of columns. 36 For example: 37 SELECT * FROM (SELECT ...) AS y(col1, col2) 38 The inner query would have `["col1", "col2"]` for its `outer_column_list` 39 parent (Scope): Parent scope 40 scope_type (ScopeType): Type of this scope, relative to it's parent 41 subquery_scopes (list[Scope]): List of all child scopes for subqueries 42 cte_scopes (list[Scope]): List of all child scopes for CTEs 43 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 44 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 45 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 46 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 47 a list of the left and right child scopes. 48 """ 49 50 def __init__( 51 self, 52 expression, 53 sources=None, 54 outer_column_list=None, 55 parent=None, 56 scope_type=ScopeType.ROOT, 57 lateral_sources=None, 58 ): 59 self.expression = expression 60 self.sources = sources or {} 61 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 62 self.sources.update(self.lateral_sources) 63 self.outer_column_list = outer_column_list or [] 64 self.parent = parent 65 self.scope_type = scope_type 66 self.subquery_scopes = [] 67 self.derived_table_scopes = [] 68 self.table_scopes = [] 69 self.cte_scopes = [] 70 self.union_scopes = [] 71 self.udtf_scopes = [] 72 self.clear_cache() 73 74 def clear_cache(self): 75 self._collected = False 76 self._raw_columns = None 77 self._derived_tables = None 78 self._udtfs = None 79 self._tables = None 80 self._ctes = None 81 self._subqueries = None 82 self._selected_sources = None 83 self._columns = None 84 self._external_columns = None 85 self._join_hints = None 86 87 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 88 """Branch from the current scope to a new, inner scope""" 89 return Scope( 90 expression=expression.unnest(), 91 sources={**self.cte_sources, **(chain_sources or {})}, 92 parent=self, 93 scope_type=scope_type, 94 **kwargs, 95 ) 96 97 def _collect(self): 98 self._tables = [] 99 self._ctes = [] 100 self._subqueries = [] 101 self._derived_tables = [] 102 self._udtfs = [] 103 self._raw_columns = [] 104 self._join_hints = [] 105 106 for node, parent, _ in self.walk(bfs=False): 107 if node is self.expression: 108 continue 109 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 110 self._raw_columns.append(node) 111 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 112 self._tables.append(node) 113 elif isinstance(node, exp.JoinHint): 114 self._join_hints.append(node) 115 elif isinstance(node, exp.UDTF): 116 self._udtfs.append(node) 117 elif isinstance(node, exp.CTE): 118 self._ctes.append(node) 119 elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): 120 self._derived_tables.append(node) 121 elif isinstance(node, exp.Subqueryable): 122 self._subqueries.append(node) 123 124 self._collected = True 125 126 def _ensure_collected(self): 127 if not self._collected: 128 self._collect() 129 130 def walk(self, bfs=True): 131 return walk_in_scope(self.expression, bfs=bfs) 132 133 def find(self, *expression_types, bfs=True): 134 """ 135 Returns the first node in this scope which matches at least one of the specified types. 136 137 This does NOT traverse into subscopes. 138 139 Args: 140 expression_types (type): the expression type(s) to match. 141 bfs (bool): True to use breadth-first search, False to use depth-first. 142 143 Returns: 144 exp.Expression: the node which matches the criteria or None if no node matching 145 the criteria was found. 146 """ 147 return next(self.find_all(*expression_types, bfs=bfs), None) 148 149 def find_all(self, *expression_types, bfs=True): 150 """ 151 Returns a generator object which visits all nodes in this scope and only yields those that 152 match at least one of the specified expression types. 153 154 This does NOT traverse into subscopes. 155 156 Args: 157 expression_types (type): the expression type(s) to match. 158 bfs (bool): True to use breadth-first search, False to use depth-first. 159 160 Yields: 161 exp.Expression: nodes 162 """ 163 for expression, *_ in self.walk(bfs=bfs): 164 if isinstance(expression, expression_types): 165 yield expression 166 167 def replace(self, old, new): 168 """ 169 Replace `old` with `new`. 170 171 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 172 173 Args: 174 old (exp.Expression): old node 175 new (exp.Expression): new node 176 """ 177 old.replace(new) 178 self.clear_cache() 179 180 @property 181 def tables(self): 182 """ 183 List of tables in this scope. 184 185 Returns: 186 list[exp.Table]: tables 187 """ 188 self._ensure_collected() 189 return self._tables 190 191 @property 192 def ctes(self): 193 """ 194 List of CTEs in this scope. 195 196 Returns: 197 list[exp.CTE]: ctes 198 """ 199 self._ensure_collected() 200 return self._ctes 201 202 @property 203 def derived_tables(self): 204 """ 205 List of derived tables in this scope. 206 207 For example: 208 SELECT * FROM (SELECT ...) <- that's a derived table 209 210 Returns: 211 list[exp.Subquery]: derived tables 212 """ 213 self._ensure_collected() 214 return self._derived_tables 215 216 @property 217 def udtfs(self): 218 """ 219 List of "User Defined Tabular Functions" in this scope. 220 221 Returns: 222 list[exp.UDTF]: UDTFs 223 """ 224 self._ensure_collected() 225 return self._udtfs 226 227 @property 228 def subqueries(self): 229 """ 230 List of subqueries in this scope. 231 232 For example: 233 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 234 235 Returns: 236 list[exp.Subqueryable]: subqueries 237 """ 238 self._ensure_collected() 239 return self._subqueries 240 241 @property 242 def columns(self): 243 """ 244 List of columns in this scope. 245 246 Returns: 247 list[exp.Column]: Column instances in this scope, plus any 248 Columns that reference this scope from correlated subqueries. 249 """ 250 if self._columns is None: 251 self._ensure_collected() 252 columns = self._raw_columns 253 254 external_columns = [ 255 column 256 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 257 for column in scope.external_columns 258 ] 259 260 named_selects = set(self.expression.named_selects) 261 262 self._columns = [] 263 for column in columns + external_columns: 264 ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint) 265 if ( 266 not ancestor 267 # Window functions can have an ORDER BY clause 268 or not isinstance(ancestor.parent, exp.Select) 269 or column.table 270 or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) 271 ): 272 self._columns.append(column) 273 274 return self._columns 275 276 @property 277 def selected_sources(self): 278 """ 279 Mapping of nodes and sources that are actually selected from in this scope. 280 281 That is, all tables in a schema are selectable at any point. But a 282 table only becomes a selected source if it's included in a FROM or JOIN clause. 283 284 Returns: 285 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 286 """ 287 if self._selected_sources is None: 288 referenced_names = [] 289 290 for table in self.tables: 291 referenced_names.append((table.alias_or_name, table)) 292 for expression in itertools.chain(self.derived_tables, self.udtfs): 293 referenced_names.append((expression.alias, expression.unnest())) 294 result = {} 295 296 for name, node in referenced_names: 297 if name in result: 298 raise OptimizeError(f"Alias already used: {name}") 299 if name in self.sources: 300 result[name] = (node, self.sources[name]) 301 302 self._selected_sources = result 303 return self._selected_sources 304 305 @property 306 def cte_sources(self): 307 """ 308 Sources that are CTEs. 309 310 Returns: 311 dict[str, Scope]: Mapping of source alias to Scope 312 """ 313 return { 314 alias: scope 315 for alias, scope in self.sources.items() 316 if isinstance(scope, Scope) and scope.is_cte 317 } 318 319 @property 320 def selects(self): 321 """ 322 Select expressions of this scope. 323 324 For example, for the following expression: 325 SELECT 1 as a, 2 as b FROM x 326 327 The outputs are the "1 as a" and "2 as b" expressions. 328 329 Returns: 330 list[exp.Expression]: expressions 331 """ 332 if isinstance(self.expression, exp.Union): 333 return self.expression.unnest().selects 334 return self.expression.selects 335 336 @property 337 def external_columns(self): 338 """ 339 Columns that appear to reference sources in outer scopes. 340 341 Returns: 342 list[exp.Column]: Column instances that don't reference 343 sources in the current scope. 344 """ 345 if self._external_columns is None: 346 self._external_columns = [ 347 c for c in self.columns if c.table not in self.selected_sources 348 ] 349 return self._external_columns 350 351 @property 352 def unqualified_columns(self): 353 """ 354 Unqualified columns in the current scope. 355 356 Returns: 357 list[exp.Column]: Unqualified columns 358 """ 359 return [c for c in self.columns if not c.table] 360 361 @property 362 def join_hints(self): 363 """ 364 Hints that exist in the scope that reference tables 365 366 Returns: 367 list[exp.JoinHint]: Join hints that are referenced within the scope 368 """ 369 if self._join_hints is None: 370 return [] 371 return self._join_hints 372 373 def source_columns(self, source_name): 374 """ 375 Get all columns in the current scope for a particular source. 376 377 Args: 378 source_name (str): Name of the source 379 Returns: 380 list[exp.Column]: Column instances that reference `source_name` 381 """ 382 return [column for column in self.columns if column.table == source_name] 383 384 @property 385 def is_subquery(self): 386 """Determine if this scope is a subquery""" 387 return self.scope_type == ScopeType.SUBQUERY 388 389 @property 390 def is_derived_table(self): 391 """Determine if this scope is a derived table""" 392 return self.scope_type == ScopeType.DERIVED_TABLE 393 394 @property 395 def is_union(self): 396 """Determine if this scope is a union""" 397 return self.scope_type == ScopeType.UNION 398 399 @property 400 def is_cte(self): 401 """Determine if this scope is a common table expression""" 402 return self.scope_type == ScopeType.CTE 403 404 @property 405 def is_root(self): 406 """Determine if this is the root scope""" 407 return self.scope_type == ScopeType.ROOT 408 409 @property 410 def is_udtf(self): 411 """Determine if this scope is a UDTF (User Defined Table Function)""" 412 return self.scope_type == ScopeType.UDTF 413 414 @property 415 def is_correlated_subquery(self): 416 """Determine if this scope is a correlated subquery""" 417 return bool(self.is_subquery and self.external_columns) 418 419 def rename_source(self, old_name, new_name): 420 """Rename a source in this scope""" 421 columns = self.sources.pop(old_name or "", []) 422 self.sources[new_name] = columns 423 424 def add_source(self, name, source): 425 """Add a source to this scope""" 426 self.sources[name] = source 427 self.clear_cache() 428 429 def remove_source(self, name): 430 """Remove a source from this scope""" 431 self.sources.pop(name, None) 432 self.clear_cache() 433 434 def __repr__(self): 435 return f"Scope<{self.expression.sql()}>" 436 437 def traverse(self): 438 """ 439 Traverse the scope tree from this node. 440 441 Yields: 442 Scope: scope instances in depth-first-search post-order 443 """ 444 for child_scope in itertools.chain( 445 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 446 ): 447 yield from child_scope.traverse() 448 yield self 449 450 def ref_count(self): 451 """ 452 Count the number of times each scope in this tree is referenced. 453 454 Returns: 455 dict[int, int]: Mapping of Scope instance ID to reference count 456 """ 457 scope_ref_count = defaultdict(lambda: 0) 458 459 for scope in self.traverse(): 460 for _, source in scope.selected_sources.values(): 461 scope_ref_count[id(source)] += 1 462 463 return scope_ref_count 464 465 466def traverse_scope(expression): 467 """ 468 Traverse an expression by it's "scopes". 469 470 "Scope" represents the current context of a Select statement. 471 472 This is helpful for optimizing queries, where we need more information than 473 the expression tree itself. For example, we might care about the source 474 names within a subquery. Returns a list because a generator could result in 475 incomplete properties which is confusing. 476 477 Examples: 478 >>> import sqlglot 479 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 480 >>> scopes = traverse_scope(expression) 481 >>> scopes[0].expression.sql(), list(scopes[0].sources) 482 ('SELECT a FROM x', ['x']) 483 >>> scopes[1].expression.sql(), list(scopes[1].sources) 484 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 485 486 Args: 487 expression (exp.Expression): expression to traverse 488 Returns: 489 list[Scope]: scope instances 490 """ 491 return list(_traverse_scope(Scope(expression))) 492 493 494def build_scope(expression): 495 """ 496 Build a scope tree. 497 498 Args: 499 expression (exp.Expression): expression to build the scope tree for 500 Returns: 501 Scope: root scope 502 """ 503 return traverse_scope(expression)[-1] 504 505 506def _traverse_scope(scope): 507 if isinstance(scope.expression, exp.Select): 508 yield from _traverse_select(scope) 509 elif isinstance(scope.expression, exp.Union): 510 yield from _traverse_union(scope) 511 elif isinstance(scope.expression, exp.Subquery): 512 yield from _traverse_subqueries(scope) 513 elif isinstance(scope.expression, exp.UDTF): 514 pass 515 else: 516 raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}") 517 yield scope 518 519 520def _traverse_select(scope): 521 yield from _traverse_ctes(scope) 522 yield from _traverse_tables(scope) 523 yield from _traverse_subqueries(scope) 524 525 526def _traverse_union(scope): 527 yield from _traverse_ctes(scope) 528 529 # The last scope to be yield should be the top most scope 530 left = None 531 for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)): 532 yield left 533 534 right = None 535 for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): 536 yield right 537 538 scope.union_scopes = [left, right] 539 540 541def _traverse_ctes(scope): 542 sources = {} 543 544 for cte in scope.ctes: 545 recursive_scope = None 546 547 # if the scope is a recursive cte, it must be in the form of 548 # base_case UNION recursive. thus the recursive scope is the first 549 # section of the union. 550 if scope.expression.args["with"].recursive: 551 union = cte.this 552 553 if isinstance(union, exp.Union): 554 recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) 555 556 for child_scope in _traverse_scope( 557 scope.branch( 558 cte.this, 559 chain_sources=sources, 560 outer_column_list=cte.alias_column_names, 561 scope_type=ScopeType.CTE, 562 ) 563 ): 564 yield child_scope 565 566 alias = cte.alias 567 sources[alias] = child_scope 568 569 if recursive_scope: 570 child_scope.add_source(alias, recursive_scope) 571 572 # append the final child_scope yielded 573 scope.cte_scopes.append(child_scope) 574 575 scope.sources.update(sources) 576 577 578def _traverse_tables(scope): 579 sources = {} 580 581 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 582 expressions = [] 583 from_ = scope.expression.args.get("from") 584 if from_: 585 expressions.extend(from_.expressions) 586 587 for join in scope.expression.args.get("joins") or []: 588 expressions.append(join.this) 589 590 expressions.extend(scope.expression.args.get("laterals") or []) 591 592 for expression in expressions: 593 if isinstance(expression, exp.Table): 594 table_name = expression.name 595 source_name = expression.alias_or_name 596 597 if table_name in scope.sources: 598 # This is a reference to a parent source (e.g. a CTE), not an actual table. 599 sources[source_name] = scope.sources[table_name] 600 elif source_name in sources: 601 sources[find_new_name(sources, table_name)] = expression 602 else: 603 sources[source_name] = expression 604 continue 605 606 if isinstance(expression, exp.UDTF): 607 lateral_sources = sources 608 scope_type = ScopeType.UDTF 609 scopes = scope.udtf_scopes 610 else: 611 lateral_sources = None 612 scope_type = ScopeType.DERIVED_TABLE 613 scopes = scope.derived_table_scopes 614 615 for child_scope in _traverse_scope( 616 scope.branch( 617 expression, 618 lateral_sources=lateral_sources, 619 outer_column_list=expression.alias_column_names, 620 scope_type=scope_type, 621 ) 622 ): 623 yield child_scope 624 625 # Tables without aliases will be set as "" 626 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 627 # Until then, this means that only a single, unaliased derived table is allowed (rather, 628 # the latest one wins. 629 alias = expression.alias 630 sources[alias] = child_scope 631 632 # append the final child_scope yielded 633 scopes.append(child_scope) 634 scope.table_scopes.append(child_scope) 635 636 scope.sources.update(sources) 637 638 639def _traverse_subqueries(scope): 640 for subquery in scope.subqueries: 641 top = None 642 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 643 yield child_scope 644 top = child_scope 645 scope.subquery_scopes.append(top) 646 647 648def walk_in_scope(expression, bfs=True): 649 """ 650 Returns a generator object which visits all nodes in the syntrax tree, stopping at 651 nodes that start child scopes. 652 653 Args: 654 expression (exp.Expression): 655 bfs (bool): if set to True the BFS traversal order will be applied, 656 otherwise the DFS traversal will be used instead. 657 658 Yields: 659 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 660 """ 661 # We'll use this variable to pass state into the dfs generator. 662 # Whenever we set it to True, we exclude a subtree from traversal. 663 prune = False 664 665 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): 666 prune = False 667 668 yield node, parent, key 669 670 if node is expression: 671 continue 672 if ( 673 isinstance(node, exp.CTE) 674 or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) 675 or isinstance(node, exp.UDTF) 676 or isinstance(node, exp.Subqueryable) 677 ): 678 prune = True
11class ScopeType(Enum): 12 ROOT = auto() 13 SUBQUERY = auto() 14 DERIVED_TABLE = auto() 15 CTE = auto() 16 UNION = auto() 17 UDTF = auto()
An enumeration.
Inherited Members
- enum.Enum
- name
- value
20class Scope: 21 """ 22 Selection scope. 23 24 Attributes: 25 expression (exp.Select|exp.Union): Root expression of this scope 26 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 27 a Table expression or another Scope instance. For example: 28 SELECT * FROM x {"x": Table(this="x")} 29 SELECT * FROM x AS y {"y": Table(this="x")} 30 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 31 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 32 For example: 33 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 34 The LATERAL VIEW EXPLODE gets x as a source. 35 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 36 defines a column list of it's alias of this scope, this is that list of columns. 37 For example: 38 SELECT * FROM (SELECT ...) AS y(col1, col2) 39 The inner query would have `["col1", "col2"]` for its `outer_column_list` 40 parent (Scope): Parent scope 41 scope_type (ScopeType): Type of this scope, relative to it's parent 42 subquery_scopes (list[Scope]): List of all child scopes for subqueries 43 cte_scopes (list[Scope]): List of all child scopes for CTEs 44 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 45 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 46 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 47 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 48 a list of the left and right child scopes. 49 """ 50 51 def __init__( 52 self, 53 expression, 54 sources=None, 55 outer_column_list=None, 56 parent=None, 57 scope_type=ScopeType.ROOT, 58 lateral_sources=None, 59 ): 60 self.expression = expression 61 self.sources = sources or {} 62 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 63 self.sources.update(self.lateral_sources) 64 self.outer_column_list = outer_column_list or [] 65 self.parent = parent 66 self.scope_type = scope_type 67 self.subquery_scopes = [] 68 self.derived_table_scopes = [] 69 self.table_scopes = [] 70 self.cte_scopes = [] 71 self.union_scopes = [] 72 self.udtf_scopes = [] 73 self.clear_cache() 74 75 def clear_cache(self): 76 self._collected = False 77 self._raw_columns = None 78 self._derived_tables = None 79 self._udtfs = None 80 self._tables = None 81 self._ctes = None 82 self._subqueries = None 83 self._selected_sources = None 84 self._columns = None 85 self._external_columns = None 86 self._join_hints = None 87 88 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 89 """Branch from the current scope to a new, inner scope""" 90 return Scope( 91 expression=expression.unnest(), 92 sources={**self.cte_sources, **(chain_sources or {})}, 93 parent=self, 94 scope_type=scope_type, 95 **kwargs, 96 ) 97 98 def _collect(self): 99 self._tables = [] 100 self._ctes = [] 101 self._subqueries = [] 102 self._derived_tables = [] 103 self._udtfs = [] 104 self._raw_columns = [] 105 self._join_hints = [] 106 107 for node, parent, _ in self.walk(bfs=False): 108 if node is self.expression: 109 continue 110 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 111 self._raw_columns.append(node) 112 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 113 self._tables.append(node) 114 elif isinstance(node, exp.JoinHint): 115 self._join_hints.append(node) 116 elif isinstance(node, exp.UDTF): 117 self._udtfs.append(node) 118 elif isinstance(node, exp.CTE): 119 self._ctes.append(node) 120 elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): 121 self._derived_tables.append(node) 122 elif isinstance(node, exp.Subqueryable): 123 self._subqueries.append(node) 124 125 self._collected = True 126 127 def _ensure_collected(self): 128 if not self._collected: 129 self._collect() 130 131 def walk(self, bfs=True): 132 return walk_in_scope(self.expression, bfs=bfs) 133 134 def find(self, *expression_types, bfs=True): 135 """ 136 Returns the first node in this scope which matches at least one of the specified types. 137 138 This does NOT traverse into subscopes. 139 140 Args: 141 expression_types (type): the expression type(s) to match. 142 bfs (bool): True to use breadth-first search, False to use depth-first. 143 144 Returns: 145 exp.Expression: the node which matches the criteria or None if no node matching 146 the criteria was found. 147 """ 148 return next(self.find_all(*expression_types, bfs=bfs), None) 149 150 def find_all(self, *expression_types, bfs=True): 151 """ 152 Returns a generator object which visits all nodes in this scope and only yields those that 153 match at least one of the specified expression types. 154 155 This does NOT traverse into subscopes. 156 157 Args: 158 expression_types (type): the expression type(s) to match. 159 bfs (bool): True to use breadth-first search, False to use depth-first. 160 161 Yields: 162 exp.Expression: nodes 163 """ 164 for expression, *_ in self.walk(bfs=bfs): 165 if isinstance(expression, expression_types): 166 yield expression 167 168 def replace(self, old, new): 169 """ 170 Replace `old` with `new`. 171 172 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 173 174 Args: 175 old (exp.Expression): old node 176 new (exp.Expression): new node 177 """ 178 old.replace(new) 179 self.clear_cache() 180 181 @property 182 def tables(self): 183 """ 184 List of tables in this scope. 185 186 Returns: 187 list[exp.Table]: tables 188 """ 189 self._ensure_collected() 190 return self._tables 191 192 @property 193 def ctes(self): 194 """ 195 List of CTEs in this scope. 196 197 Returns: 198 list[exp.CTE]: ctes 199 """ 200 self._ensure_collected() 201 return self._ctes 202 203 @property 204 def derived_tables(self): 205 """ 206 List of derived tables in this scope. 207 208 For example: 209 SELECT * FROM (SELECT ...) <- that's a derived table 210 211 Returns: 212 list[exp.Subquery]: derived tables 213 """ 214 self._ensure_collected() 215 return self._derived_tables 216 217 @property 218 def udtfs(self): 219 """ 220 List of "User Defined Tabular Functions" in this scope. 221 222 Returns: 223 list[exp.UDTF]: UDTFs 224 """ 225 self._ensure_collected() 226 return self._udtfs 227 228 @property 229 def subqueries(self): 230 """ 231 List of subqueries in this scope. 232 233 For example: 234 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 235 236 Returns: 237 list[exp.Subqueryable]: subqueries 238 """ 239 self._ensure_collected() 240 return self._subqueries 241 242 @property 243 def columns(self): 244 """ 245 List of columns in this scope. 246 247 Returns: 248 list[exp.Column]: Column instances in this scope, plus any 249 Columns that reference this scope from correlated subqueries. 250 """ 251 if self._columns is None: 252 self._ensure_collected() 253 columns = self._raw_columns 254 255 external_columns = [ 256 column 257 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 258 for column in scope.external_columns 259 ] 260 261 named_selects = set(self.expression.named_selects) 262 263 self._columns = [] 264 for column in columns + external_columns: 265 ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint) 266 if ( 267 not ancestor 268 # Window functions can have an ORDER BY clause 269 or not isinstance(ancestor.parent, exp.Select) 270 or column.table 271 or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) 272 ): 273 self._columns.append(column) 274 275 return self._columns 276 277 @property 278 def selected_sources(self): 279 """ 280 Mapping of nodes and sources that are actually selected from in this scope. 281 282 That is, all tables in a schema are selectable at any point. But a 283 table only becomes a selected source if it's included in a FROM or JOIN clause. 284 285 Returns: 286 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 287 """ 288 if self._selected_sources is None: 289 referenced_names = [] 290 291 for table in self.tables: 292 referenced_names.append((table.alias_or_name, table)) 293 for expression in itertools.chain(self.derived_tables, self.udtfs): 294 referenced_names.append((expression.alias, expression.unnest())) 295 result = {} 296 297 for name, node in referenced_names: 298 if name in result: 299 raise OptimizeError(f"Alias already used: {name}") 300 if name in self.sources: 301 result[name] = (node, self.sources[name]) 302 303 self._selected_sources = result 304 return self._selected_sources 305 306 @property 307 def cte_sources(self): 308 """ 309 Sources that are CTEs. 310 311 Returns: 312 dict[str, Scope]: Mapping of source alias to Scope 313 """ 314 return { 315 alias: scope 316 for alias, scope in self.sources.items() 317 if isinstance(scope, Scope) and scope.is_cte 318 } 319 320 @property 321 def selects(self): 322 """ 323 Select expressions of this scope. 324 325 For example, for the following expression: 326 SELECT 1 as a, 2 as b FROM x 327 328 The outputs are the "1 as a" and "2 as b" expressions. 329 330 Returns: 331 list[exp.Expression]: expressions 332 """ 333 if isinstance(self.expression, exp.Union): 334 return self.expression.unnest().selects 335 return self.expression.selects 336 337 @property 338 def external_columns(self): 339 """ 340 Columns that appear to reference sources in outer scopes. 341 342 Returns: 343 list[exp.Column]: Column instances that don't reference 344 sources in the current scope. 345 """ 346 if self._external_columns is None: 347 self._external_columns = [ 348 c for c in self.columns if c.table not in self.selected_sources 349 ] 350 return self._external_columns 351 352 @property 353 def unqualified_columns(self): 354 """ 355 Unqualified columns in the current scope. 356 357 Returns: 358 list[exp.Column]: Unqualified columns 359 """ 360 return [c for c in self.columns if not c.table] 361 362 @property 363 def join_hints(self): 364 """ 365 Hints that exist in the scope that reference tables 366 367 Returns: 368 list[exp.JoinHint]: Join hints that are referenced within the scope 369 """ 370 if self._join_hints is None: 371 return [] 372 return self._join_hints 373 374 def source_columns(self, source_name): 375 """ 376 Get all columns in the current scope for a particular source. 377 378 Args: 379 source_name (str): Name of the source 380 Returns: 381 list[exp.Column]: Column instances that reference `source_name` 382 """ 383 return [column for column in self.columns if column.table == source_name] 384 385 @property 386 def is_subquery(self): 387 """Determine if this scope is a subquery""" 388 return self.scope_type == ScopeType.SUBQUERY 389 390 @property 391 def is_derived_table(self): 392 """Determine if this scope is a derived table""" 393 return self.scope_type == ScopeType.DERIVED_TABLE 394 395 @property 396 def is_union(self): 397 """Determine if this scope is a union""" 398 return self.scope_type == ScopeType.UNION 399 400 @property 401 def is_cte(self): 402 """Determine if this scope is a common table expression""" 403 return self.scope_type == ScopeType.CTE 404 405 @property 406 def is_root(self): 407 """Determine if this is the root scope""" 408 return self.scope_type == ScopeType.ROOT 409 410 @property 411 def is_udtf(self): 412 """Determine if this scope is a UDTF (User Defined Table Function)""" 413 return self.scope_type == ScopeType.UDTF 414 415 @property 416 def is_correlated_subquery(self): 417 """Determine if this scope is a correlated subquery""" 418 return bool(self.is_subquery and self.external_columns) 419 420 def rename_source(self, old_name, new_name): 421 """Rename a source in this scope""" 422 columns = self.sources.pop(old_name or "", []) 423 self.sources[new_name] = columns 424 425 def add_source(self, name, source): 426 """Add a source to this scope""" 427 self.sources[name] = source 428 self.clear_cache() 429 430 def remove_source(self, name): 431 """Remove a source from this scope""" 432 self.sources.pop(name, None) 433 self.clear_cache() 434 435 def __repr__(self): 436 return f"Scope<{self.expression.sql()}>" 437 438 def traverse(self): 439 """ 440 Traverse the scope tree from this node. 441 442 Yields: 443 Scope: scope instances in depth-first-search post-order 444 """ 445 for child_scope in itertools.chain( 446 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 447 ): 448 yield from child_scope.traverse() 449 yield self 450 451 def ref_count(self): 452 """ 453 Count the number of times each scope in this tree is referenced. 454 455 Returns: 456 dict[int, int]: Mapping of Scope instance ID to reference count 457 """ 458 scope_ref_count = defaultdict(lambda: 0) 459 460 for scope in self.traverse(): 461 for _, source in scope.selected_sources.values(): 462 scope_ref_count[id(source)] += 1 463 464 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.Union): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_column_list
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
51 def __init__( 52 self, 53 expression, 54 sources=None, 55 outer_column_list=None, 56 parent=None, 57 scope_type=ScopeType.ROOT, 58 lateral_sources=None, 59 ): 60 self.expression = expression 61 self.sources = sources or {} 62 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 63 self.sources.update(self.lateral_sources) 64 self.outer_column_list = outer_column_list or [] 65 self.parent = parent 66 self.scope_type = scope_type 67 self.subquery_scopes = [] 68 self.derived_table_scopes = [] 69 self.table_scopes = [] 70 self.cte_scopes = [] 71 self.union_scopes = [] 72 self.udtf_scopes = [] 73 self.clear_cache()
75 def clear_cache(self): 76 self._collected = False 77 self._raw_columns = None 78 self._derived_tables = None 79 self._udtfs = None 80 self._tables = None 81 self._ctes = None 82 self._subqueries = None 83 self._selected_sources = None 84 self._columns = None 85 self._external_columns = None 86 self._join_hints = None
88 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 89 """Branch from the current scope to a new, inner scope""" 90 return Scope( 91 expression=expression.unnest(), 92 sources={**self.cte_sources, **(chain_sources or {})}, 93 parent=self, 94 scope_type=scope_type, 95 **kwargs, 96 )
Branch from the current scope to a new, inner scope
134 def find(self, *expression_types, bfs=True): 135 """ 136 Returns the first node in this scope which matches at least one of the specified types. 137 138 This does NOT traverse into subscopes. 139 140 Args: 141 expression_types (type): the expression type(s) to match. 142 bfs (bool): True to use breadth-first search, False to use depth-first. 143 144 Returns: 145 exp.Expression: the node which matches the criteria or None if no node matching 146 the criteria was found. 147 """ 148 return next(self.find_all(*expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.
150 def find_all(self, *expression_types, bfs=True): 151 """ 152 Returns a generator object which visits all nodes in this scope and only yields those that 153 match at least one of the specified expression types. 154 155 This does NOT traverse into subscopes. 156 157 Args: 158 expression_types (type): the expression type(s) to match. 159 bfs (bool): True to use breadth-first search, False to use depth-first. 160 161 Yields: 162 exp.Expression: nodes 163 """ 164 for expression, *_ in self.walk(bfs=bfs): 165 if isinstance(expression, expression_types): 166 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
168 def replace(self, old, new): 169 """ 170 Replace `old` with `new`. 171 172 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 173 174 Args: 175 old (exp.Expression): old node 176 new (exp.Expression): new node 177 """ 178 old.replace(new) 179 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Subqueryable]: subqueries
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
Select expressions of this scope.
For example, for the following expression: SELECT 1 as a, 2 as b FROM x
The outputs are the "1 as a" and "2 as b" expressions.
Returns:
list[exp.Expression]: expressions
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
374 def source_columns(self, source_name): 375 """ 376 Get all columns in the current scope for a particular source. 377 378 Args: 379 source_name (str): Name of the source 380 Returns: 381 list[exp.Column]: Column instances that reference `source_name` 382 """ 383 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
420 def rename_source(self, old_name, new_name): 421 """Rename a source in this scope""" 422 columns = self.sources.pop(old_name or "", []) 423 self.sources[new_name] = columns
Rename a source in this scope
425 def add_source(self, name, source): 426 """Add a source to this scope""" 427 self.sources[name] = source 428 self.clear_cache()
Add a source to this scope
430 def remove_source(self, name): 431 """Remove a source from this scope""" 432 self.sources.pop(name, None) 433 self.clear_cache()
Remove a source from this scope
438 def traverse(self): 439 """ 440 Traverse the scope tree from this node. 441 442 Yields: 443 Scope: scope instances in depth-first-search post-order 444 """ 445 for child_scope in itertools.chain( 446 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 447 ): 448 yield from child_scope.traverse() 449 yield self
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
451 def ref_count(self): 452 """ 453 Count the number of times each scope in this tree is referenced. 454 455 Returns: 456 dict[int, int]: Mapping of Scope instance ID to reference count 457 """ 458 scope_ref_count = defaultdict(lambda: 0) 459 460 for scope in self.traverse(): 461 for _, source in scope.selected_sources.values(): 462 scope_ref_count[id(source)] += 1 463 464 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
467def traverse_scope(expression): 468 """ 469 Traverse an expression by it's "scopes". 470 471 "Scope" represents the current context of a Select statement. 472 473 This is helpful for optimizing queries, where we need more information than 474 the expression tree itself. For example, we might care about the source 475 names within a subquery. Returns a list because a generator could result in 476 incomplete properties which is confusing. 477 478 Examples: 479 >>> import sqlglot 480 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 481 >>> scopes = traverse_scope(expression) 482 >>> scopes[0].expression.sql(), list(scopes[0].sources) 483 ('SELECT a FROM x', ['x']) 484 >>> scopes[1].expression.sql(), list(scopes[1].sources) 485 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 486 487 Args: 488 expression (exp.Expression): expression to traverse 489 Returns: 490 list[Scope]: scope instances 491 """ 492 return list(_traverse_scope(Scope(expression)))
Traverse an expression by it's "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression (exp.Expression): expression to traverse
Returns:
list[Scope]: scope instances
495def build_scope(expression): 496 """ 497 Build a scope tree. 498 499 Args: 500 expression (exp.Expression): expression to build the scope tree for 501 Returns: 502 Scope: root scope 503 """ 504 return traverse_scope(expression)[-1]
Build a scope tree.
Arguments:
- expression (exp.Expression): expression to build the scope tree for
Returns:
Scope: root scope
649def walk_in_scope(expression, bfs=True): 650 """ 651 Returns a generator object which visits all nodes in the syntrax tree, stopping at 652 nodes that start child scopes. 653 654 Args: 655 expression (exp.Expression): 656 bfs (bool): if set to True the BFS traversal order will be applied, 657 otherwise the DFS traversal will be used instead. 658 659 Yields: 660 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 661 """ 662 # We'll use this variable to pass state into the dfs generator. 663 # Whenever we set it to True, we exclude a subtree from traversal. 664 prune = False 665 666 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): 667 prune = False 668 669 yield node, parent, key 670 671 if node is expression: 672 continue 673 if ( 674 isinstance(node, exp.CTE) 675 or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) 676 or isinstance(node, exp.UDTF) 677 or isinstance(node, exp.Subqueryable) 678 ): 679 prune = True
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key