sqlglot.optimizer.scope
1import itertools 2from collections import defaultdict 3from enum import Enum, auto 4 5from sqlglot import exp 6from sqlglot.errors import OptimizeError 7 8 9class ScopeType(Enum): 10 ROOT = auto() 11 SUBQUERY = auto() 12 DERIVED_TABLE = auto() 13 CTE = auto() 14 UNION = auto() 15 UDTF = auto() 16 17 18class Scope: 19 """ 20 Selection scope. 21 22 Attributes: 23 expression (exp.Select|exp.Union): Root expression of this scope 24 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 25 a Table expression or another Scope instance. For example: 26 SELECT * FROM x {"x": Table(this="x")} 27 SELECT * FROM x AS y {"y": Table(this="x")} 28 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 29 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 30 For example: 31 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 32 The LATERAL VIEW EXPLODE gets x as a source. 33 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 34 defines a column list of it's alias of this scope, this is that list of columns. 35 For example: 36 SELECT * FROM (SELECT ...) AS y(col1, col2) 37 The inner query would have `["col1", "col2"]` for its `outer_column_list` 38 parent (Scope): Parent scope 39 scope_type (ScopeType): Type of this scope, relative to it's parent 40 subquery_scopes (list[Scope]): List of all child scopes for subqueries 41 cte_scopes (list[Scope]): List of all child scopes for CTEs 42 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 43 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 44 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 45 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 46 a list of the left and right child scopes. 47 """ 48 49 def __init__( 50 self, 51 expression, 52 sources=None, 53 outer_column_list=None, 54 parent=None, 55 scope_type=ScopeType.ROOT, 56 lateral_sources=None, 57 ): 58 self.expression = expression 59 self.sources = sources or {} 60 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 61 self.sources.update(self.lateral_sources) 62 self.outer_column_list = outer_column_list or [] 63 self.parent = parent 64 self.scope_type = scope_type 65 self.subquery_scopes = [] 66 self.derived_table_scopes = [] 67 self.table_scopes = [] 68 self.cte_scopes = [] 69 self.union_scopes = [] 70 self.udtf_scopes = [] 71 self.clear_cache() 72 73 def clear_cache(self): 74 self._collected = False 75 self._raw_columns = None 76 self._derived_tables = None 77 self._udtfs = None 78 self._tables = None 79 self._ctes = None 80 self._subqueries = None 81 self._selected_sources = None 82 self._columns = None 83 self._external_columns = None 84 self._join_hints = None 85 86 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 87 """Branch from the current scope to a new, inner scope""" 88 return Scope( 89 expression=expression.unnest(), 90 sources={**self.cte_sources, **(chain_sources or {})}, 91 parent=self, 92 scope_type=scope_type, 93 **kwargs, 94 ) 95 96 def _collect(self): 97 self._tables = [] 98 self._ctes = [] 99 self._subqueries = [] 100 self._derived_tables = [] 101 self._udtfs = [] 102 self._raw_columns = [] 103 self._join_hints = [] 104 105 for node, parent, _ in self.walk(bfs=False): 106 if node is self.expression: 107 continue 108 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 109 self._raw_columns.append(node) 110 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 111 self._tables.append(node) 112 elif isinstance(node, exp.JoinHint): 113 self._join_hints.append(node) 114 elif isinstance(node, exp.UDTF): 115 self._udtfs.append(node) 116 elif isinstance(node, exp.CTE): 117 self._ctes.append(node) 118 elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): 119 self._derived_tables.append(node) 120 elif isinstance(node, exp.Subqueryable): 121 self._subqueries.append(node) 122 123 self._collected = True 124 125 def _ensure_collected(self): 126 if not self._collected: 127 self._collect() 128 129 def walk(self, bfs=True): 130 return walk_in_scope(self.expression, bfs=bfs) 131 132 def find(self, *expression_types, bfs=True): 133 """ 134 Returns the first node in this scope which matches at least one of the specified types. 135 136 This does NOT traverse into subscopes. 137 138 Args: 139 expression_types (type): the expression type(s) to match. 140 bfs (bool): True to use breadth-first search, False to use depth-first. 141 142 Returns: 143 exp.Expression: the node which matches the criteria or None if no node matching 144 the criteria was found. 145 """ 146 return next(self.find_all(*expression_types, bfs=bfs), None) 147 148 def find_all(self, *expression_types, bfs=True): 149 """ 150 Returns a generator object which visits all nodes in this scope and only yields those that 151 match at least one of the specified expression types. 152 153 This does NOT traverse into subscopes. 154 155 Args: 156 expression_types (type): the expression type(s) to match. 157 bfs (bool): True to use breadth-first search, False to use depth-first. 158 159 Yields: 160 exp.Expression: nodes 161 """ 162 for expression, _, _ in self.walk(bfs=bfs): 163 if isinstance(expression, expression_types): 164 yield expression 165 166 def replace(self, old, new): 167 """ 168 Replace `old` with `new`. 169 170 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 171 172 Args: 173 old (exp.Expression): old node 174 new (exp.Expression): new node 175 """ 176 old.replace(new) 177 self.clear_cache() 178 179 @property 180 def tables(self): 181 """ 182 List of tables in this scope. 183 184 Returns: 185 list[exp.Table]: tables 186 """ 187 self._ensure_collected() 188 return self._tables 189 190 @property 191 def ctes(self): 192 """ 193 List of CTEs in this scope. 194 195 Returns: 196 list[exp.CTE]: ctes 197 """ 198 self._ensure_collected() 199 return self._ctes 200 201 @property 202 def derived_tables(self): 203 """ 204 List of derived tables in this scope. 205 206 For example: 207 SELECT * FROM (SELECT ...) <- that's a derived table 208 209 Returns: 210 list[exp.Subquery]: derived tables 211 """ 212 self._ensure_collected() 213 return self._derived_tables 214 215 @property 216 def udtfs(self): 217 """ 218 List of "User Defined Tabular Functions" in this scope. 219 220 Returns: 221 list[exp.UDTF]: UDTFs 222 """ 223 self._ensure_collected() 224 return self._udtfs 225 226 @property 227 def subqueries(self): 228 """ 229 List of subqueries in this scope. 230 231 For example: 232 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 233 234 Returns: 235 list[exp.Subqueryable]: subqueries 236 """ 237 self._ensure_collected() 238 return self._subqueries 239 240 @property 241 def columns(self): 242 """ 243 List of columns in this scope. 244 245 Returns: 246 list[exp.Column]: Column instances in this scope, plus any 247 Columns that reference this scope from correlated subqueries. 248 """ 249 if self._columns is None: 250 self._ensure_collected() 251 columns = self._raw_columns 252 253 external_columns = [ 254 column 255 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 256 for column in scope.external_columns 257 ] 258 259 named_selects = set(self.expression.named_selects) 260 261 self._columns = [] 262 for column in columns + external_columns: 263 ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint) 264 if ( 265 not ancestor 266 # Window functions can have an ORDER BY clause 267 or not isinstance(ancestor.parent, exp.Select) 268 or column.table 269 or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) 270 ): 271 self._columns.append(column) 272 273 return self._columns 274 275 @property 276 def selected_sources(self): 277 """ 278 Mapping of nodes and sources that are actually selected from in this scope. 279 280 That is, all tables in a schema are selectable at any point. But a 281 table only becomes a selected source if it's included in a FROM or JOIN clause. 282 283 Returns: 284 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 285 """ 286 if self._selected_sources is None: 287 referenced_names = [] 288 289 for table in self.tables: 290 referenced_names.append((table.alias_or_name, table)) 291 for expression in itertools.chain(self.derived_tables, self.udtfs): 292 referenced_names.append((expression.alias, expression.unnest())) 293 result = {} 294 295 for name, node in referenced_names: 296 if name in self.sources: 297 result[name] = (node, self.sources[name]) 298 299 self._selected_sources = result 300 return self._selected_sources 301 302 @property 303 def cte_sources(self): 304 """ 305 Sources that are CTEs. 306 307 Returns: 308 dict[str, Scope]: Mapping of source alias to Scope 309 """ 310 return { 311 alias: scope 312 for alias, scope in self.sources.items() 313 if isinstance(scope, Scope) and scope.is_cte 314 } 315 316 @property 317 def selects(self): 318 """ 319 Select expressions of this scope. 320 321 For example, for the following expression: 322 SELECT 1 as a, 2 as b FROM x 323 324 The outputs are the "1 as a" and "2 as b" expressions. 325 326 Returns: 327 list[exp.Expression]: expressions 328 """ 329 if isinstance(self.expression, exp.Union): 330 return self.expression.unnest().selects 331 return self.expression.selects 332 333 @property 334 def external_columns(self): 335 """ 336 Columns that appear to reference sources in outer scopes. 337 338 Returns: 339 list[exp.Column]: Column instances that don't reference 340 sources in the current scope. 341 """ 342 if self._external_columns is None: 343 self._external_columns = [ 344 c for c in self.columns if c.table not in self.selected_sources 345 ] 346 return self._external_columns 347 348 @property 349 def unqualified_columns(self): 350 """ 351 Unqualified columns in the current scope. 352 353 Returns: 354 list[exp.Column]: Unqualified columns 355 """ 356 return [c for c in self.columns if not c.table] 357 358 @property 359 def join_hints(self): 360 """ 361 Hints that exist in the scope that reference tables 362 363 Returns: 364 list[exp.JoinHint]: Join hints that are referenced within the scope 365 """ 366 if self._join_hints is None: 367 return [] 368 return self._join_hints 369 370 def source_columns(self, source_name): 371 """ 372 Get all columns in the current scope for a particular source. 373 374 Args: 375 source_name (str): Name of the source 376 Returns: 377 list[exp.Column]: Column instances that reference `source_name` 378 """ 379 return [column for column in self.columns if column.table == source_name] 380 381 @property 382 def is_subquery(self): 383 """Determine if this scope is a subquery""" 384 return self.scope_type == ScopeType.SUBQUERY 385 386 @property 387 def is_derived_table(self): 388 """Determine if this scope is a derived table""" 389 return self.scope_type == ScopeType.DERIVED_TABLE 390 391 @property 392 def is_union(self): 393 """Determine if this scope is a union""" 394 return self.scope_type == ScopeType.UNION 395 396 @property 397 def is_cte(self): 398 """Determine if this scope is a common table expression""" 399 return self.scope_type == ScopeType.CTE 400 401 @property 402 def is_root(self): 403 """Determine if this is the root scope""" 404 return self.scope_type == ScopeType.ROOT 405 406 @property 407 def is_udtf(self): 408 """Determine if this scope is a UDTF (User Defined Table Function)""" 409 return self.scope_type == ScopeType.UDTF 410 411 @property 412 def is_correlated_subquery(self): 413 """Determine if this scope is a correlated subquery""" 414 return bool(self.is_subquery and self.external_columns) 415 416 def rename_source(self, old_name, new_name): 417 """Rename a source in this scope""" 418 columns = self.sources.pop(old_name or "", []) 419 self.sources[new_name] = columns 420 421 def add_source(self, name, source): 422 """Add a source to this scope""" 423 self.sources[name] = source 424 self.clear_cache() 425 426 def remove_source(self, name): 427 """Remove a source from this scope""" 428 self.sources.pop(name, None) 429 self.clear_cache() 430 431 def __repr__(self): 432 return f"Scope<{self.expression.sql()}>" 433 434 def traverse(self): 435 """ 436 Traverse the scope tree from this node. 437 438 Yields: 439 Scope: scope instances in depth-first-search post-order 440 """ 441 for child_scope in itertools.chain( 442 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 443 ): 444 yield from child_scope.traverse() 445 yield self 446 447 def ref_count(self): 448 """ 449 Count the number of times each scope in this tree is referenced. 450 451 Returns: 452 dict[int, int]: Mapping of Scope instance ID to reference count 453 """ 454 scope_ref_count = defaultdict(lambda: 0) 455 456 for scope in self.traverse(): 457 for _, source in scope.selected_sources.values(): 458 scope_ref_count[id(source)] += 1 459 460 return scope_ref_count 461 462 463def traverse_scope(expression): 464 """ 465 Traverse an expression by it's "scopes". 466 467 "Scope" represents the current context of a Select statement. 468 469 This is helpful for optimizing queries, where we need more information than 470 the expression tree itself. For example, we might care about the source 471 names within a subquery. Returns a list because a generator could result in 472 incomplete properties which is confusing. 473 474 Examples: 475 >>> import sqlglot 476 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 477 >>> scopes = traverse_scope(expression) 478 >>> scopes[0].expression.sql(), list(scopes[0].sources) 479 ('SELECT a FROM x', ['x']) 480 >>> scopes[1].expression.sql(), list(scopes[1].sources) 481 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 482 483 Args: 484 expression (exp.Expression): expression to traverse 485 Returns: 486 list[Scope]: scope instances 487 """ 488 return list(_traverse_scope(Scope(expression))) 489 490 491def build_scope(expression): 492 """ 493 Build a scope tree. 494 495 Args: 496 expression (exp.Expression): expression to build the scope tree for 497 Returns: 498 Scope: root scope 499 """ 500 return traverse_scope(expression)[-1] 501 502 503def _traverse_scope(scope): 504 if isinstance(scope.expression, exp.Select): 505 yield from _traverse_select(scope) 506 elif isinstance(scope.expression, exp.Union): 507 yield from _traverse_union(scope) 508 elif isinstance(scope.expression, exp.Subquery): 509 yield from _traverse_subqueries(scope) 510 elif isinstance(scope.expression, exp.UDTF): 511 pass 512 else: 513 raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}") 514 yield scope 515 516 517def _traverse_select(scope): 518 yield from _traverse_ctes(scope) 519 yield from _traverse_tables(scope) 520 yield from _traverse_subqueries(scope) 521 522 523def _traverse_union(scope): 524 yield from _traverse_ctes(scope) 525 526 # The last scope to be yield should be the top most scope 527 left = None 528 for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)): 529 yield left 530 531 right = None 532 for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): 533 yield right 534 535 scope.union_scopes = [left, right] 536 537 538def _traverse_ctes(scope): 539 sources = {} 540 541 for cte in scope.ctes: 542 recursive_scope = None 543 544 # if the scope is a recursive cte, it must be in the form of 545 # base_case UNION recursive. thus the recursive scope is the first 546 # section of the union. 547 if scope.expression.args["with"].recursive: 548 union = cte.this 549 550 if isinstance(union, exp.Union): 551 recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) 552 553 for child_scope in _traverse_scope( 554 scope.branch( 555 cte.this, 556 chain_sources=sources, 557 outer_column_list=cte.alias_column_names, 558 scope_type=ScopeType.CTE, 559 ) 560 ): 561 yield child_scope 562 563 alias = cte.alias 564 sources[alias] = child_scope 565 566 if recursive_scope: 567 child_scope.add_source(alias, recursive_scope) 568 569 # append the final child_scope yielded 570 scope.cte_scopes.append(child_scope) 571 572 scope.sources.update(sources) 573 574 575def _traverse_tables(scope): 576 sources = {} 577 578 # Traverse FROMs, JOINs, and LATERALs in the order they are defined 579 expressions = [] 580 from_ = scope.expression.args.get("from") 581 if from_: 582 expressions.extend(from_.expressions) 583 584 for join in scope.expression.args.get("joins") or []: 585 expressions.append(join.this) 586 587 expressions.extend(scope.expression.args.get("laterals") or []) 588 589 for expression in expressions: 590 if isinstance(expression, exp.Table): 591 table_name = expression.name 592 source_name = expression.alias_or_name 593 594 if table_name in scope.sources: 595 # This is a reference to a parent source (e.g. a CTE), not an actual table. 596 sources[source_name] = scope.sources[table_name] 597 else: 598 sources[source_name] = expression 599 continue 600 601 if isinstance(expression, exp.UDTF): 602 lateral_sources = sources 603 scope_type = ScopeType.UDTF 604 scopes = scope.udtf_scopes 605 else: 606 lateral_sources = None 607 scope_type = ScopeType.DERIVED_TABLE 608 scopes = scope.derived_table_scopes 609 610 for child_scope in _traverse_scope( 611 scope.branch( 612 expression, 613 lateral_sources=lateral_sources, 614 outer_column_list=expression.alias_column_names, 615 scope_type=scope_type, 616 ) 617 ): 618 yield child_scope 619 620 # Tables without aliases will be set as "" 621 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. 622 # Until then, this means that only a single, unaliased derived table is allowed (rather, 623 # the latest one wins. 624 alias = expression.alias 625 sources[alias] = child_scope 626 627 # append the final child_scope yielded 628 scopes.append(child_scope) 629 scope.table_scopes.append(child_scope) 630 631 scope.sources.update(sources) 632 633 634def _traverse_subqueries(scope): 635 for subquery in scope.subqueries: 636 top = None 637 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): 638 yield child_scope 639 top = child_scope 640 scope.subquery_scopes.append(top) 641 642 643def walk_in_scope(expression, bfs=True): 644 """ 645 Returns a generator object which visits all nodes in the syntrax tree, stopping at 646 nodes that start child scopes. 647 648 Args: 649 expression (exp.Expression): 650 bfs (bool): if set to True the BFS traversal order will be applied, 651 otherwise the DFS traversal will be used instead. 652 653 Yields: 654 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 655 """ 656 # We'll use this variable to pass state into the dfs generator. 657 # Whenever we set it to True, we exclude a subtree from traversal. 658 prune = False 659 660 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): 661 prune = False 662 663 yield node, parent, key 664 665 if node is expression: 666 continue 667 if ( 668 isinstance(node, exp.CTE) 669 or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) 670 or isinstance(node, exp.UDTF) 671 or isinstance(node, exp.Subqueryable) 672 ): 673 prune = True
10class ScopeType(Enum): 11 ROOT = auto() 12 SUBQUERY = auto() 13 DERIVED_TABLE = auto() 14 CTE = auto() 15 UNION = auto() 16 UDTF = auto()
An enumeration.
Inherited Members
- enum.Enum
- name
- value
19class Scope: 20 """ 21 Selection scope. 22 23 Attributes: 24 expression (exp.Select|exp.Union): Root expression of this scope 25 sources (dict[str, exp.Table|Scope]): Mapping of source name to either 26 a Table expression or another Scope instance. For example: 27 SELECT * FROM x {"x": Table(this="x")} 28 SELECT * FROM x AS y {"y": Table(this="x")} 29 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} 30 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals 31 For example: 32 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; 33 The LATERAL VIEW EXPLODE gets x as a source. 34 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query 35 defines a column list of it's alias of this scope, this is that list of columns. 36 For example: 37 SELECT * FROM (SELECT ...) AS y(col1, col2) 38 The inner query would have `["col1", "col2"]` for its `outer_column_list` 39 parent (Scope): Parent scope 40 scope_type (ScopeType): Type of this scope, relative to it's parent 41 subquery_scopes (list[Scope]): List of all child scopes for subqueries 42 cte_scopes (list[Scope]): List of all child scopes for CTEs 43 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables 44 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions 45 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined 46 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be 47 a list of the left and right child scopes. 48 """ 49 50 def __init__( 51 self, 52 expression, 53 sources=None, 54 outer_column_list=None, 55 parent=None, 56 scope_type=ScopeType.ROOT, 57 lateral_sources=None, 58 ): 59 self.expression = expression 60 self.sources = sources or {} 61 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 62 self.sources.update(self.lateral_sources) 63 self.outer_column_list = outer_column_list or [] 64 self.parent = parent 65 self.scope_type = scope_type 66 self.subquery_scopes = [] 67 self.derived_table_scopes = [] 68 self.table_scopes = [] 69 self.cte_scopes = [] 70 self.union_scopes = [] 71 self.udtf_scopes = [] 72 self.clear_cache() 73 74 def clear_cache(self): 75 self._collected = False 76 self._raw_columns = None 77 self._derived_tables = None 78 self._udtfs = None 79 self._tables = None 80 self._ctes = None 81 self._subqueries = None 82 self._selected_sources = None 83 self._columns = None 84 self._external_columns = None 85 self._join_hints = None 86 87 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 88 """Branch from the current scope to a new, inner scope""" 89 return Scope( 90 expression=expression.unnest(), 91 sources={**self.cte_sources, **(chain_sources or {})}, 92 parent=self, 93 scope_type=scope_type, 94 **kwargs, 95 ) 96 97 def _collect(self): 98 self._tables = [] 99 self._ctes = [] 100 self._subqueries = [] 101 self._derived_tables = [] 102 self._udtfs = [] 103 self._raw_columns = [] 104 self._join_hints = [] 105 106 for node, parent, _ in self.walk(bfs=False): 107 if node is self.expression: 108 continue 109 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): 110 self._raw_columns.append(node) 111 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): 112 self._tables.append(node) 113 elif isinstance(node, exp.JoinHint): 114 self._join_hints.append(node) 115 elif isinstance(node, exp.UDTF): 116 self._udtfs.append(node) 117 elif isinstance(node, exp.CTE): 118 self._ctes.append(node) 119 elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): 120 self._derived_tables.append(node) 121 elif isinstance(node, exp.Subqueryable): 122 self._subqueries.append(node) 123 124 self._collected = True 125 126 def _ensure_collected(self): 127 if not self._collected: 128 self._collect() 129 130 def walk(self, bfs=True): 131 return walk_in_scope(self.expression, bfs=bfs) 132 133 def find(self, *expression_types, bfs=True): 134 """ 135 Returns the first node in this scope which matches at least one of the specified types. 136 137 This does NOT traverse into subscopes. 138 139 Args: 140 expression_types (type): the expression type(s) to match. 141 bfs (bool): True to use breadth-first search, False to use depth-first. 142 143 Returns: 144 exp.Expression: the node which matches the criteria or None if no node matching 145 the criteria was found. 146 """ 147 return next(self.find_all(*expression_types, bfs=bfs), None) 148 149 def find_all(self, *expression_types, bfs=True): 150 """ 151 Returns a generator object which visits all nodes in this scope and only yields those that 152 match at least one of the specified expression types. 153 154 This does NOT traverse into subscopes. 155 156 Args: 157 expression_types (type): the expression type(s) to match. 158 bfs (bool): True to use breadth-first search, False to use depth-first. 159 160 Yields: 161 exp.Expression: nodes 162 """ 163 for expression, _, _ in self.walk(bfs=bfs): 164 if isinstance(expression, expression_types): 165 yield expression 166 167 def replace(self, old, new): 168 """ 169 Replace `old` with `new`. 170 171 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 172 173 Args: 174 old (exp.Expression): old node 175 new (exp.Expression): new node 176 """ 177 old.replace(new) 178 self.clear_cache() 179 180 @property 181 def tables(self): 182 """ 183 List of tables in this scope. 184 185 Returns: 186 list[exp.Table]: tables 187 """ 188 self._ensure_collected() 189 return self._tables 190 191 @property 192 def ctes(self): 193 """ 194 List of CTEs in this scope. 195 196 Returns: 197 list[exp.CTE]: ctes 198 """ 199 self._ensure_collected() 200 return self._ctes 201 202 @property 203 def derived_tables(self): 204 """ 205 List of derived tables in this scope. 206 207 For example: 208 SELECT * FROM (SELECT ...) <- that's a derived table 209 210 Returns: 211 list[exp.Subquery]: derived tables 212 """ 213 self._ensure_collected() 214 return self._derived_tables 215 216 @property 217 def udtfs(self): 218 """ 219 List of "User Defined Tabular Functions" in this scope. 220 221 Returns: 222 list[exp.UDTF]: UDTFs 223 """ 224 self._ensure_collected() 225 return self._udtfs 226 227 @property 228 def subqueries(self): 229 """ 230 List of subqueries in this scope. 231 232 For example: 233 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery 234 235 Returns: 236 list[exp.Subqueryable]: subqueries 237 """ 238 self._ensure_collected() 239 return self._subqueries 240 241 @property 242 def columns(self): 243 """ 244 List of columns in this scope. 245 246 Returns: 247 list[exp.Column]: Column instances in this scope, plus any 248 Columns that reference this scope from correlated subqueries. 249 """ 250 if self._columns is None: 251 self._ensure_collected() 252 columns = self._raw_columns 253 254 external_columns = [ 255 column 256 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) 257 for column in scope.external_columns 258 ] 259 260 named_selects = set(self.expression.named_selects) 261 262 self._columns = [] 263 for column in columns + external_columns: 264 ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint) 265 if ( 266 not ancestor 267 # Window functions can have an ORDER BY clause 268 or not isinstance(ancestor.parent, exp.Select) 269 or column.table 270 or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) 271 ): 272 self._columns.append(column) 273 274 return self._columns 275 276 @property 277 def selected_sources(self): 278 """ 279 Mapping of nodes and sources that are actually selected from in this scope. 280 281 That is, all tables in a schema are selectable at any point. But a 282 table only becomes a selected source if it's included in a FROM or JOIN clause. 283 284 Returns: 285 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes 286 """ 287 if self._selected_sources is None: 288 referenced_names = [] 289 290 for table in self.tables: 291 referenced_names.append((table.alias_or_name, table)) 292 for expression in itertools.chain(self.derived_tables, self.udtfs): 293 referenced_names.append((expression.alias, expression.unnest())) 294 result = {} 295 296 for name, node in referenced_names: 297 if name in self.sources: 298 result[name] = (node, self.sources[name]) 299 300 self._selected_sources = result 301 return self._selected_sources 302 303 @property 304 def cte_sources(self): 305 """ 306 Sources that are CTEs. 307 308 Returns: 309 dict[str, Scope]: Mapping of source alias to Scope 310 """ 311 return { 312 alias: scope 313 for alias, scope in self.sources.items() 314 if isinstance(scope, Scope) and scope.is_cte 315 } 316 317 @property 318 def selects(self): 319 """ 320 Select expressions of this scope. 321 322 For example, for the following expression: 323 SELECT 1 as a, 2 as b FROM x 324 325 The outputs are the "1 as a" and "2 as b" expressions. 326 327 Returns: 328 list[exp.Expression]: expressions 329 """ 330 if isinstance(self.expression, exp.Union): 331 return self.expression.unnest().selects 332 return self.expression.selects 333 334 @property 335 def external_columns(self): 336 """ 337 Columns that appear to reference sources in outer scopes. 338 339 Returns: 340 list[exp.Column]: Column instances that don't reference 341 sources in the current scope. 342 """ 343 if self._external_columns is None: 344 self._external_columns = [ 345 c for c in self.columns if c.table not in self.selected_sources 346 ] 347 return self._external_columns 348 349 @property 350 def unqualified_columns(self): 351 """ 352 Unqualified columns in the current scope. 353 354 Returns: 355 list[exp.Column]: Unqualified columns 356 """ 357 return [c for c in self.columns if not c.table] 358 359 @property 360 def join_hints(self): 361 """ 362 Hints that exist in the scope that reference tables 363 364 Returns: 365 list[exp.JoinHint]: Join hints that are referenced within the scope 366 """ 367 if self._join_hints is None: 368 return [] 369 return self._join_hints 370 371 def source_columns(self, source_name): 372 """ 373 Get all columns in the current scope for a particular source. 374 375 Args: 376 source_name (str): Name of the source 377 Returns: 378 list[exp.Column]: Column instances that reference `source_name` 379 """ 380 return [column for column in self.columns if column.table == source_name] 381 382 @property 383 def is_subquery(self): 384 """Determine if this scope is a subquery""" 385 return self.scope_type == ScopeType.SUBQUERY 386 387 @property 388 def is_derived_table(self): 389 """Determine if this scope is a derived table""" 390 return self.scope_type == ScopeType.DERIVED_TABLE 391 392 @property 393 def is_union(self): 394 """Determine if this scope is a union""" 395 return self.scope_type == ScopeType.UNION 396 397 @property 398 def is_cte(self): 399 """Determine if this scope is a common table expression""" 400 return self.scope_type == ScopeType.CTE 401 402 @property 403 def is_root(self): 404 """Determine if this is the root scope""" 405 return self.scope_type == ScopeType.ROOT 406 407 @property 408 def is_udtf(self): 409 """Determine if this scope is a UDTF (User Defined Table Function)""" 410 return self.scope_type == ScopeType.UDTF 411 412 @property 413 def is_correlated_subquery(self): 414 """Determine if this scope is a correlated subquery""" 415 return bool(self.is_subquery and self.external_columns) 416 417 def rename_source(self, old_name, new_name): 418 """Rename a source in this scope""" 419 columns = self.sources.pop(old_name or "", []) 420 self.sources[new_name] = columns 421 422 def add_source(self, name, source): 423 """Add a source to this scope""" 424 self.sources[name] = source 425 self.clear_cache() 426 427 def remove_source(self, name): 428 """Remove a source from this scope""" 429 self.sources.pop(name, None) 430 self.clear_cache() 431 432 def __repr__(self): 433 return f"Scope<{self.expression.sql()}>" 434 435 def traverse(self): 436 """ 437 Traverse the scope tree from this node. 438 439 Yields: 440 Scope: scope instances in depth-first-search post-order 441 """ 442 for child_scope in itertools.chain( 443 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 444 ): 445 yield from child_scope.traverse() 446 yield self 447 448 def ref_count(self): 449 """ 450 Count the number of times each scope in this tree is referenced. 451 452 Returns: 453 dict[int, int]: Mapping of Scope instance ID to reference count 454 """ 455 scope_ref_count = defaultdict(lambda: 0) 456 457 for scope in self.traverse(): 458 for _, source in scope.selected_sources.values(): 459 scope_ref_count[id(source)] += 1 460 461 return scope_ref_count
Selection scope.
Attributes:
- expression (exp.Select|exp.Union): Root expression of this scope
- sources (dict[str, exp.Table|Scope]): Mapping of source name to either a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
- lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals For example: SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; The LATERAL VIEW EXPLODE gets x as a source.
- outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have
["col1", "col2"]
for itsouter_column_list
- parent (Scope): Parent scope
- scope_type (ScopeType): Type of this scope, relative to it's parent
- subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes (list[Scope]): List of all child scopes for CTEs
- derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
- udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
- table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
- union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
50 def __init__( 51 self, 52 expression, 53 sources=None, 54 outer_column_list=None, 55 parent=None, 56 scope_type=ScopeType.ROOT, 57 lateral_sources=None, 58 ): 59 self.expression = expression 60 self.sources = sources or {} 61 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} 62 self.sources.update(self.lateral_sources) 63 self.outer_column_list = outer_column_list or [] 64 self.parent = parent 65 self.scope_type = scope_type 66 self.subquery_scopes = [] 67 self.derived_table_scopes = [] 68 self.table_scopes = [] 69 self.cte_scopes = [] 70 self.union_scopes = [] 71 self.udtf_scopes = [] 72 self.clear_cache()
74 def clear_cache(self): 75 self._collected = False 76 self._raw_columns = None 77 self._derived_tables = None 78 self._udtfs = None 79 self._tables = None 80 self._ctes = None 81 self._subqueries = None 82 self._selected_sources = None 83 self._columns = None 84 self._external_columns = None 85 self._join_hints = None
87 def branch(self, expression, scope_type, chain_sources=None, **kwargs): 88 """Branch from the current scope to a new, inner scope""" 89 return Scope( 90 expression=expression.unnest(), 91 sources={**self.cte_sources, **(chain_sources or {})}, 92 parent=self, 93 scope_type=scope_type, 94 **kwargs, 95 )
Branch from the current scope to a new, inner scope
133 def find(self, *expression_types, bfs=True): 134 """ 135 Returns the first node in this scope which matches at least one of the specified types. 136 137 This does NOT traverse into subscopes. 138 139 Args: 140 expression_types (type): the expression type(s) to match. 141 bfs (bool): True to use breadth-first search, False to use depth-first. 142 143 Returns: 144 exp.Expression: the node which matches the criteria or None if no node matching 145 the criteria was found. 146 """ 147 return next(self.find_all(*expression_types, bfs=bfs), None)
Returns the first node in this scope which matches at least one of the specified types.
This does NOT traverse into subscopes.
Arguments:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Returns:
exp.Expression: the node which matches the criteria or None if no node matching the criteria was found.
149 def find_all(self, *expression_types, bfs=True): 150 """ 151 Returns a generator object which visits all nodes in this scope and only yields those that 152 match at least one of the specified expression types. 153 154 This does NOT traverse into subscopes. 155 156 Args: 157 expression_types (type): the expression type(s) to match. 158 bfs (bool): True to use breadth-first search, False to use depth-first. 159 160 Yields: 161 exp.Expression: nodes 162 """ 163 for expression, _, _ in self.walk(bfs=bfs): 164 if isinstance(expression, expression_types): 165 yield expression
Returns a generator object which visits all nodes in this scope and only yields those that match at least one of the specified expression types.
This does NOT traverse into subscopes.
Arguments:
- expression_types (type): the expression type(s) to match.
- bfs (bool): True to use breadth-first search, False to use depth-first.
Yields:
exp.Expression: nodes
167 def replace(self, old, new): 168 """ 169 Replace `old` with `new`. 170 171 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. 172 173 Args: 174 old (exp.Expression): old node 175 new (exp.Expression): new node 176 """ 177 old.replace(new) 178 self.clear_cache()
Replace old
with new
.
This can be used instead of exp.Expression.replace
to ensure the Scope
is kept up-to-date.
Arguments:
- old (exp.Expression): old node
- new (exp.Expression): new node
List of derived tables in this scope.
For example:
SELECT * FROM (SELECT ...) <- that's a derived table
Returns:
list[exp.Subquery]: derived tables
List of subqueries in this scope.
For example:
SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
Returns:
list[exp.Subqueryable]: subqueries
List of columns in this scope.
Returns:
list[exp.Column]: Column instances in this scope, plus any Columns that reference this scope from correlated subqueries.
Mapping of nodes and sources that are actually selected from in this scope.
That is, all tables in a schema are selectable at any point. But a table only becomes a selected source if it's included in a FROM or JOIN clause.
Returns:
dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes
Select expressions of this scope.
For example, for the following expression: SELECT 1 as a, 2 as b FROM x
The outputs are the "1 as a" and "2 as b" expressions.
Returns:
list[exp.Expression]: expressions
Columns that appear to reference sources in outer scopes.
Returns:
list[exp.Column]: Column instances that don't reference sources in the current scope.
Unqualified columns in the current scope.
Returns:
list[exp.Column]: Unqualified columns
Hints that exist in the scope that reference tables
Returns:
list[exp.JoinHint]: Join hints that are referenced within the scope
371 def source_columns(self, source_name): 372 """ 373 Get all columns in the current scope for a particular source. 374 375 Args: 376 source_name (str): Name of the source 377 Returns: 378 list[exp.Column]: Column instances that reference `source_name` 379 """ 380 return [column for column in self.columns if column.table == source_name]
Get all columns in the current scope for a particular source.
Arguments:
- source_name (str): Name of the source
Returns:
list[exp.Column]: Column instances that reference
source_name
417 def rename_source(self, old_name, new_name): 418 """Rename a source in this scope""" 419 columns = self.sources.pop(old_name or "", []) 420 self.sources[new_name] = columns
Rename a source in this scope
422 def add_source(self, name, source): 423 """Add a source to this scope""" 424 self.sources[name] = source 425 self.clear_cache()
Add a source to this scope
427 def remove_source(self, name): 428 """Remove a source from this scope""" 429 self.sources.pop(name, None) 430 self.clear_cache()
Remove a source from this scope
435 def traverse(self): 436 """ 437 Traverse the scope tree from this node. 438 439 Yields: 440 Scope: scope instances in depth-first-search post-order 441 """ 442 for child_scope in itertools.chain( 443 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes 444 ): 445 yield from child_scope.traverse() 446 yield self
Traverse the scope tree from this node.
Yields:
Scope: scope instances in depth-first-search post-order
448 def ref_count(self): 449 """ 450 Count the number of times each scope in this tree is referenced. 451 452 Returns: 453 dict[int, int]: Mapping of Scope instance ID to reference count 454 """ 455 scope_ref_count = defaultdict(lambda: 0) 456 457 for scope in self.traverse(): 458 for _, source in scope.selected_sources.values(): 459 scope_ref_count[id(source)] += 1 460 461 return scope_ref_count
Count the number of times each scope in this tree is referenced.
Returns:
dict[int, int]: Mapping of Scope instance ID to reference count
464def traverse_scope(expression): 465 """ 466 Traverse an expression by it's "scopes". 467 468 "Scope" represents the current context of a Select statement. 469 470 This is helpful for optimizing queries, where we need more information than 471 the expression tree itself. For example, we might care about the source 472 names within a subquery. Returns a list because a generator could result in 473 incomplete properties which is confusing. 474 475 Examples: 476 >>> import sqlglot 477 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") 478 >>> scopes = traverse_scope(expression) 479 >>> scopes[0].expression.sql(), list(scopes[0].sources) 480 ('SELECT a FROM x', ['x']) 481 >>> scopes[1].expression.sql(), list(scopes[1].sources) 482 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) 483 484 Args: 485 expression (exp.Expression): expression to traverse 486 Returns: 487 list[Scope]: scope instances 488 """ 489 return list(_traverse_scope(Scope(expression)))
Traverse an expression by it's "scopes".
"Scope" represents the current context of a Select statement.
This is helpful for optimizing queries, where we need more information than the expression tree itself. For example, we might care about the source names within a subquery. Returns a list because a generator could result in incomplete properties which is confusing.
Examples:
>>> import sqlglot >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") >>> scopes = traverse_scope(expression) >>> scopes[0].expression.sql(), list(scopes[0].sources) ('SELECT a FROM x', ['x']) >>> scopes[1].expression.sql(), list(scopes[1].sources) ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
Arguments:
- expression (exp.Expression): expression to traverse
Returns:
list[Scope]: scope instances
492def build_scope(expression): 493 """ 494 Build a scope tree. 495 496 Args: 497 expression (exp.Expression): expression to build the scope tree for 498 Returns: 499 Scope: root scope 500 """ 501 return traverse_scope(expression)[-1]
Build a scope tree.
Arguments:
- expression (exp.Expression): expression to build the scope tree for
Returns:
Scope: root scope
644def walk_in_scope(expression, bfs=True): 645 """ 646 Returns a generator object which visits all nodes in the syntrax tree, stopping at 647 nodes that start child scopes. 648 649 Args: 650 expression (exp.Expression): 651 bfs (bool): if set to True the BFS traversal order will be applied, 652 otherwise the DFS traversal will be used instead. 653 654 Yields: 655 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key 656 """ 657 # We'll use this variable to pass state into the dfs generator. 658 # Whenever we set it to True, we exclude a subtree from traversal. 659 prune = False 660 661 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): 662 prune = False 663 664 yield node, parent, key 665 666 if node is expression: 667 continue 668 if ( 669 isinstance(node, exp.CTE) 670 or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) 671 or isinstance(node, exp.UDTF) 672 or isinstance(node, exp.Subqueryable) 673 ): 674 prune = True
Returns a generator object which visits all nodes in the syntrax tree, stopping at nodes that start child scopes.
Arguments:
- expression (exp.Expression):
- bfs (bool): if set to True the BFS traversal order will be applied, otherwise the DFS traversal will be used instead.
Yields:
tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key