From 639a208fa57ea674d165c4837e96f3ae4d7e3e61 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 19 Feb 2023 14:45:09 +0100 Subject: Merging upstream version 11.1.3. Signed-off-by: Daniel Baumann --- docs/sqlglot/optimizer/scope.html | 2407 +++++++++++++++++++------------------ 1 file changed, 1254 insertions(+), 1153 deletions(-) (limited to 'docs/sqlglot/optimizer/scope.html') diff --git a/docs/sqlglot/optimizer/scope.html b/docs/sqlglot/optimizer/scope.html index 94b5f5b..1f751c9 100644 --- a/docs/sqlglot/optimizer/scope.html +++ b/docs/sqlglot/optimizer/scope.html @@ -87,6 +87,9 @@
  • derived_tables
  • +
  • + udtfs +
  • subqueries
  • @@ -213,610 +216,651 @@ 26 SELECT * FROM x {"x": Table(this="x")} 27 SELECT * FROM x AS y {"y": Table(this="x")} 28 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} - 29 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query - 30 defines a column list of it's alias of this scope, this is that list of columns. - 31 For example: - 32 SELECT * FROM (SELECT ...) AS y(col1, col2) - 33 The inner query would have `["col1", "col2"]` for its `outer_column_list` - 34 parent (Scope): Parent scope - 35 scope_type (ScopeType): Type of this scope, relative to it's parent - 36 subquery_scopes (list[Scope]): List of all child scopes for subqueries - 37 cte_scopes = (list[Scope]) List of all child scopes for CTEs - 38 derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables - 39 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be - 40 a list of the left and right child scopes. - 41 """ - 42 - 43 def __init__( - 44 self, - 45 expression, - 46 sources=None, - 47 outer_column_list=None, - 48 parent=None, - 49 scope_type=ScopeType.ROOT, - 50 ): - 51 self.expression = expression - 52 self.sources = sources or {} - 53 self.outer_column_list = outer_column_list or [] - 54 self.parent = parent - 55 self.scope_type = scope_type - 56 self.subquery_scopes = [] - 57 self.derived_table_scopes = [] - 58 self.cte_scopes = [] - 59 self.union_scopes = [] - 60 self.clear_cache() - 61 - 62 def clear_cache(self): - 63 self._collected = False - 64 self._raw_columns = None - 65 self._derived_tables = None - 66 self._tables = None - 67 self._ctes = None - 68 self._subqueries = None - 69 self._selected_sources = None - 70 self._columns = None - 71 self._external_columns = None - 72 self._join_hints = None - 73 - 74 def branch(self, expression, scope_type, chain_sources=None, **kwargs): - 75 """Branch from the current scope to a new, inner scope""" - 76 return Scope( - 77 expression=expression.unnest(), - 78 sources={**self.cte_sources, **(chain_sources or {})}, - 79 parent=self, - 80 scope_type=scope_type, - 81 **kwargs, - 82 ) - 83 - 84 def _collect(self): - 85 self._tables = [] - 86 self._ctes = [] - 87 self._subqueries = [] - 88 self._derived_tables = [] - 89 self._raw_columns = [] - 90 self._join_hints = [] - 91 - 92 for node, parent, _ in self.walk(bfs=False): - 93 if node is self.expression: - 94 continue - 95 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): - 96 self._raw_columns.append(node) - 97 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): - 98 self._tables.append(node) - 99 elif isinstance(node, exp.JoinHint): -100 self._join_hints.append(node) -101 elif isinstance(node, exp.UDTF): -102 self._derived_tables.append(node) -103 elif isinstance(node, exp.CTE): -104 self._ctes.append(node) -105 elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): -106 self._derived_tables.append(node) -107 elif isinstance(node, exp.Subqueryable): -108 self._subqueries.append(node) -109 -110 self._collected = True -111 -112 def _ensure_collected(self): -113 if not self._collected: -114 self._collect() -115 -116 def walk(self, bfs=True): -117 return walk_in_scope(self.expression, bfs=bfs) -118 -119 def find(self, *expression_types, bfs=True): -120 """ -121 Returns the first node in this scope which matches at least one of the specified types. + 29 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals + 30 For example: + 31 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; + 32 The LATERAL VIEW EXPLODE gets x as a source. + 33 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query + 34 defines a column list of it's alias of this scope, this is that list of columns. + 35 For example: + 36 SELECT * FROM (SELECT ...) AS y(col1, col2) + 37 The inner query would have `["col1", "col2"]` for its `outer_column_list` + 38 parent (Scope): Parent scope + 39 scope_type (ScopeType): Type of this scope, relative to it's parent + 40 subquery_scopes (list[Scope]): List of all child scopes for subqueries + 41 cte_scopes (list[Scope]): List of all child scopes for CTEs + 42 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables + 43 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions + 44 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined + 45 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be + 46 a list of the left and right child scopes. + 47 """ + 48 + 49 def __init__( + 50 self, + 51 expression, + 52 sources=None, + 53 outer_column_list=None, + 54 parent=None, + 55 scope_type=ScopeType.ROOT, + 56 lateral_sources=None, + 57 ): + 58 self.expression = expression + 59 self.sources = sources or {} + 60 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} + 61 self.sources.update(self.lateral_sources) + 62 self.outer_column_list = outer_column_list or [] + 63 self.parent = parent + 64 self.scope_type = scope_type + 65 self.subquery_scopes = [] + 66 self.derived_table_scopes = [] + 67 self.table_scopes = [] + 68 self.cte_scopes = [] + 69 self.union_scopes = [] + 70 self.udtf_scopes = [] + 71 self.clear_cache() + 72 + 73 def clear_cache(self): + 74 self._collected = False + 75 self._raw_columns = None + 76 self._derived_tables = None + 77 self._udtfs = None + 78 self._tables = None + 79 self._ctes = None + 80 self._subqueries = None + 81 self._selected_sources = None + 82 self._columns = None + 83 self._external_columns = None + 84 self._join_hints = None + 85 + 86 def branch(self, expression, scope_type, chain_sources=None, **kwargs): + 87 """Branch from the current scope to a new, inner scope""" + 88 return Scope( + 89 expression=expression.unnest(), + 90 sources={**self.cte_sources, **(chain_sources or {})}, + 91 parent=self, + 92 scope_type=scope_type, + 93 **kwargs, + 94 ) + 95 + 96 def _collect(self): + 97 self._tables = [] + 98 self._ctes = [] + 99 self._subqueries = [] +100 self._derived_tables = [] +101 self._udtfs = [] +102 self._raw_columns = [] +103 self._join_hints = [] +104 +105 for node, parent, _ in self.walk(bfs=False): +106 if node is self.expression: +107 continue +108 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): +109 self._raw_columns.append(node) +110 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): +111 self._tables.append(node) +112 elif isinstance(node, exp.JoinHint): +113 self._join_hints.append(node) +114 elif isinstance(node, exp.UDTF): +115 self._udtfs.append(node) +116 elif isinstance(node, exp.CTE): +117 self._ctes.append(node) +118 elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): +119 self._derived_tables.append(node) +120 elif isinstance(node, exp.Subqueryable): +121 self._subqueries.append(node) 122 -123 This does NOT traverse into subscopes. +123 self._collected = True 124 -125 Args: -126 expression_types (type): the expression type(s) to match. -127 bfs (bool): True to use breadth-first search, False to use depth-first. +125 def _ensure_collected(self): +126 if not self._collected: +127 self._collect() 128 -129 Returns: -130 exp.Expression: the node which matches the criteria or None if no node matching -131 the criteria was found. -132 """ -133 return next(self.find_all(*expression_types, bfs=bfs), None) -134 -135 def find_all(self, *expression_types, bfs=True): -136 """ -137 Returns a generator object which visits all nodes in this scope and only yields those that -138 match at least one of the specified expression types. -139 -140 This does NOT traverse into subscopes. +129 def walk(self, bfs=True): +130 return walk_in_scope(self.expression, bfs=bfs) +131 +132 def find(self, *expression_types, bfs=True): +133 """ +134 Returns the first node in this scope which matches at least one of the specified types. +135 +136 This does NOT traverse into subscopes. +137 +138 Args: +139 expression_types (type): the expression type(s) to match. +140 bfs (bool): True to use breadth-first search, False to use depth-first. 141 -142 Args: -143 expression_types (type): the expression type(s) to match. -144 bfs (bool): True to use breadth-first search, False to use depth-first. -145 -146 Yields: -147 exp.Expression: nodes -148 """ -149 for expression, _, _ in self.walk(bfs=bfs): -150 if isinstance(expression, expression_types): -151 yield expression +142 Returns: +143 exp.Expression: the node which matches the criteria or None if no node matching +144 the criteria was found. +145 """ +146 return next(self.find_all(*expression_types, bfs=bfs), None) +147 +148 def find_all(self, *expression_types, bfs=True): +149 """ +150 Returns a generator object which visits all nodes in this scope and only yields those that +151 match at least one of the specified expression types. 152 -153 def replace(self, old, new): -154 """ -155 Replace `old` with `new`. -156 -157 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. +153 This does NOT traverse into subscopes. +154 +155 Args: +156 expression_types (type): the expression type(s) to match. +157 bfs (bool): True to use breadth-first search, False to use depth-first. 158 -159 Args: -160 old (exp.Expression): old node -161 new (exp.Expression): new node -162 """ -163 old.replace(new) -164 self.clear_cache() +159 Yields: +160 exp.Expression: nodes +161 """ +162 for expression, _, _ in self.walk(bfs=bfs): +163 if isinstance(expression, expression_types): +164 yield expression 165 -166 @property -167 def tables(self): -168 """ -169 List of tables in this scope. -170 -171 Returns: -172 list[exp.Table]: tables -173 """ -174 self._ensure_collected() -175 return self._tables -176 -177 @property -178 def ctes(self): -179 """ -180 List of CTEs in this scope. -181 -182 Returns: -183 list[exp.CTE]: ctes -184 """ -185 self._ensure_collected() -186 return self._ctes -187 -188 @property -189 def derived_tables(self): -190 """ -191 List of derived tables in this scope. -192 -193 For example: -194 SELECT * FROM (SELECT ...) <- that's a derived table -195 -196 Returns: -197 list[exp.Subquery]: derived tables -198 """ -199 self._ensure_collected() -200 return self._derived_tables -201 -202 @property -203 def subqueries(self): -204 """ -205 List of subqueries in this scope. -206 -207 For example: -208 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery -209 -210 Returns: -211 list[exp.Subqueryable]: subqueries -212 """ -213 self._ensure_collected() -214 return self._subqueries -215 -216 @property -217 def columns(self): -218 """ -219 List of columns in this scope. -220 -221 Returns: -222 list[exp.Column]: Column instances in this scope, plus any -223 Columns that reference this scope from correlated subqueries. -224 """ -225 if self._columns is None: -226 self._ensure_collected() -227 columns = self._raw_columns -228 -229 external_columns = [ -230 column for scope in self.subquery_scopes for column in scope.external_columns -231 ] -232 -233 named_selects = set(self.expression.named_selects) -234 -235 self._columns = [] -236 for column in columns + external_columns: -237 ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint) -238 if ( -239 not ancestor -240 # Window functions can have an ORDER BY clause -241 or not isinstance(ancestor.parent, exp.Select) -242 or column.table -243 or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) -244 ): -245 self._columns.append(column) -246 -247 return self._columns -248 -249 @property -250 def selected_sources(self): -251 """ -252 Mapping of nodes and sources that are actually selected from in this scope. -253 -254 That is, all tables in a schema are selectable at any point. But a -255 table only becomes a selected source if it's included in a FROM or JOIN clause. -256 -257 Returns: -258 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes -259 """ -260 if self._selected_sources is None: -261 referenced_names = [] -262 -263 for table in self.tables: -264 referenced_names.append((table.alias_or_name, table)) -265 for derived_table in self.derived_tables: -266 referenced_names.append((derived_table.alias, derived_table.unnest())) -267 -268 result = {} -269 -270 for name, node in referenced_names: -271 if name in self.sources: -272 result[name] = (node, self.sources[name]) -273 -274 self._selected_sources = result -275 return self._selected_sources -276 -277 @property -278 def cte_sources(self): -279 """ -280 Sources that are CTEs. -281 -282 Returns: -283 dict[str, Scope]: Mapping of source alias to Scope -284 """ -285 return { -286 alias: scope -287 for alias, scope in self.sources.items() -288 if isinstance(scope, Scope) and scope.is_cte -289 } -290 -291 @property -292 def selects(self): -293 """ -294 Select expressions of this scope. -295 -296 For example, for the following expression: -297 SELECT 1 as a, 2 as b FROM x +166 def replace(self, old, new): +167 """ +168 Replace `old` with `new`. +169 +170 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. +171 +172 Args: +173 old (exp.Expression): old node +174 new (exp.Expression): new node +175 """ +176 old.replace(new) +177 self.clear_cache() +178 +179 @property +180 def tables(self): +181 """ +182 List of tables in this scope. +183 +184 Returns: +185 list[exp.Table]: tables +186 """ +187 self._ensure_collected() +188 return self._tables +189 +190 @property +191 def ctes(self): +192 """ +193 List of CTEs in this scope. +194 +195 Returns: +196 list[exp.CTE]: ctes +197 """ +198 self._ensure_collected() +199 return self._ctes +200 +201 @property +202 def derived_tables(self): +203 """ +204 List of derived tables in this scope. +205 +206 For example: +207 SELECT * FROM (SELECT ...) <- that's a derived table +208 +209 Returns: +210 list[exp.Subquery]: derived tables +211 """ +212 self._ensure_collected() +213 return self._derived_tables +214 +215 @property +216 def udtfs(self): +217 """ +218 List of "User Defined Tabular Functions" in this scope. +219 +220 Returns: +221 list[exp.UDTF]: UDTFs +222 """ +223 self._ensure_collected() +224 return self._udtfs +225 +226 @property +227 def subqueries(self): +228 """ +229 List of subqueries in this scope. +230 +231 For example: +232 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery +233 +234 Returns: +235 list[exp.Subqueryable]: subqueries +236 """ +237 self._ensure_collected() +238 return self._subqueries +239 +240 @property +241 def columns(self): +242 """ +243 List of columns in this scope. +244 +245 Returns: +246 list[exp.Column]: Column instances in this scope, plus any +247 Columns that reference this scope from correlated subqueries. +248 """ +249 if self._columns is None: +250 self._ensure_collected() +251 columns = self._raw_columns +252 +253 external_columns = [ +254 column +255 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) +256 for column in scope.external_columns +257 ] +258 +259 named_selects = set(self.expression.named_selects) +260 +261 self._columns = [] +262 for column in columns + external_columns: +263 ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint) +264 if ( +265 not ancestor +266 # Window functions can have an ORDER BY clause +267 or not isinstance(ancestor.parent, exp.Select) +268 or column.table +269 or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) +270 ): +271 self._columns.append(column) +272 +273 return self._columns +274 +275 @property +276 def selected_sources(self): +277 """ +278 Mapping of nodes and sources that are actually selected from in this scope. +279 +280 That is, all tables in a schema are selectable at any point. But a +281 table only becomes a selected source if it's included in a FROM or JOIN clause. +282 +283 Returns: +284 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes +285 """ +286 if self._selected_sources is None: +287 referenced_names = [] +288 +289 for table in self.tables: +290 referenced_names.append((table.alias_or_name, table)) +291 for expression in itertools.chain(self.derived_tables, self.udtfs): +292 referenced_names.append((expression.alias, expression.unnest())) +293 result = {} +294 +295 for name, node in referenced_names: +296 if name in self.sources: +297 result[name] = (node, self.sources[name]) 298 -299 The outputs are the "1 as a" and "2 as b" expressions. -300 -301 Returns: -302 list[exp.Expression]: expressions -303 """ -304 if isinstance(self.expression, exp.Union): -305 return self.expression.unnest().selects -306 return self.expression.selects -307 -308 @property -309 def external_columns(self): -310 """ -311 Columns that appear to reference sources in outer scopes. -312 -313 Returns: -314 list[exp.Column]: Column instances that don't reference -315 sources in the current scope. -316 """ -317 if self._external_columns is None: -318 self._external_columns = [ -319 c for c in self.columns if c.table not in self.selected_sources -320 ] -321 return self._external_columns -322 -323 @property -324 def unqualified_columns(self): -325 """ -326 Unqualified columns in the current scope. -327 -328 Returns: -329 list[exp.Column]: Unqualified columns -330 """ -331 return [c for c in self.columns if not c.table] +299 self._selected_sources = result +300 return self._selected_sources +301 +302 @property +303 def cte_sources(self): +304 """ +305 Sources that are CTEs. +306 +307 Returns: +308 dict[str, Scope]: Mapping of source alias to Scope +309 """ +310 return { +311 alias: scope +312 for alias, scope in self.sources.items() +313 if isinstance(scope, Scope) and scope.is_cte +314 } +315 +316 @property +317 def selects(self): +318 """ +319 Select expressions of this scope. +320 +321 For example, for the following expression: +322 SELECT 1 as a, 2 as b FROM x +323 +324 The outputs are the "1 as a" and "2 as b" expressions. +325 +326 Returns: +327 list[exp.Expression]: expressions +328 """ +329 if isinstance(self.expression, exp.Union): +330 return self.expression.unnest().selects +331 return self.expression.selects 332 333 @property -334 def join_hints(self): +334 def external_columns(self): 335 """ -336 Hints that exist in the scope that reference tables +336 Columns that appear to reference sources in outer scopes. 337 338 Returns: -339 list[exp.JoinHint]: Join hints that are referenced within the scope -340 """ -341 if self._join_hints is None: -342 return [] -343 return self._join_hints -344 -345 def source_columns(self, source_name): -346 """ -347 Get all columns in the current scope for a particular source. -348 -349 Args: -350 source_name (str): Name of the source -351 Returns: -352 list[exp.Column]: Column instances that reference `source_name` -353 """ -354 return [column for column in self.columns if column.table == source_name] -355 -356 @property -357 def is_subquery(self): -358 """Determine if this scope is a subquery""" -359 return self.scope_type == ScopeType.SUBQUERY -360 -361 @property -362 def is_derived_table(self): -363 """Determine if this scope is a derived table""" -364 return self.scope_type == ScopeType.DERIVED_TABLE -365 -366 @property -367 def is_union(self): -368 """Determine if this scope is a union""" -369 return self.scope_type == ScopeType.UNION -370 -371 @property -372 def is_cte(self): -373 """Determine if this scope is a common table expression""" -374 return self.scope_type == ScopeType.CTE -375 -376 @property -377 def is_root(self): -378 """Determine if this is the root scope""" -379 return self.scope_type == ScopeType.ROOT +339 list[exp.Column]: Column instances that don't reference +340 sources in the current scope. +341 """ +342 if self._external_columns is None: +343 self._external_columns = [ +344 c for c in self.columns if c.table not in self.selected_sources +345 ] +346 return self._external_columns +347 +348 @property +349 def unqualified_columns(self): +350 """ +351 Unqualified columns in the current scope. +352 +353 Returns: +354 list[exp.Column]: Unqualified columns +355 """ +356 return [c for c in self.columns if not c.table] +357 +358 @property +359 def join_hints(self): +360 """ +361 Hints that exist in the scope that reference tables +362 +363 Returns: +364 list[exp.JoinHint]: Join hints that are referenced within the scope +365 """ +366 if self._join_hints is None: +367 return [] +368 return self._join_hints +369 +370 def source_columns(self, source_name): +371 """ +372 Get all columns in the current scope for a particular source. +373 +374 Args: +375 source_name (str): Name of the source +376 Returns: +377 list[exp.Column]: Column instances that reference `source_name` +378 """ +379 return [column for column in self.columns if column.table == source_name] 380 381 @property -382 def is_udtf(self): -383 """Determine if this scope is a UDTF (User Defined Table Function)""" -384 return self.scope_type == ScopeType.UDTF +382 def is_subquery(self): +383 """Determine if this scope is a subquery""" +384 return self.scope_type == ScopeType.SUBQUERY 385 386 @property -387 def is_correlated_subquery(self): -388 """Determine if this scope is a correlated subquery""" -389 return bool(self.is_subquery and self.external_columns) +387 def is_derived_table(self): +388 """Determine if this scope is a derived table""" +389 return self.scope_type == ScopeType.DERIVED_TABLE 390 -391 def rename_source(self, old_name, new_name): -392 """Rename a source in this scope""" -393 columns = self.sources.pop(old_name or "", []) -394 self.sources[new_name] = columns +391 @property +392 def is_union(self): +393 """Determine if this scope is a union""" +394 return self.scope_type == ScopeType.UNION 395 -396 def add_source(self, name, source): -397 """Add a source to this scope""" -398 self.sources[name] = source -399 self.clear_cache() +396 @property +397 def is_cte(self): +398 """Determine if this scope is a common table expression""" +399 return self.scope_type == ScopeType.CTE 400 -401 def remove_source(self, name): -402 """Remove a source from this scope""" -403 self.sources.pop(name, None) -404 self.clear_cache() +401 @property +402 def is_root(self): +403 """Determine if this is the root scope""" +404 return self.scope_type == ScopeType.ROOT 405 -406 def __repr__(self): -407 return f"Scope<{self.expression.sql()}>" -408 -409 def traverse(self): -410 """ -411 Traverse the scope tree from this node. -412 -413 Yields: -414 Scope: scope instances in depth-first-search post-order -415 """ -416 for child_scope in itertools.chain( -417 self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes -418 ): -419 yield from child_scope.traverse() -420 yield self -421 -422 def ref_count(self): -423 """ -424 Count the number of times each scope in this tree is referenced. +406 @property +407 def is_udtf(self): +408 """Determine if this scope is a UDTF (User Defined Table Function)""" +409 return self.scope_type == ScopeType.UDTF +410 +411 @property +412 def is_correlated_subquery(self): +413 """Determine if this scope is a correlated subquery""" +414 return bool(self.is_subquery and self.external_columns) +415 +416 def rename_source(self, old_name, new_name): +417 """Rename a source in this scope""" +418 columns = self.sources.pop(old_name or "", []) +419 self.sources[new_name] = columns +420 +421 def add_source(self, name, source): +422 """Add a source to this scope""" +423 self.sources[name] = source +424 self.clear_cache() 425 -426 Returns: -427 dict[int, int]: Mapping of Scope instance ID to reference count -428 """ -429 scope_ref_count = defaultdict(lambda: 0) +426 def remove_source(self, name): +427 """Remove a source from this scope""" +428 self.sources.pop(name, None) +429 self.clear_cache() 430 -431 for scope in self.traverse(): -432 for _, source in scope.selected_sources.values(): -433 scope_ref_count[id(source)] += 1 -434 -435 return scope_ref_count -436 +431 def __repr__(self): +432 return f"Scope<{self.expression.sql()}>" +433 +434 def traverse(self): +435 """ +436 Traverse the scope tree from this node. 437 -438def traverse_scope(expression): -439 """ -440 Traverse an expression by it's "scopes". -441 -442 "Scope" represents the current context of a Select statement. -443 -444 This is helpful for optimizing queries, where we need more information than -445 the expression tree itself. For example, we might care about the source -446 names within a subquery. Returns a list because a generator could result in -447 incomplete properties which is confusing. -448 -449 Examples: -450 >>> import sqlglot -451 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") -452 >>> scopes = traverse_scope(expression) -453 >>> scopes[0].expression.sql(), list(scopes[0].sources) -454 ('SELECT a FROM x', ['x']) -455 >>> scopes[1].expression.sql(), list(scopes[1].sources) -456 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) -457 -458 Args: -459 expression (exp.Expression): expression to traverse -460 Returns: -461 list[Scope]: scope instances -462 """ -463 return list(_traverse_scope(Scope(expression))) -464 -465 -466def build_scope(expression): -467 """ -468 Build a scope tree. -469 -470 Args: -471 expression (exp.Expression): expression to build the scope tree for -472 Returns: -473 Scope: root scope -474 """ -475 return traverse_scope(expression)[-1] -476 -477 -478def _traverse_scope(scope): -479 if isinstance(scope.expression, exp.Select): -480 yield from _traverse_select(scope) -481 elif isinstance(scope.expression, exp.Union): -482 yield from _traverse_union(scope) -483 elif isinstance(scope.expression, exp.UDTF): -484 _set_udtf_scope(scope) -485 elif isinstance(scope.expression, exp.Subquery): -486 yield from _traverse_subqueries(scope) -487 else: -488 raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}") -489 yield scope +438 Yields: +439 Scope: scope instances in depth-first-search post-order +440 """ +441 for child_scope in itertools.chain( +442 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes +443 ): +444 yield from child_scope.traverse() +445 yield self +446 +447 def ref_count(self): +448 """ +449 Count the number of times each scope in this tree is referenced. +450 +451 Returns: +452 dict[int, int]: Mapping of Scope instance ID to reference count +453 """ +454 scope_ref_count = defaultdict(lambda: 0) +455 +456 for scope in self.traverse(): +457 for _, source in scope.selected_sources.values(): +458 scope_ref_count[id(source)] += 1 +459 +460 return scope_ref_count +461 +462 +463def traverse_scope(expression): +464 """ +465 Traverse an expression by it's "scopes". +466 +467 "Scope" represents the current context of a Select statement. +468 +469 This is helpful for optimizing queries, where we need more information than +470 the expression tree itself. For example, we might care about the source +471 names within a subquery. Returns a list because a generator could result in +472 incomplete properties which is confusing. +473 +474 Examples: +475 >>> import sqlglot +476 >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") +477 >>> scopes = traverse_scope(expression) +478 >>> scopes[0].expression.sql(), list(scopes[0].sources) +479 ('SELECT a FROM x', ['x']) +480 >>> scopes[1].expression.sql(), list(scopes[1].sources) +481 ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) +482 +483 Args: +484 expression (exp.Expression): expression to traverse +485 Returns: +486 list[Scope]: scope instances +487 """ +488 return list(_traverse_scope(Scope(expression))) +489 490 -491 -492def _traverse_select(scope): -493 yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE) -494 yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE) -495 yield from _traverse_subqueries(scope) -496 _add_table_sources(scope) -497 -498 -499def _traverse_union(scope): -500 yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE) +491def build_scope(expression): +492 """ +493 Build a scope tree. +494 +495 Args: +496 expression (exp.Expression): expression to build the scope tree for +497 Returns: +498 Scope: root scope +499 """ +500 return traverse_scope(expression)[-1] 501 -502 # The last scope to be yield should be the top most scope -503 left = None -504 for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)): -505 yield left -506 -507 right = None -508 for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): -509 yield right -510 -511 scope.union_scopes = [left, right] -512 -513 -514def _set_udtf_scope(scope): -515 parent = scope.expression.parent -516 from_ = parent.args.get("from") -517 -518 if not from_: -519 return -520 -521 for table in from_.expressions: -522 if isinstance(table, exp.Table): -523 scope.tables.append(table) -524 elif isinstance(table, exp.Subquery): -525 scope.subqueries.append(table) -526 _add_table_sources(scope) -527 _traverse_subqueries(scope) -528 -529 -530def _traverse_derived_tables(derived_tables, scope, scope_type): -531 sources = {} -532 is_cte = scope_type == ScopeType.CTE -533 -534 for derived_table in derived_tables: -535 recursive_scope = None +502 +503def _traverse_scope(scope): +504 if isinstance(scope.expression, exp.Select): +505 yield from _traverse_select(scope) +506 elif isinstance(scope.expression, exp.Union): +507 yield from _traverse_union(scope) +508 elif isinstance(scope.expression, exp.Subquery): +509 yield from _traverse_subqueries(scope) +510 elif isinstance(scope.expression, exp.UDTF): +511 pass +512 else: +513 raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}") +514 yield scope +515 +516 +517def _traverse_select(scope): +518 yield from _traverse_ctes(scope) +519 yield from _traverse_tables(scope) +520 yield from _traverse_subqueries(scope) +521 +522 +523def _traverse_union(scope): +524 yield from _traverse_ctes(scope) +525 +526 # The last scope to be yield should be the top most scope +527 left = None +528 for left in _traverse_scope(scope.branch(scope.expression.left, scope_type=ScopeType.UNION)): +529 yield left +530 +531 right = None +532 for right in _traverse_scope(scope.branch(scope.expression.right, scope_type=ScopeType.UNION)): +533 yield right +534 +535 scope.union_scopes = [left, right] 536 -537 # if the scope is a recursive cte, it must be in the form of -538 # base_case UNION recursive. thus the recursive scope is the first -539 # section of the union. -540 if is_cte and scope.expression.args["with"].recursive: -541 union = derived_table.this -542 -543 if isinstance(union, exp.Union): -544 recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) -545 -546 for child_scope in _traverse_scope( -547 scope.branch( -548 derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this, -549 chain_sources=sources if scope_type == ScopeType.CTE else None, -550 outer_column_list=derived_table.alias_column_names, -551 scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type, -552 ) -553 ): -554 yield child_scope -555 -556 # Tables without aliases will be set as "" -557 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. -558 # Until then, this means that only a single, unaliased derived table is allowed (rather, -559 # the latest one wins. -560 alias = derived_table.alias -561 sources[alias] = child_scope +537 +538def _traverse_ctes(scope): +539 sources = {} +540 +541 for cte in scope.ctes: +542 recursive_scope = None +543 +544 # if the scope is a recursive cte, it must be in the form of +545 # base_case UNION recursive. thus the recursive scope is the first +546 # section of the union. +547 if scope.expression.args["with"].recursive: +548 union = cte.this +549 +550 if isinstance(union, exp.Union): +551 recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) +552 +553 for child_scope in _traverse_scope( +554 scope.branch( +555 cte.this, +556 chain_sources=sources, +557 outer_column_list=cte.alias_column_names, +558 scope_type=ScopeType.CTE, +559 ) +560 ): +561 yield child_scope 562 -563 if recursive_scope: -564 child_scope.add_source(alias, recursive_scope) +563 alias = cte.alias +564 sources[alias] = child_scope 565 -566 # append the final child_scope yielded -567 if is_cte: -568 scope.cte_scopes.append(child_scope) -569 else: -570 scope.derived_table_scopes.append(child_scope) +566 if recursive_scope: +567 child_scope.add_source(alias, recursive_scope) +568 +569 # append the final child_scope yielded +570 scope.cte_scopes.append(child_scope) 571 572 scope.sources.update(sources) 573 574 -575def _add_table_sources(scope): +575def _traverse_tables(scope): 576 sources = {} -577 for table in scope.tables: -578 table_name = table.name -579 -580 if table.alias: -581 source_name = table.alias -582 else: -583 source_name = table_name -584 -585 if table_name in scope.sources: -586 # This is a reference to a parent source (e.g. a CTE), not an actual table. -587 scope.sources[source_name] = scope.sources[table_name] -588 else: -589 sources[source_name] = table -590 -591 scope.sources.update(sources) -592 +577 +578 # Traverse FROMs, JOINs, and LATERALs in the order they are defined +579 expressions = [] +580 from_ = scope.expression.args.get("from") +581 if from_: +582 expressions.extend(from_.expressions) +583 +584 for join in scope.expression.args.get("joins") or []: +585 expressions.append(join.this) +586 +587 expressions.extend(scope.expression.args.get("laterals") or []) +588 +589 for expression in expressions: +590 if isinstance(expression, exp.Table): +591 table_name = expression.name +592 source_name = expression.alias_or_name 593 -594def _traverse_subqueries(scope): -595 for subquery in scope.subqueries: -596 top = None -597 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): -598 yield child_scope -599 top = child_scope -600 scope.subquery_scopes.append(top) -601 -602 -603def walk_in_scope(expression, bfs=True): -604 """ -605 Returns a generator object which visits all nodes in the syntrax tree, stopping at -606 nodes that start child scopes. -607 -608 Args: -609 expression (exp.Expression): -610 bfs (bool): if set to True the BFS traversal order will be applied, -611 otherwise the DFS traversal will be used instead. -612 -613 Yields: -614 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key -615 """ -616 # We'll use this variable to pass state into the dfs generator. -617 # Whenever we set it to True, we exclude a subtree from traversal. -618 prune = False +594 if table_name in scope.sources: +595 # This is a reference to a parent source (e.g. a CTE), not an actual table. +596 sources[source_name] = scope.sources[table_name] +597 else: +598 sources[source_name] = expression +599 continue +600 +601 if isinstance(expression, exp.UDTF): +602 lateral_sources = sources +603 scope_type = ScopeType.UDTF +604 scopes = scope.udtf_scopes +605 else: +606 lateral_sources = None +607 scope_type = ScopeType.DERIVED_TABLE +608 scopes = scope.derived_table_scopes +609 +610 for child_scope in _traverse_scope( +611 scope.branch( +612 expression, +613 lateral_sources=lateral_sources, +614 outer_column_list=expression.alias_column_names, +615 scope_type=scope_type, +616 ) +617 ): +618 yield child_scope 619 -620 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): -621 prune = False -622 -623 yield node, parent, key -624 -625 if node is expression: -626 continue -627 elif isinstance(node, exp.CTE): -628 prune = True -629 elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): -630 prune = True -631 elif isinstance(node, exp.Subqueryable): -632 prune = True +620 # Tables without aliases will be set as "" +621 # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. +622 # Until then, this means that only a single, unaliased derived table is allowed (rather, +623 # the latest one wins. +624 alias = expression.alias +625 sources[alias] = child_scope +626 +627 # append the final child_scope yielded +628 scopes.append(child_scope) +629 scope.table_scopes.append(child_scope) +630 +631 scope.sources.update(sources) +632 +633 +634def _traverse_subqueries(scope): +635 for subquery in scope.subqueries: +636 top = None +637 for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): +638 yield child_scope +639 top = child_scope +640 scope.subquery_scopes.append(top) +641 +642 +643def walk_in_scope(expression, bfs=True): +644 """ +645 Returns a generator object which visits all nodes in the syntrax tree, stopping at +646 nodes that start child scopes. +647 +648 Args: +649 expression (exp.Expression): +650 bfs (bool): if set to True the BFS traversal order will be applied, +651 otherwise the DFS traversal will be used instead. +652 +653 Yields: +654 tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key +655 """ +656 # We'll use this variable to pass state into the dfs generator. +657 # Whenever we set it to True, we exclude a subtree from traversal. +658 prune = False +659 +660 for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune): +661 prune = False +662 +663 yield node, parent, key +664 +665 if node is expression: +666 continue +667 if ( +668 isinstance(node, exp.CTE) +669 or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) +670 or isinstance(node, exp.UDTF) +671 or isinstance(node, exp.Subqueryable) +672 ): +673 prune = True @@ -945,413 +989,438 @@ 27 SELECT * FROM x {"x": Table(this="x")} 28 SELECT * FROM x AS y {"y": Table(this="x")} 29 SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} - 30 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query - 31 defines a column list of it's alias of this scope, this is that list of columns. - 32 For example: - 33 SELECT * FROM (SELECT ...) AS y(col1, col2) - 34 The inner query would have `["col1", "col2"]` for its `outer_column_list` - 35 parent (Scope): Parent scope - 36 scope_type (ScopeType): Type of this scope, relative to it's parent - 37 subquery_scopes (list[Scope]): List of all child scopes for subqueries - 38 cte_scopes = (list[Scope]) List of all child scopes for CTEs - 39 derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables - 40 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be - 41 a list of the left and right child scopes. - 42 """ - 43 - 44 def __init__( - 45 self, - 46 expression, - 47 sources=None, - 48 outer_column_list=None, - 49 parent=None, - 50 scope_type=ScopeType.ROOT, - 51 ): - 52 self.expression = expression - 53 self.sources = sources or {} - 54 self.outer_column_list = outer_column_list or [] - 55 self.parent = parent - 56 self.scope_type = scope_type - 57 self.subquery_scopes = [] - 58 self.derived_table_scopes = [] - 59 self.cte_scopes = [] - 60 self.union_scopes = [] - 61 self.clear_cache() - 62 - 63 def clear_cache(self): - 64 self._collected = False - 65 self._raw_columns = None - 66 self._derived_tables = None - 67 self._tables = None - 68 self._ctes = None - 69 self._subqueries = None - 70 self._selected_sources = None - 71 self._columns = None - 72 self._external_columns = None - 73 self._join_hints = None - 74 - 75 def branch(self, expression, scope_type, chain_sources=None, **kwargs): - 76 """Branch from the current scope to a new, inner scope""" - 77 return Scope( - 78 expression=expression.unnest(), - 79 sources={**self.cte_sources, **(chain_sources or {})}, - 80 parent=self, - 81 scope_type=scope_type, - 82 **kwargs, - 83 ) - 84 - 85 def _collect(self): - 86 self._tables = [] - 87 self._ctes = [] - 88 self._subqueries = [] - 89 self._derived_tables = [] - 90 self._raw_columns = [] - 91 self._join_hints = [] - 92 - 93 for node, parent, _ in self.walk(bfs=False): - 94 if node is self.expression: - 95 continue - 96 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): - 97 self._raw_columns.append(node) - 98 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): - 99 self._tables.append(node) -100 elif isinstance(node, exp.JoinHint): -101 self._join_hints.append(node) -102 elif isinstance(node, exp.UDTF): -103 self._derived_tables.append(node) -104 elif isinstance(node, exp.CTE): -105 self._ctes.append(node) -106 elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): -107 self._derived_tables.append(node) -108 elif isinstance(node, exp.Subqueryable): -109 self._subqueries.append(node) -110 -111 self._collected = True -112 -113 def _ensure_collected(self): -114 if not self._collected: -115 self._collect() -116 -117 def walk(self, bfs=True): -118 return walk_in_scope(self.expression, bfs=bfs) -119 -120 def find(self, *expression_types, bfs=True): -121 """ -122 Returns the first node in this scope which matches at least one of the specified types. + 30 lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals + 31 For example: + 32 SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; + 33 The LATERAL VIEW EXPLODE gets x as a source. + 34 outer_column_list (list[str]): If this is a derived table or CTE, and the outer query + 35 defines a column list of it's alias of this scope, this is that list of columns. + 36 For example: + 37 SELECT * FROM (SELECT ...) AS y(col1, col2) + 38 The inner query would have `["col1", "col2"]` for its `outer_column_list` + 39 parent (Scope): Parent scope + 40 scope_type (ScopeType): Type of this scope, relative to it's parent + 41 subquery_scopes (list[Scope]): List of all child scopes for subqueries + 42 cte_scopes (list[Scope]): List of all child scopes for CTEs + 43 derived_table_scopes (list[Scope]): List of all child scopes for derived_tables + 44 udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions + 45 table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined + 46 union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be + 47 a list of the left and right child scopes. + 48 """ + 49 + 50 def __init__( + 51 self, + 52 expression, + 53 sources=None, + 54 outer_column_list=None, + 55 parent=None, + 56 scope_type=ScopeType.ROOT, + 57 lateral_sources=None, + 58 ): + 59 self.expression = expression + 60 self.sources = sources or {} + 61 self.lateral_sources = lateral_sources.copy() if lateral_sources else {} + 62 self.sources.update(self.lateral_sources) + 63 self.outer_column_list = outer_column_list or [] + 64 self.parent = parent + 65 self.scope_type = scope_type + 66 self.subquery_scopes = [] + 67 self.derived_table_scopes = [] + 68 self.table_scopes = [] + 69 self.cte_scopes = [] + 70 self.union_scopes = [] + 71 self.udtf_scopes = [] + 72 self.clear_cache() + 73 + 74 def clear_cache(self): + 75 self._collected = False + 76 self._raw_columns = None + 77 self._derived_tables = None + 78 self._udtfs = None + 79 self._tables = None + 80 self._ctes = None + 81 self._subqueries = None + 82 self._selected_sources = None + 83 self._columns = None + 84 self._external_columns = None + 85 self._join_hints = None + 86 + 87 def branch(self, expression, scope_type, chain_sources=None, **kwargs): + 88 """Branch from the current scope to a new, inner scope""" + 89 return Scope( + 90 expression=expression.unnest(), + 91 sources={**self.cte_sources, **(chain_sources or {})}, + 92 parent=self, + 93 scope_type=scope_type, + 94 **kwargs, + 95 ) + 96 + 97 def _collect(self): + 98 self._tables = [] + 99 self._ctes = [] +100 self._subqueries = [] +101 self._derived_tables = [] +102 self._udtfs = [] +103 self._raw_columns = [] +104 self._join_hints = [] +105 +106 for node, parent, _ in self.walk(bfs=False): +107 if node is self.expression: +108 continue +109 elif isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): +110 self._raw_columns.append(node) +111 elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): +112 self._tables.append(node) +113 elif isinstance(node, exp.JoinHint): +114 self._join_hints.append(node) +115 elif isinstance(node, exp.UDTF): +116 self._udtfs.append(node) +117 elif isinstance(node, exp.CTE): +118 self._ctes.append(node) +119 elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): +120 self._derived_tables.append(node) +121 elif isinstance(node, exp.Subqueryable): +122 self._subqueries.append(node) 123 -124 This does NOT traverse into subscopes. +124 self._collected = True 125 -126 Args: -127 expression_types (type): the expression type(s) to match. -128 bfs (bool): True to use breadth-first search, False to use depth-first. +126 def _ensure_collected(self): +127 if not self._collected: +128 self._collect() 129 -130 Returns: -131 exp.Expression: the node which matches the criteria or None if no node matching -132 the criteria was found. -133 """ -134 return next(self.find_all(*expression_types, bfs=bfs), None) -135 -136 def find_all(self, *expression_types, bfs=True): -137 """ -138 Returns a generator object which visits all nodes in this scope and only yields those that -139 match at least one of the specified expression types. -140 -141 This does NOT traverse into subscopes. +130 def walk(self, bfs=True): +131 return walk_in_scope(self.expression, bfs=bfs) +132 +133 def find(self, *expression_types, bfs=True): +134 """ +135 Returns the first node in this scope which matches at least one of the specified types. +136 +137 This does NOT traverse into subscopes. +138 +139 Args: +140 expression_types (type): the expression type(s) to match. +141 bfs (bool): True to use breadth-first search, False to use depth-first. 142 -143 Args: -144 expression_types (type): the expression type(s) to match. -145 bfs (bool): True to use breadth-first search, False to use depth-first. -146 -147 Yields: -148 exp.Expression: nodes -149 """ -150 for expression, _, _ in self.walk(bfs=bfs): -151 if isinstance(expression, expression_types): -152 yield expression +143 Returns: +144 exp.Expression: the node which matches the criteria or None if no node matching +145 the criteria was found. +146 """ +147 return next(self.find_all(*expression_types, bfs=bfs), None) +148 +149 def find_all(self, *expression_types, bfs=True): +150 """ +151 Returns a generator object which visits all nodes in this scope and only yields those that +152 match at least one of the specified expression types. 153 -154 def replace(self, old, new): -155 """ -156 Replace `old` with `new`. -157 -158 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. +154 This does NOT traverse into subscopes. +155 +156 Args: +157 expression_types (type): the expression type(s) to match. +158 bfs (bool): True to use breadth-first search, False to use depth-first. 159 -160 Args: -161 old (exp.Expression): old node -162 new (exp.Expression): new node -163 """ -164 old.replace(new) -165 self.clear_cache() +160 Yields: +161 exp.Expression: nodes +162 """ +163 for expression, _, _ in self.walk(bfs=bfs): +164 if isinstance(expression, expression_types): +165 yield expression 166 -167 @property -168 def tables(self): -169 """ -170 List of tables in this scope. -171 -172 Returns: -173 list[exp.Table]: tables -174 """ -175 self._ensure_collected() -176 return self._tables -177 -178 @property -179 def ctes(self): -180 """ -181 List of CTEs in this scope. -182 -183 Returns: -184 list[exp.CTE]: ctes -185 """ -186 self._ensure_collected() -187 return self._ctes -188 -189 @property -190 def derived_tables(self): -191 """ -192 List of derived tables in this scope. -193 -194 For example: -195 SELECT * FROM (SELECT ...) <- that's a derived table -196 -197 Returns: -198 list[exp.Subquery]: derived tables -199 """ -200 self._ensure_collected() -201 return self._derived_tables -202 -203 @property -204 def subqueries(self): -205 """ -206 List of subqueries in this scope. -207 -208 For example: -209 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery -210 -211 Returns: -212 list[exp.Subqueryable]: subqueries -213 """ -214 self._ensure_collected() -215 return self._subqueries -216 -217 @property -218 def columns(self): -219 """ -220 List of columns in this scope. -221 -222 Returns: -223 list[exp.Column]: Column instances in this scope, plus any -224 Columns that reference this scope from correlated subqueries. -225 """ -226 if self._columns is None: -227 self._ensure_collected() -228 columns = self._raw_columns -229 -230 external_columns = [ -231 column for scope in self.subquery_scopes for column in scope.external_columns -232 ] -233 -234 named_selects = set(self.expression.named_selects) -235 -236 self._columns = [] -237 for column in columns + external_columns: -238 ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint) -239 if ( -240 not ancestor -241 # Window functions can have an ORDER BY clause -242 or not isinstance(ancestor.parent, exp.Select) -243 or column.table -244 or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) -245 ): -246 self._columns.append(column) -247 -248 return self._columns -249 -250 @property -251 def selected_sources(self): -252 """ -253 Mapping of nodes and sources that are actually selected from in this scope. -254 -255 That is, all tables in a schema are selectable at any point. But a -256 table only becomes a selected source if it's included in a FROM or JOIN clause. -257 -258 Returns: -259 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes -260 """ -261 if self._selected_sources is None: -262 referenced_names = [] -263 -264 for table in self.tables: -265 referenced_names.append((table.alias_or_name, table)) -266 for derived_table in self.derived_tables: -267 referenced_names.append((derived_table.alias, derived_table.unnest())) -268 -269 result = {} -270 -271 for name, node in referenced_names: -272 if name in self.sources: -273 result[name] = (node, self.sources[name]) -274 -275 self._selected_sources = result -276 return self._selected_sources -277 -278 @property -279 def cte_sources(self): -280 """ -281 Sources that are CTEs. -282 -283 Returns: -284 dict[str, Scope]: Mapping of source alias to Scope -285 """ -286 return { -287 alias: scope -288 for alias, scope in self.sources.items() -289 if isinstance(scope, Scope) and scope.is_cte -290 } -291 -292 @property -293 def selects(self): -294 """ -295 Select expressions of this scope. -296 -297 For example, for the following expression: -298 SELECT 1 as a, 2 as b FROM x +167 def replace(self, old, new): +168 """ +169 Replace `old` with `new`. +170 +171 This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. +172 +173 Args: +174 old (exp.Expression): old node +175 new (exp.Expression): new node +176 """ +177 old.replace(new) +178 self.clear_cache() +179 +180 @property +181 def tables(self): +182 """ +183 List of tables in this scope. +184 +185 Returns: +186 list[exp.Table]: tables +187 """ +188 self._ensure_collected() +189 return self._tables +190 +191 @property +192 def ctes(self): +193 """ +194 List of CTEs in this scope. +195 +196 Returns: +197 list[exp.CTE]: ctes +198 """ +199 self._ensure_collected() +200 return self._ctes +201 +202 @property +203 def derived_tables(self): +204 """ +205 List of derived tables in this scope. +206 +207 For example: +208 SELECT * FROM (SELECT ...) <- that's a derived table +209 +210 Returns: +211 list[exp.Subquery]: derived tables +212 """ +213 self._ensure_collected() +214 return self._derived_tables +215 +216 @property +217 def udtfs(self): +218 """ +219 List of "User Defined Tabular Functions" in this scope. +220 +221 Returns: +222 list[exp.UDTF]: UDTFs +223 """ +224 self._ensure_collected() +225 return self._udtfs +226 +227 @property +228 def subqueries(self): +229 """ +230 List of subqueries in this scope. +231 +232 For example: +233 SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery +234 +235 Returns: +236 list[exp.Subqueryable]: subqueries +237 """ +238 self._ensure_collected() +239 return self._subqueries +240 +241 @property +242 def columns(self): +243 """ +244 List of columns in this scope. +245 +246 Returns: +247 list[exp.Column]: Column instances in this scope, plus any +248 Columns that reference this scope from correlated subqueries. +249 """ +250 if self._columns is None: +251 self._ensure_collected() +252 columns = self._raw_columns +253 +254 external_columns = [ +255 column +256 for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) +257 for column in scope.external_columns +258 ] +259 +260 named_selects = set(self.expression.named_selects) +261 +262 self._columns = [] +263 for column in columns + external_columns: +264 ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint) +265 if ( +266 not ancestor +267 # Window functions can have an ORDER BY clause +268 or not isinstance(ancestor.parent, exp.Select) +269 or column.table +270 or (column.name not in named_selects and not isinstance(ancestor, exp.Hint)) +271 ): +272 self._columns.append(column) +273 +274 return self._columns +275 +276 @property +277 def selected_sources(self): +278 """ +279 Mapping of nodes and sources that are actually selected from in this scope. +280 +281 That is, all tables in a schema are selectable at any point. But a +282 table only becomes a selected source if it's included in a FROM or JOIN clause. +283 +284 Returns: +285 dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes +286 """ +287 if self._selected_sources is None: +288 referenced_names = [] +289 +290 for table in self.tables: +291 referenced_names.append((table.alias_or_name, table)) +292 for expression in itertools.chain(self.derived_tables, self.udtfs): +293 referenced_names.append((expression.alias, expression.unnest())) +294 result = {} +295 +296 for name, node in referenced_names: +297 if name in self.sources: +298 result[name] = (node, self.sources[name]) 299 -300 The outputs are the "1 as a" and "2 as b" expressions. -301 -302 Returns: -303 list[exp.Expression]: expressions -304 """ -305 if isinstance(self.expression, exp.Union): -306 return self.expression.unnest().selects -307 return self.expression.selects -308 -309 @property -310 def external_columns(self): -311 """ -312 Columns that appear to reference sources in outer scopes. -313 -314 Returns: -315 list[exp.Column]: Column instances that don't reference -316 sources in the current scope. -317 """ -318 if self._external_columns is None: -319 self._external_columns = [ -320 c for c in self.columns if c.table not in self.selected_sources -321 ] -322 return self._external_columns -323 -324 @property -325 def unqualified_columns(self): -326 """ -327 Unqualified columns in the current scope. -328 -329 Returns: -330 list[exp.Column]: Unqualified columns -331 """ -332 return [c for c in self.columns if not c.table] +300 self._selected_sources = result +301 return self._selected_sources +302 +303 @property +304 def cte_sources(self): +305 """ +306 Sources that are CTEs. +307 +308 Returns: +309 dict[str, Scope]: Mapping of source alias to Scope +310 """ +311 return { +312 alias: scope +313 for alias, scope in self.sources.items() +314 if isinstance(scope, Scope) and scope.is_cte +315 } +316 +317 @property +318 def selects(self): +319 """ +320 Select expressions of this scope. +321 +322 For example, for the following expression: +323 SELECT 1 as a, 2 as b FROM x +324 +325 The outputs are the "1 as a" and "2 as b" expressions. +326 +327 Returns: +328 list[exp.Expression]: expressions +329 """ +330 if isinstance(self.expression, exp.Union): +331 return self.expression.unnest().selects +332 return self.expression.selects 333 334 @property -335 def join_hints(self): +335 def external_columns(self): 336 """ -337 Hints that exist in the scope that reference tables +337 Columns that appear to reference sources in outer scopes. 338 339 Returns: -340 list[exp.JoinHint]: Join hints that are referenced within the scope -341 """ -342 if self._join_hints is None: -343 return [] -344 return self._join_hints -345 -346 def source_columns(self, source_name): -347 """ -348 Get all columns in the current scope for a particular source. -349 -350 Args: -351 source_name (str): Name of the source -352 Returns: -353 list[exp.Column]: Column instances that reference `source_name` -354 """ -355 return [column for column in self.columns if column.table == source_name] -356 -357 @property -358 def is_subquery(self): -359 """Determine if this scope is a subquery""" -360 return self.scope_type == ScopeType.SUBQUERY -361 -362 @property -363 def is_derived_table(self): -364 """Determine if this scope is a derived table""" -365 return self.scope_type == ScopeType.DERIVED_TABLE -366 -367 @property -368 def is_union(self): -369 """Determine if this scope is a union""" -370 return self.scope_type == ScopeType.UNION -371 -372 @property -373 def is_cte(self): -374 """Determine if this scope is a common table expression""" -375 return self.scope_type == ScopeType.CTE -376 -377 @property -378 def is_root(self): -379 """Determine if this is the root scope""" -380 return self.scope_type == ScopeType.ROOT +340 list[exp.Column]: Column instances that don't reference +341 sources in the current scope. +342 """ +343 if self._external_columns is None: +344 self._external_columns = [ +345 c for c in self.columns if c.table not in self.selected_sources +346 ] +347 return self._external_columns +348 +349 @property +350 def unqualified_columns(self): +351 """ +352 Unqualified columns in the current scope. +353 +354 Returns: +355 list[exp.Column]: Unqualified columns +356 """ +357 return [c for c in self.columns if not c.table] +358 +359 @property +360 def join_hints(self): +361 """ +362 Hints that exist in the scope that reference tables +363 +364 Returns: +365 list[exp.JoinHint]: Join hints that are referenced within the scope +366 """ +367 if self._join_hints is None: +368 return [] +369 return self._join_hints +370 +371 def source_columns(self, source_name): +372 """ +373 Get all columns in the current scope for a particular source. +374 +375 Args: +376 source_name (str): Name of the source +377 Returns: +378 list[exp.Column]: Column instances that reference `source_name` +379 """ +380 return [column for column in self.columns if column.table == source_name] 381 382 @property -383 def is_udtf(self): -384 """Determine if this scope is a UDTF (User Defined Table Function)""" -385 return self.scope_type == ScopeType.UDTF +383 def is_subquery(self): +384 """Determine if this scope is a subquery""" +385 return self.scope_type == ScopeType.SUBQUERY 386 387 @property -388 def is_correlated_subquery(self): -389 """Determine if this scope is a correlated subquery""" -390 return bool(self.is_subquery and self.external_columns) +388 def is_derived_table(self): +389 """Determine if this scope is a derived table""" +390 return self.scope_type == ScopeType.DERIVED_TABLE 391 -392 def rename_source(self, old_name, new_name): -393 """Rename a source in this scope""" -394 columns = self.sources.pop(old_name or "", []) -395 self.sources[new_name] = columns +392 @property +393 def is_union(self): +394 """Determine if this scope is a union""" +395 return self.scope_type == ScopeType.UNION 396 -397 def add_source(self, name, source): -398 """Add a source to this scope""" -399 self.sources[name] = source -400 self.clear_cache() +397 @property +398 def is_cte(self): +399 """Determine if this scope is a common table expression""" +400 return self.scope_type == ScopeType.CTE 401 -402 def remove_source(self, name): -403 """Remove a source from this scope""" -404 self.sources.pop(name, None) -405 self.clear_cache() +402 @property +403 def is_root(self): +404 """Determine if this is the root scope""" +405 return self.scope_type == ScopeType.ROOT 406 -407 def __repr__(self): -408 return f"Scope<{self.expression.sql()}>" -409 -410 def traverse(self): -411 """ -412 Traverse the scope tree from this node. -413 -414 Yields: -415 Scope: scope instances in depth-first-search post-order -416 """ -417 for child_scope in itertools.chain( -418 self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes -419 ): -420 yield from child_scope.traverse() -421 yield self -422 -423 def ref_count(self): -424 """ -425 Count the number of times each scope in this tree is referenced. +407 @property +408 def is_udtf(self): +409 """Determine if this scope is a UDTF (User Defined Table Function)""" +410 return self.scope_type == ScopeType.UDTF +411 +412 @property +413 def is_correlated_subquery(self): +414 """Determine if this scope is a correlated subquery""" +415 return bool(self.is_subquery and self.external_columns) +416 +417 def rename_source(self, old_name, new_name): +418 """Rename a source in this scope""" +419 columns = self.sources.pop(old_name or "", []) +420 self.sources[new_name] = columns +421 +422 def add_source(self, name, source): +423 """Add a source to this scope""" +424 self.sources[name] = source +425 self.clear_cache() 426 -427 Returns: -428 dict[int, int]: Mapping of Scope instance ID to reference count -429 """ -430 scope_ref_count = defaultdict(lambda: 0) +427 def remove_source(self, name): +428 """Remove a source from this scope""" +429 self.sources.pop(name, None) +430 self.clear_cache() 431 -432 for scope in self.traverse(): -433 for _, source in scope.selected_sources.values(): -434 scope_ref_count[id(source)] += 1 -435 -436 return scope_ref_count +432 def __repr__(self): +433 return f"Scope<{self.expression.sql()}>" +434 +435 def traverse(self): +436 """ +437 Traverse the scope tree from this node. +438 +439 Yields: +440 Scope: scope instances in depth-first-search post-order +441 """ +442 for child_scope in itertools.chain( +443 self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes +444 ): +445 yield from child_scope.traverse() +446 yield self +447 +448 def ref_count(self): +449 """ +450 Count the number of times each scope in this tree is referenced. +451 +452 Returns: +453 dict[int, int]: Mapping of Scope instance ID to reference count +454 """ +455 scope_ref_count = defaultdict(lambda: 0) +456 +457 for scope in self.traverse(): +458 for _, source in scope.selected_sources.values(): +459 scope_ref_count[id(source)] += 1 +460 +461 return scope_ref_count @@ -1366,6 +1435,10 @@ a Table expression or another Scope instance. For example: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} +
  • lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals +For example: + SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; +The LATERAL VIEW EXPLODE gets x as a source.
  • outer_column_list (list[str]): If this is a derived table or CTE, and the outer query defines a column list of it's alias of this scope, this is that list of columns. For example: @@ -1374,8 +1447,10 @@ The inner query would have ["col1", "col2"] for its outer_col
  • parent (Scope): Parent scope
  • scope_type (ScopeType): Type of this scope, relative to it's parent
  • subquery_scopes (list[Scope]): List of all child scopes for subqueries
  • -
  • cte_scopes = (list[Scope]) List of all child scopes for CTEs
  • -
  • derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables
  • +
  • cte_scopes (list[Scope]): List of all child scopes for CTEs
  • +
  • derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
  • +
  • udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
  • +
  • table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
  • union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes.
  • @@ -1386,30 +1461,35 @@ a list of the left and right child scopes.
    - Scope( expression, sources=None, outer_column_list=None, parent=None, scope_type=<ScopeType.ROOT: 1>) + Scope( expression, sources=None, outer_column_list=None, parent=None, scope_type=<ScopeType.ROOT: 1>, lateral_sources=None)
    -
    44    def __init__(
    -45        self,
    -46        expression,
    -47        sources=None,
    -48        outer_column_list=None,
    -49        parent=None,
    -50        scope_type=ScopeType.ROOT,
    -51    ):
    -52        self.expression = expression
    -53        self.sources = sources or {}
    -54        self.outer_column_list = outer_column_list or []
    -55        self.parent = parent
    -56        self.scope_type = scope_type
    -57        self.subquery_scopes = []
    -58        self.derived_table_scopes = []
    -59        self.cte_scopes = []
    -60        self.union_scopes = []
    -61        self.clear_cache()
    +            
    50    def __init__(
    +51        self,
    +52        expression,
    +53        sources=None,
    +54        outer_column_list=None,
    +55        parent=None,
    +56        scope_type=ScopeType.ROOT,
    +57        lateral_sources=None,
    +58    ):
    +59        self.expression = expression
    +60        self.sources = sources or {}
    +61        self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
    +62        self.sources.update(self.lateral_sources)
    +63        self.outer_column_list = outer_column_list or []
    +64        self.parent = parent
    +65        self.scope_type = scope_type
    +66        self.subquery_scopes = []
    +67        self.derived_table_scopes = []
    +68        self.table_scopes = []
    +69        self.cte_scopes = []
    +70        self.union_scopes = []
    +71        self.udtf_scopes = []
    +72        self.clear_cache()
     
    @@ -1427,17 +1507,18 @@ a list of the left and right child scopes.
    -
    63    def clear_cache(self):
    -64        self._collected = False
    -65        self._raw_columns = None
    -66        self._derived_tables = None
    -67        self._tables = None
    -68        self._ctes = None
    -69        self._subqueries = None
    -70        self._selected_sources = None
    -71        self._columns = None
    -72        self._external_columns = None
    -73        self._join_hints = None
    +            
    74    def clear_cache(self):
    +75        self._collected = False
    +76        self._raw_columns = None
    +77        self._derived_tables = None
    +78        self._udtfs = None
    +79        self._tables = None
    +80        self._ctes = None
    +81        self._subqueries = None
    +82        self._selected_sources = None
    +83        self._columns = None
    +84        self._external_columns = None
    +85        self._join_hints = None
     
    @@ -1455,15 +1536,15 @@ a list of the left and right child scopes.
    -
    75    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
    -76        """Branch from the current scope to a new, inner scope"""
    -77        return Scope(
    -78            expression=expression.unnest(),
    -79            sources={**self.cte_sources, **(chain_sources or {})},
    -80            parent=self,
    -81            scope_type=scope_type,
    -82            **kwargs,
    -83        )
    +            
    87    def branch(self, expression, scope_type, chain_sources=None, **kwargs):
    +88        """Branch from the current scope to a new, inner scope"""
    +89        return Scope(
    +90            expression=expression.unnest(),
    +91            sources={**self.cte_sources, **(chain_sources or {})},
    +92            parent=self,
    +93            scope_type=scope_type,
    +94            **kwargs,
    +95        )
     
    @@ -1483,8 +1564,8 @@ a list of the left and right child scopes.
    -
    117    def walk(self, bfs=True):
    -118        return walk_in_scope(self.expression, bfs=bfs)
    +            
    130    def walk(self, bfs=True):
    +131        return walk_in_scope(self.expression, bfs=bfs)
     
    @@ -1502,21 +1583,21 @@ a list of the left and right child scopes.
    -
    120    def find(self, *expression_types, bfs=True):
    -121        """
    -122        Returns the first node in this scope which matches at least one of the specified types.
    -123
    -124        This does NOT traverse into subscopes.
    -125
    -126        Args:
    -127            expression_types (type): the expression type(s) to match.
    -128            bfs (bool): True to use breadth-first search, False to use depth-first.
    -129
    -130        Returns:
    -131            exp.Expression: the node which matches the criteria or None if no node matching
    -132            the criteria was found.
    -133        """
    -134        return next(self.find_all(*expression_types, bfs=bfs), None)
    +            
    133    def find(self, *expression_types, bfs=True):
    +134        """
    +135        Returns the first node in this scope which matches at least one of the specified types.
    +136
    +137        This does NOT traverse into subscopes.
    +138
    +139        Args:
    +140            expression_types (type): the expression type(s) to match.
    +141            bfs (bool): True to use breadth-first search, False to use depth-first.
    +142
    +143        Returns:
    +144            exp.Expression: the node which matches the criteria or None if no node matching
    +145            the criteria was found.
    +146        """
    +147        return next(self.find_all(*expression_types, bfs=bfs), None)
     
    @@ -1552,23 +1633,23 @@ a list of the left and right child scopes.
    -
    136    def find_all(self, *expression_types, bfs=True):
    -137        """
    -138        Returns a generator object which visits all nodes in this scope and only yields those that
    -139        match at least one of the specified expression types.
    -140
    -141        This does NOT traverse into subscopes.
    -142
    -143        Args:
    -144            expression_types (type): the expression type(s) to match.
    -145            bfs (bool): True to use breadth-first search, False to use depth-first.
    -146
    -147        Yields:
    -148            exp.Expression: nodes
    -149        """
    -150        for expression, _, _ in self.walk(bfs=bfs):
    -151            if isinstance(expression, expression_types):
    -152                yield expression
    +            
    149    def find_all(self, *expression_types, bfs=True):
    +150        """
    +151        Returns a generator object which visits all nodes in this scope and only yields those that
    +152        match at least one of the specified expression types.
    +153
    +154        This does NOT traverse into subscopes.
    +155
    +156        Args:
    +157            expression_types (type): the expression type(s) to match.
    +158            bfs (bool): True to use breadth-first search, False to use depth-first.
    +159
    +160        Yields:
    +161            exp.Expression: nodes
    +162        """
    +163        for expression, _, _ in self.walk(bfs=bfs):
    +164            if isinstance(expression, expression_types):
    +165                yield expression
     
    @@ -1604,18 +1685,18 @@ match at least one of the specified expression types.

    -
    154    def replace(self, old, new):
    -155        """
    -156        Replace `old` with `new`.
    -157
    -158        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
    -159
    -160        Args:
    -161            old (exp.Expression): old node
    -162            new (exp.Expression): new node
    -163        """
    -164        old.replace(new)
    -165        self.clear_cache()
    +            
    167    def replace(self, old, new):
    +168        """
    +169        Replace `old` with `new`.
    +170
    +171        This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
    +172
    +173        Args:
    +174            old (exp.Expression): old node
    +175            new (exp.Expression): new node
    +176        """
    +177        old.replace(new)
    +178        self.clear_cache()
     
    @@ -1695,6 +1776,25 @@ match at least one of the specified expression types.

    + +
    +
    + udtfs + + +
    + + +

    List of "User Defined Tabular Functions" in this scope.

    + +
    Returns:
    + +
    +

    list[exp.UDTF]: UDTFs

    +
    +
    + +
    @@ -1875,16 +1975,16 @@ table only becomes a selected source if it's included in a FROM or JOIN clause.<
    -
    346    def source_columns(self, source_name):
    -347        """
    -348        Get all columns in the current scope for a particular source.
    -349
    -350        Args:
    -351            source_name (str): Name of the source
    -352        Returns:
    -353            list[exp.Column]: Column instances that reference `source_name`
    -354        """
    -355        return [column for column in self.columns if column.table == source_name]
    +            
    371    def source_columns(self, source_name):
    +372        """
    +373        Get all columns in the current scope for a particular source.
    +374
    +375        Args:
    +376            source_name (str): Name of the source
    +377        Returns:
    +378            list[exp.Column]: Column instances that reference `source_name`
    +379        """
    +380        return [column for column in self.columns if column.table == source_name]
     
    @@ -2007,10 +2107,10 @@ table only becomes a selected source if it's included in a FROM or JOIN clause.<
    -
    392    def rename_source(self, old_name, new_name):
    -393        """Rename a source in this scope"""
    -394        columns = self.sources.pop(old_name or "", [])
    -395        self.sources[new_name] = columns
    +            
    417    def rename_source(self, old_name, new_name):
    +418        """Rename a source in this scope"""
    +419        columns = self.sources.pop(old_name or "", [])
    +420        self.sources[new_name] = columns
     
    @@ -2030,10 +2130,10 @@ table only becomes a selected source if it's included in a FROM or JOIN clause.<
    -
    397    def add_source(self, name, source):
    -398        """Add a source to this scope"""
    -399        self.sources[name] = source
    -400        self.clear_cache()
    +            
    422    def add_source(self, name, source):
    +423        """Add a source to this scope"""
    +424        self.sources[name] = source
    +425        self.clear_cache()
     
    @@ -2053,10 +2153,10 @@ table only becomes a selected source if it's included in a FROM or JOIN clause.<
    -
    402    def remove_source(self, name):
    -403        """Remove a source from this scope"""
    -404        self.sources.pop(name, None)
    -405        self.clear_cache()
    +            
    427    def remove_source(self, name):
    +428        """Remove a source from this scope"""
    +429        self.sources.pop(name, None)
    +430        self.clear_cache()
     
    @@ -2076,18 +2176,18 @@ table only becomes a selected source if it's included in a FROM or JOIN clause.<
    -
    410    def traverse(self):
    -411        """
    -412        Traverse the scope tree from this node.
    -413
    -414        Yields:
    -415            Scope: scope instances in depth-first-search post-order
    -416        """
    -417        for child_scope in itertools.chain(
    -418            self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes
    -419        ):
    -420            yield from child_scope.traverse()
    -421        yield self
    +            
    435    def traverse(self):
    +436        """
    +437        Traverse the scope tree from this node.
    +438
    +439        Yields:
    +440            Scope: scope instances in depth-first-search post-order
    +441        """
    +442        for child_scope in itertools.chain(
    +443            self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
    +444        ):
    +445            yield from child_scope.traverse()
    +446        yield self
     
    @@ -2113,20 +2213,20 @@ table only becomes a selected source if it's included in a FROM or JOIN clause.<
    -
    423    def ref_count(self):
    -424        """
    -425        Count the number of times each scope in this tree is referenced.
    -426
    -427        Returns:
    -428            dict[int, int]: Mapping of Scope instance ID to reference count
    -429        """
    -430        scope_ref_count = defaultdict(lambda: 0)
    -431
    -432        for scope in self.traverse():
    -433            for _, source in scope.selected_sources.values():
    -434                scope_ref_count[id(source)] += 1
    -435
    -436        return scope_ref_count
    +            
    448    def ref_count(self):
    +449        """
    +450        Count the number of times each scope in this tree is referenced.
    +451
    +452        Returns:
    +453            dict[int, int]: Mapping of Scope instance ID to reference count
    +454        """
    +455        scope_ref_count = defaultdict(lambda: 0)
    +456
    +457        for scope in self.traverse():
    +458            for _, source in scope.selected_sources.values():
    +459                scope_ref_count[id(source)] += 1
    +460
    +461        return scope_ref_count
     
    @@ -2153,32 +2253,32 @@ table only becomes a selected source if it's included in a FROM or JOIN clause.<
    -
    439def traverse_scope(expression):
    -440    """
    -441    Traverse an expression by it's "scopes".
    -442
    -443    "Scope" represents the current context of a Select statement.
    -444
    -445    This is helpful for optimizing queries, where we need more information than
    -446    the expression tree itself. For example, we might care about the source
    -447    names within a subquery. Returns a list because a generator could result in
    -448    incomplete properties which is confusing.
    -449
    -450    Examples:
    -451        >>> import sqlglot
    -452        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
    -453        >>> scopes = traverse_scope(expression)
    -454        >>> scopes[0].expression.sql(), list(scopes[0].sources)
    -455        ('SELECT a FROM x', ['x'])
    -456        >>> scopes[1].expression.sql(), list(scopes[1].sources)
    -457        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
    -458
    -459    Args:
    -460        expression (exp.Expression): expression to traverse
    -461    Returns:
    -462        list[Scope]: scope instances
    -463    """
    -464    return list(_traverse_scope(Scope(expression)))
    +            
    464def traverse_scope(expression):
    +465    """
    +466    Traverse an expression by it's "scopes".
    +467
    +468    "Scope" represents the current context of a Select statement.
    +469
    +470    This is helpful for optimizing queries, where we need more information than
    +471    the expression tree itself. For example, we might care about the source
    +472    names within a subquery. Returns a list because a generator could result in
    +473    incomplete properties which is confusing.
    +474
    +475    Examples:
    +476        >>> import sqlglot
    +477        >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
    +478        >>> scopes = traverse_scope(expression)
    +479        >>> scopes[0].expression.sql(), list(scopes[0].sources)
    +480        ('SELECT a FROM x', ['x'])
    +481        >>> scopes[1].expression.sql(), list(scopes[1].sources)
    +482        ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
    +483
    +484    Args:
    +485        expression (exp.Expression): expression to traverse
    +486    Returns:
    +487        list[Scope]: scope instances
    +488    """
    +489    return list(_traverse_scope(Scope(expression)))
     
    @@ -2232,16 +2332,16 @@ incomplete properties which is confusing.

    -
    467def build_scope(expression):
    -468    """
    -469    Build a scope tree.
    -470
    -471    Args:
    -472        expression (exp.Expression): expression to build the scope tree for
    -473    Returns:
    -474        Scope: root scope
    -475    """
    -476    return traverse_scope(expression)[-1]
    +            
    492def build_scope(expression):
    +493    """
    +494    Build a scope tree.
    +495
    +496    Args:
    +497        expression (exp.Expression): expression to build the scope tree for
    +498    Returns:
    +499        Scope: root scope
    +500    """
    +501    return traverse_scope(expression)[-1]
     
    @@ -2273,36 +2373,37 @@ incomplete properties which is confusing.

    -
    604def walk_in_scope(expression, bfs=True):
    -605    """
    -606    Returns a generator object which visits all nodes in the syntrax tree, stopping at
    -607    nodes that start child scopes.
    -608
    -609    Args:
    -610        expression (exp.Expression):
    -611        bfs (bool): if set to True the BFS traversal order will be applied,
    -612            otherwise the DFS traversal will be used instead.
    -613
    -614    Yields:
    -615        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
    -616    """
    -617    # We'll use this variable to pass state into the dfs generator.
    -618    # Whenever we set it to True, we exclude a subtree from traversal.
    -619    prune = False
    -620
    -621    for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
    -622        prune = False
    -623
    -624        yield node, parent, key
    -625
    -626        if node is expression:
    -627            continue
    -628        elif isinstance(node, exp.CTE):
    -629            prune = True
    -630        elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
    -631            prune = True
    -632        elif isinstance(node, exp.Subqueryable):
    -633            prune = True
    +            
    644def walk_in_scope(expression, bfs=True):
    +645    """
    +646    Returns a generator object which visits all nodes in the syntrax tree, stopping at
    +647    nodes that start child scopes.
    +648
    +649    Args:
    +650        expression (exp.Expression):
    +651        bfs (bool): if set to True the BFS traversal order will be applied,
    +652            otherwise the DFS traversal will be used instead.
    +653
    +654    Yields:
    +655        tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key
    +656    """
    +657    # We'll use this variable to pass state into the dfs generator.
    +658    # Whenever we set it to True, we exclude a subtree from traversal.
    +659    prune = False
    +660
    +661    for node, parent, key in expression.walk(bfs=bfs, prune=lambda *_: prune):
    +662        prune = False
    +663
    +664        yield node, parent, key
    +665
    +666        if node is expression:
    +667            continue
    +668        if (
    +669            isinstance(node, exp.CTE)
    +670            or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)))
    +671            or isinstance(node, exp.UDTF)
    +672            or isinstance(node, exp.Subqueryable)
    +673        ):
    +674            prune = True
     
    -- cgit v1.2.3