diff options
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r-- | sqlglot/optimizer/scope.py | 169 |
1 files changed, 105 insertions, 64 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 8565c64..335ff3e 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -26,6 +26,10 @@ class Scope: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} + lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals + For example: + SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; + The LATERAL VIEW EXPLODE gets x as a source. outer_column_list (list[str]): If this is a derived table or CTE, and the outer query defines a column list of it's alias of this scope, this is that list of columns. For example: @@ -34,8 +38,10 @@ class Scope: parent (Scope): Parent scope scope_type (ScopeType): Type of this scope, relative to it's parent subquery_scopes (list[Scope]): List of all child scopes for subqueries - cte_scopes = (list[Scope]) List of all child scopes for CTEs - derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables + cte_scopes (list[Scope]): List of all child scopes for CTEs + derived_table_scopes (list[Scope]): List of all child scopes for derived_tables + udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions + table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes. """ @@ -47,22 +53,28 @@ class Scope: outer_column_list=None, parent=None, scope_type=ScopeType.ROOT, + lateral_sources=None, ): self.expression = expression self.sources = sources or {} + self.lateral_sources = lateral_sources.copy() if lateral_sources else {} + self.sources.update(self.lateral_sources) self.outer_column_list = outer_column_list or [] self.parent = parent self.scope_type = scope_type self.subquery_scopes = [] self.derived_table_scopes = [] + self.table_scopes = [] self.cte_scopes = [] self.union_scopes = [] + self.udtf_scopes = [] self.clear_cache() def clear_cache(self): self._collected = False self._raw_columns = None self._derived_tables = None + self._udtfs = None self._tables = None self._ctes = None self._subqueries = None @@ -86,6 +98,7 @@ class Scope: self._ctes = [] self._subqueries = [] self._derived_tables = [] + self._udtfs = [] self._raw_columns = [] self._join_hints = [] @@ -99,7 +112,7 @@ class Scope: elif isinstance(node, exp.JoinHint): self._join_hints.append(node) elif isinstance(node, exp.UDTF): - self._derived_tables.append(node) + self._udtfs.append(node) elif isinstance(node, exp.CTE): self._ctes.append(node) elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): @@ -200,6 +213,17 @@ class Scope: return self._derived_tables @property + def udtfs(self): + """ + List of "User Defined Tabular Functions" in this scope. + + Returns: + list[exp.UDTF]: UDTFs + """ + self._ensure_collected() + return self._udtfs + + @property def subqueries(self): """ List of subqueries in this scope. @@ -227,7 +251,9 @@ class Scope: columns = self._raw_columns external_columns = [ - column for scope in self.subquery_scopes for column in scope.external_columns + column + for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) + for column in scope.external_columns ] named_selects = set(self.expression.named_selects) @@ -262,9 +288,8 @@ class Scope: for table in self.tables: referenced_names.append((table.alias_or_name, table)) - for derived_table in self.derived_tables: - referenced_names.append((derived_table.alias, derived_table.unnest())) - + for expression in itertools.chain(self.derived_tables, self.udtfs): + referenced_names.append((expression.alias, expression.unnest())) result = {} for name, node in referenced_names: @@ -414,7 +439,7 @@ class Scope: Scope: scope instances in depth-first-search post-order """ for child_scope in itertools.chain( - self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes + self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes ): yield from child_scope.traverse() yield self @@ -480,24 +505,23 @@ def _traverse_scope(scope): yield from _traverse_select(scope) elif isinstance(scope.expression, exp.Union): yield from _traverse_union(scope) - elif isinstance(scope.expression, exp.UDTF): - _set_udtf_scope(scope) elif isinstance(scope.expression, exp.Subquery): yield from _traverse_subqueries(scope) + elif isinstance(scope.expression, exp.UDTF): + pass else: raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}") yield scope def _traverse_select(scope): - yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE) - yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE) + yield from _traverse_ctes(scope) + yield from _traverse_tables(scope) yield from _traverse_subqueries(scope) - _add_table_sources(scope) def _traverse_union(scope): - yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE) + yield from _traverse_ctes(scope) # The last scope to be yield should be the top most scope left = None @@ -511,82 +535,98 @@ def _traverse_union(scope): scope.union_scopes = [left, right] -def _set_udtf_scope(scope): - parent = scope.expression.parent - from_ = parent.args.get("from") - - if not from_: - return - - for table in from_.expressions: - if isinstance(table, exp.Table): - scope.tables.append(table) - elif isinstance(table, exp.Subquery): - scope.subqueries.append(table) - _add_table_sources(scope) - _traverse_subqueries(scope) - - -def _traverse_derived_tables(derived_tables, scope, scope_type): +def _traverse_ctes(scope): sources = {} - is_cte = scope_type == ScopeType.CTE - for derived_table in derived_tables: + for cte in scope.ctes: recursive_scope = None # if the scope is a recursive cte, it must be in the form of # base_case UNION recursive. thus the recursive scope is the first # section of the union. - if is_cte and scope.expression.args["with"].recursive: - union = derived_table.this + if scope.expression.args["with"].recursive: + union = cte.this if isinstance(union, exp.Union): recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) for child_scope in _traverse_scope( scope.branch( - derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this, - chain_sources=sources if scope_type == ScopeType.CTE else None, - outer_column_list=derived_table.alias_column_names, - scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type, + cte.this, + chain_sources=sources, + outer_column_list=cte.alias_column_names, + scope_type=ScopeType.CTE, ) ): yield child_scope - # Tables without aliases will be set as "" - # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. - # Until then, this means that only a single, unaliased derived table is allowed (rather, - # the latest one wins. - alias = derived_table.alias + alias = cte.alias sources[alias] = child_scope if recursive_scope: child_scope.add_source(alias, recursive_scope) # append the final child_scope yielded - if is_cte: - scope.cte_scopes.append(child_scope) - else: - scope.derived_table_scopes.append(child_scope) + scope.cte_scopes.append(child_scope) scope.sources.update(sources) -def _add_table_sources(scope): +def _traverse_tables(scope): sources = {} - for table in scope.tables: - table_name = table.name - if table.alias: - source_name = table.alias - else: - source_name = table_name + # Traverse FROMs, JOINs, and LATERALs in the order they are defined + expressions = [] + from_ = scope.expression.args.get("from") + if from_: + expressions.extend(from_.expressions) - if table_name in scope.sources: - # This is a reference to a parent source (e.g. a CTE), not an actual table. - scope.sources[source_name] = scope.sources[table_name] + for join in scope.expression.args.get("joins") or []: + expressions.append(join.this) + + expressions.extend(scope.expression.args.get("laterals") or []) + + for expression in expressions: + if isinstance(expression, exp.Table): + table_name = expression.name + source_name = expression.alias_or_name + + if table_name in scope.sources: + # This is a reference to a parent source (e.g. a CTE), not an actual table. + sources[source_name] = scope.sources[table_name] + else: + sources[source_name] = expression + continue + + if isinstance(expression, exp.UDTF): + lateral_sources = sources + scope_type = ScopeType.UDTF + scopes = scope.udtf_scopes else: - sources[source_name] = table + lateral_sources = None + scope_type = ScopeType.DERIVED_TABLE + scopes = scope.derived_table_scopes + + for child_scope in _traverse_scope( + scope.branch( + expression, + lateral_sources=lateral_sources, + outer_column_list=expression.alias_column_names, + scope_type=scope_type, + ) + ): + yield child_scope + + # Tables without aliases will be set as "" + # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. + # Until then, this means that only a single, unaliased derived table is allowed (rather, + # the latest one wins. + alias = expression.alias + sources[alias] = child_scope + + # append the final child_scope yielded + scopes.append(child_scope) + scope.table_scopes.append(child_scope) scope.sources.update(sources) @@ -624,9 +664,10 @@ def walk_in_scope(expression, bfs=True): if node is expression: continue - elif isinstance(node, exp.CTE): - prune = True - elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): - prune = True - elif isinstance(node, exp.Subqueryable): + if ( + isinstance(node, exp.CTE) + or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) + or isinstance(node, exp.UDTF) + or isinstance(node, exp.Subqueryable) + ): prune = True |