diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-02-19 13:45:09 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-02-19 13:45:09 +0000 |
commit | 639a208fa57ea674d165c4837e96f3ae4d7e3e61 (patch) | |
tree | f4d66da146c396d407cecefb5b405e609af1109e /sqlglot/optimizer | |
parent | Releasing debian version 11.0.1-1. (diff) | |
download | sqlglot-639a208fa57ea674d165c4837e96f3ae4d7e3e61.tar.xz sqlglot-639a208fa57ea674d165c4837e96f3ae4d7e3e61.zip |
Merging upstream version 11.1.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 3 | ||||
-rw-r--r-- | sqlglot/optimizer/eliminate_subqueries.py | 6 | ||||
-rw-r--r-- | sqlglot/optimizer/optimizer.py | 31 | ||||
-rw-r--r-- | sqlglot/optimizer/pushdown_projections.py | 76 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_columns.py | 48 | ||||
-rw-r--r-- | sqlglot/optimizer/qualify_tables.py | 2 | ||||
-rw-r--r-- | sqlglot/optimizer/scope.py | 169 |
7 files changed, 199 insertions, 136 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 66f97a9..be65ab9 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -280,6 +280,9 @@ class TypeAnnotator: } # First annotate the current scope's column references for col in scope.columns: + if not col.table: + continue + source = scope.sources.get(col.table) if isinstance(source, exp.Table): col.type = self.schema.get_column_type(source, col) diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index c6bea5a..6f9db82 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -81,9 +81,7 @@ def eliminate_subqueries(expression): new_ctes.append(cte_scope.expression.parent) # Now append the rest - for scope in itertools.chain( - root.union_scopes, root.subquery_scopes, root.derived_table_scopes - ): + for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes): for child_scope in scope.traverse(): new_cte = _eliminate(child_scope, existing_ctes, taken) if new_cte: @@ -99,7 +97,7 @@ def _eliminate(scope, existing_ctes, taken): if scope.is_union: return _eliminate_union(scope, existing_ctes, taken) - if scope.is_derived_table and not isinstance(scope.expression, exp.UDTF): + if scope.is_derived_table: return _eliminate_derived_table(scope, existing_ctes, taken) if scope.is_cte: diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 96fd56b..d9d04be 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -1,4 +1,10 @@ +from __future__ import annotations + +import typing as t + import sqlglot +from sqlglot import Schema, exp +from sqlglot.dialects.dialect import DialectType from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.canonicalize import canonicalize from sqlglot.optimizer.eliminate_ctes import eliminate_ctes @@ -24,8 +30,8 @@ RULES = ( isolate_table_selects, qualify_columns, expand_laterals, - validate_qualify_columns, pushdown_projections, + validate_qualify_columns, normalize, unnest_subqueries, expand_multi_table_selects, @@ -40,22 +46,31 @@ RULES = ( ) -def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwargs): +def optimize( + expression: str | exp.Expression, + schema: t.Optional[dict | Schema] = None, + db: t.Optional[str] = None, + catalog: t.Optional[str] = None, + dialect: DialectType = None, + rules: t.Sequence[t.Callable] = RULES, + **kwargs, +): """ Rewrite a sqlglot AST into an optimized form. Args: - expression (sqlglot.Expression): expression to optimize - schema (dict|sqlglot.optimizer.Schema): database schema. + expression: expression to optimize + schema: database schema. This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of the following forms: 1. {table: {col: type}} 2. {db: {table: {col: type}}} 3. {catalog: {db: {table: {col: type}}}} If no schema is provided then the default schema defined at `sqlgot.schema` will be used - db (str): specify the default database, as might be set by a `USE DATABASE db` statement - catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement - rules (sequence): sequence of optimizer rules to use. + db: specify the default database, as might be set by a `USE DATABASE db` statement + catalog: specify the default catalog, as might be set by a `USE CATALOG c` statement + dialect: The dialect to parse the sql string. + rules: sequence of optimizer rules to use. Many of the rules require tables and columns to be qualified. Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know what you're doing! @@ -65,7 +80,7 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar """ schema = ensure_schema(schema or sqlglot.schema) possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs} - expression = expression.copy() + expression = exp.maybe_parse(expression, dialect=dialect, copy=True) for rule in rules: # Find any additional rule parameters, beyond `expression` rule_params = rule.__code__.co_varnames diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 54c5021..3f360f9 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -1,7 +1,10 @@ from collections import defaultdict from sqlglot import alias, exp +from sqlglot.helper import flatten +from sqlglot.optimizer.qualify_columns import Resolver from sqlglot.optimizer.scope import Scope, traverse_scope +from sqlglot.schema import ensure_schema # Sentinel value that means an outer query selecting ALL columns SELECT_ALL = object() @@ -10,7 +13,7 @@ SELECT_ALL = object() DEFAULT_SELECTION = lambda: alias("1", "_") -def pushdown_projections(expression): +def pushdown_projections(expression, schema=None): """ Rewrite sqlglot AST to remove unused columns projections. @@ -27,9 +30,9 @@ def pushdown_projections(expression): sqlglot.Expression: optimized expression """ # Map of Scope to all columns being selected by outer queries. + schema = ensure_schema(schema) referenced_columns = defaultdict(set) - left_union = None - right_union = None + # We build the scope tree (which is traversed in DFS postorder), then iterate # over the result in reverse order. This should ensure that the set of selected # columns for a particular scope are completely build by the time we get to it. @@ -41,16 +44,20 @@ def pushdown_projections(expression): parent_selections = {SELECT_ALL} if isinstance(scope.expression, exp.Union): - left_union, right_union = scope.union_scopes - referenced_columns[left_union] = parent_selections - referenced_columns[right_union] = parent_selections + left, right = scope.union_scopes + referenced_columns[left] = parent_selections + + if any(select.is_star for select in right.selects): + referenced_columns[right] = parent_selections + elif not any(select.is_star for select in left.selects): + referenced_columns[right] = [ + right.selects[i].alias_or_name + for i, select in enumerate(left.selects) + if SELECT_ALL in parent_selections or select.alias_or_name in parent_selections + ] - if isinstance(scope.expression, exp.Select) and scope != right_union: - removed_indexes = _remove_unused_selections(scope, parent_selections) - # The left union is used for column names to select and if we remove columns from the left - # we need to also remove those same columns in the right that were at the same position - if scope is left_union: - _remove_indexed_selections(right_union, removed_indexes) + if isinstance(scope.expression, exp.Select): + _remove_unused_selections(scope, parent_selections, schema) # Group columns by source name selects = defaultdict(set) @@ -68,8 +75,7 @@ def pushdown_projections(expression): return expression -def _remove_unused_selections(scope, parent_selections): - removed_indexes = [] +def _remove_unused_selections(scope, parent_selections, schema): order = scope.expression.args.get("order") if order: @@ -78,33 +84,33 @@ def _remove_unused_selections(scope, parent_selections): else: order_refs = set() - new_selections = [] + new_selections = defaultdict(list) removed = False - for i, selection in enumerate(scope.selects): - if ( - SELECT_ALL in parent_selections - or selection.alias_or_name in parent_selections - or selection.alias_or_name in order_refs - ): - new_selections.append(selection) + star = False + for selection in scope.selects: + name = selection.alias_or_name + + if SELECT_ALL in parent_selections or name in parent_selections or name in order_refs: + new_selections[name].append(selection) else: - removed_indexes.append(i) + if selection.is_star: + star = True removed = True + if star: + resolver = Resolver(scope, schema) + + for name in sorted(parent_selections): + if name not in new_selections: + new_selections[name].append( + alias(exp.column(name, table=resolver.get_table(name)), name) + ) + # If there are no remaining selections, just select a single constant if not new_selections: - new_selections.append(DEFAULT_SELECTION()) + new_selections[""].append(DEFAULT_SELECTION()) + + scope.expression.select(*flatten(new_selections.values()), append=False, copy=False) - scope.expression.set("expressions", new_selections) if removed: scope.clear_cache() - return removed_indexes - - -def _remove_indexed_selections(scope, indexes_to_remove): - new_selections = [ - selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove - ] - if not new_selections: - new_selections.append(DEFAULT_SELECTION()) - scope.expression.set("expressions", new_selections) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index ab13d01..a7bd9b5 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -27,17 +27,16 @@ def qualify_columns(expression, schema): schema = ensure_schema(schema) for scope in traverse_scope(expression): - resolver = _Resolver(scope, schema) + resolver = Resolver(scope, schema) _pop_table_column_aliases(scope.ctes) _pop_table_column_aliases(scope.derived_tables) _expand_using(scope, resolver) - _expand_group_by(scope, resolver) _qualify_columns(scope, resolver) - _expand_order_by(scope) if not isinstance(scope.expression, exp.UDTF): _expand_stars(scope, resolver) _qualify_outputs(scope) - + _expand_group_by(scope, resolver) + _expand_order_by(scope) return expression @@ -48,7 +47,8 @@ def validate_qualify_columns(expression): if isinstance(scope.expression, exp.Select): unqualified_columns.extend(scope.unqualified_columns) if scope.external_columns and not scope.is_correlated_subquery: - raise OptimizeError(f"Unknown table: {scope.external_columns[0].table}") + column = scope.external_columns[0] + raise OptimizeError(f"Unknown table: '{column.table}' for column '{column}'") if unqualified_columns: raise OptimizeError(f"Ambiguous columns: {unqualified_columns}") @@ -62,8 +62,6 @@ def _pop_table_column_aliases(derived_tables): (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) """ for derived_table in derived_tables: - if isinstance(derived_table.unnest(), exp.UDTF): - continue table_alias = derived_table.args.get("alias") if table_alias: table_alias.args.pop("columns", None) @@ -206,7 +204,7 @@ def _qualify_columns(scope, resolver): if column_table and column_table in scope.sources: source_columns = resolver.get_source_columns(column_table) - if source_columns and column_name not in source_columns: + if source_columns and column_name not in source_columns and "*" not in source_columns: raise OptimizeError(f"Unknown column: {column_name}") if not column_table: @@ -256,7 +254,7 @@ def _expand_stars(scope, resolver): tables = list(scope.selected_sources) _add_except_columns(expression, tables, except_columns) _add_replace_columns(expression, tables, replace_columns) - elif isinstance(expression, exp.Column) and isinstance(expression.this, exp.Star): + elif expression.is_star: tables = [expression.table] _add_except_columns(expression.this, tables, except_columns) _add_replace_columns(expression.this, tables, replace_columns) @@ -268,17 +266,16 @@ def _expand_stars(scope, resolver): if table not in scope.sources: raise OptimizeError(f"Unknown table: {table}") columns = resolver.get_source_columns(table, only_visible=True) - if not columns: - raise OptimizeError( - f"Table has no schema/columns. Cannot expand star for table: {table}." - ) - table_id = id(table) - for name in columns: - if name not in except_columns.get(table_id, set()): - alias_ = replace_columns.get(table_id, {}).get(name, name) - column = exp.column(name, table) - new_selections.append(alias(column, alias_) if alias_ != name else column) + if columns and "*" not in columns: + table_id = id(table) + for name in columns: + if name not in except_columns.get(table_id, set()): + alias_ = replace_columns.get(table_id, {}).get(name, name) + column = exp.column(name, table) + new_selections.append(alias(column, alias_) if alias_ != name else column) + else: + return scope.expression.set("expressions", new_selections) @@ -316,7 +313,7 @@ def _qualify_outputs(scope): if isinstance(selection, exp.Subquery): if not selection.output_name: selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) - elif not isinstance(selection, exp.Alias): + elif not isinstance(selection, exp.Alias) and not selection.is_star: alias_ = alias(exp.column(""), alias=selection.output_name or f"_col_{i}") alias_.set("this", selection) selection = alias_ @@ -329,7 +326,7 @@ def _qualify_outputs(scope): scope.expression.set("expressions", new_selections) -class _Resolver: +class Resolver: """ Helper for resolving columns. @@ -361,7 +358,9 @@ class _Resolver: if not table: sources_without_schema = tuple( - source for source, columns in self._get_all_source_columns().items() if not columns + source + for source, columns in self._get_all_source_columns().items() + if not columns or "*" in columns ) if len(sources_without_schema) == 1: return sources_without_schema[0] @@ -397,7 +396,8 @@ class _Resolver: def _get_all_source_columns(self): if self._source_columns is None: self._source_columns = { - k: self.get_source_columns(k) for k in self.scope.selected_sources + k: self.get_source_columns(k) + for k in itertools.chain(self.scope.selected_sources, self.scope.lateral_sources) } return self._source_columns @@ -436,7 +436,7 @@ class _Resolver: Find the unique columns in a list of columns. Example: - >>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"])) + >>> sorted(Resolver._find_unique_columns(["a", "b", "b", "c"])) ['a', 'c'] This is necessary because duplicate column names are ambiguous. diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py index 65593bd..6e50182 100644 --- a/sqlglot/optimizer/qualify_tables.py +++ b/sqlglot/optimizer/qualify_tables.py @@ -28,7 +28,7 @@ def qualify_tables(expression, db=None, catalog=None, schema=None): next_name = lambda: f"_q_{next(sequence)}" for scope in traverse_scope(expression): - for derived_table in scope.ctes + scope.derived_tables: + for derived_table in itertools.chain(scope.ctes, scope.derived_tables): if not derived_table.args.get("alias"): alias_ = f"_q_{next(sequence)}" derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 8565c64..335ff3e 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -26,6 +26,10 @@ class Scope: SELECT * FROM x {"x": Table(this="x")} SELECT * FROM x AS y {"y": Table(this="x")} SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} + lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals + For example: + SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; + The LATERAL VIEW EXPLODE gets x as a source. outer_column_list (list[str]): If this is a derived table or CTE, and the outer query defines a column list of it's alias of this scope, this is that list of columns. For example: @@ -34,8 +38,10 @@ class Scope: parent (Scope): Parent scope scope_type (ScopeType): Type of this scope, relative to it's parent subquery_scopes (list[Scope]): List of all child scopes for subqueries - cte_scopes = (list[Scope]) List of all child scopes for CTEs - derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables + cte_scopes (list[Scope]): List of all child scopes for CTEs + derived_table_scopes (list[Scope]): List of all child scopes for derived_tables + udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions + table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be a list of the left and right child scopes. """ @@ -47,22 +53,28 @@ class Scope: outer_column_list=None, parent=None, scope_type=ScopeType.ROOT, + lateral_sources=None, ): self.expression = expression self.sources = sources or {} + self.lateral_sources = lateral_sources.copy() if lateral_sources else {} + self.sources.update(self.lateral_sources) self.outer_column_list = outer_column_list or [] self.parent = parent self.scope_type = scope_type self.subquery_scopes = [] self.derived_table_scopes = [] + self.table_scopes = [] self.cte_scopes = [] self.union_scopes = [] + self.udtf_scopes = [] self.clear_cache() def clear_cache(self): self._collected = False self._raw_columns = None self._derived_tables = None + self._udtfs = None self._tables = None self._ctes = None self._subqueries = None @@ -86,6 +98,7 @@ class Scope: self._ctes = [] self._subqueries = [] self._derived_tables = [] + self._udtfs = [] self._raw_columns = [] self._join_hints = [] @@ -99,7 +112,7 @@ class Scope: elif isinstance(node, exp.JoinHint): self._join_hints.append(node) elif isinstance(node, exp.UDTF): - self._derived_tables.append(node) + self._udtfs.append(node) elif isinstance(node, exp.CTE): self._ctes.append(node) elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): @@ -200,6 +213,17 @@ class Scope: return self._derived_tables @property + def udtfs(self): + """ + List of "User Defined Tabular Functions" in this scope. + + Returns: + list[exp.UDTF]: UDTFs + """ + self._ensure_collected() + return self._udtfs + + @property def subqueries(self): """ List of subqueries in this scope. @@ -227,7 +251,9 @@ class Scope: columns = self._raw_columns external_columns = [ - column for scope in self.subquery_scopes for column in scope.external_columns + column + for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes) + for column in scope.external_columns ] named_selects = set(self.expression.named_selects) @@ -262,9 +288,8 @@ class Scope: for table in self.tables: referenced_names.append((table.alias_or_name, table)) - for derived_table in self.derived_tables: - referenced_names.append((derived_table.alias, derived_table.unnest())) - + for expression in itertools.chain(self.derived_tables, self.udtfs): + referenced_names.append((expression.alias, expression.unnest())) result = {} for name, node in referenced_names: @@ -414,7 +439,7 @@ class Scope: Scope: scope instances in depth-first-search post-order """ for child_scope in itertools.chain( - self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes + self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes ): yield from child_scope.traverse() yield self @@ -480,24 +505,23 @@ def _traverse_scope(scope): yield from _traverse_select(scope) elif isinstance(scope.expression, exp.Union): yield from _traverse_union(scope) - elif isinstance(scope.expression, exp.UDTF): - _set_udtf_scope(scope) elif isinstance(scope.expression, exp.Subquery): yield from _traverse_subqueries(scope) + elif isinstance(scope.expression, exp.UDTF): + pass else: raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}") yield scope def _traverse_select(scope): - yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE) - yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE) + yield from _traverse_ctes(scope) + yield from _traverse_tables(scope) yield from _traverse_subqueries(scope) - _add_table_sources(scope) def _traverse_union(scope): - yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE) + yield from _traverse_ctes(scope) # The last scope to be yield should be the top most scope left = None @@ -511,82 +535,98 @@ def _traverse_union(scope): scope.union_scopes = [left, right] -def _set_udtf_scope(scope): - parent = scope.expression.parent - from_ = parent.args.get("from") - - if not from_: - return - - for table in from_.expressions: - if isinstance(table, exp.Table): - scope.tables.append(table) - elif isinstance(table, exp.Subquery): - scope.subqueries.append(table) - _add_table_sources(scope) - _traverse_subqueries(scope) - - -def _traverse_derived_tables(derived_tables, scope, scope_type): +def _traverse_ctes(scope): sources = {} - is_cte = scope_type == ScopeType.CTE - for derived_table in derived_tables: + for cte in scope.ctes: recursive_scope = None # if the scope is a recursive cte, it must be in the form of # base_case UNION recursive. thus the recursive scope is the first # section of the union. - if is_cte and scope.expression.args["with"].recursive: - union = derived_table.this + if scope.expression.args["with"].recursive: + union = cte.this if isinstance(union, exp.Union): recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) for child_scope in _traverse_scope( scope.branch( - derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this, - chain_sources=sources if scope_type == ScopeType.CTE else None, - outer_column_list=derived_table.alias_column_names, - scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type, + cte.this, + chain_sources=sources, + outer_column_list=cte.alias_column_names, + scope_type=ScopeType.CTE, ) ): yield child_scope - # Tables without aliases will be set as "" - # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. - # Until then, this means that only a single, unaliased derived table is allowed (rather, - # the latest one wins. - alias = derived_table.alias + alias = cte.alias sources[alias] = child_scope if recursive_scope: child_scope.add_source(alias, recursive_scope) # append the final child_scope yielded - if is_cte: - scope.cte_scopes.append(child_scope) - else: - scope.derived_table_scopes.append(child_scope) + scope.cte_scopes.append(child_scope) scope.sources.update(sources) -def _add_table_sources(scope): +def _traverse_tables(scope): sources = {} - for table in scope.tables: - table_name = table.name - if table.alias: - source_name = table.alias - else: - source_name = table_name + # Traverse FROMs, JOINs, and LATERALs in the order they are defined + expressions = [] + from_ = scope.expression.args.get("from") + if from_: + expressions.extend(from_.expressions) - if table_name in scope.sources: - # This is a reference to a parent source (e.g. a CTE), not an actual table. - scope.sources[source_name] = scope.sources[table_name] + for join in scope.expression.args.get("joins") or []: + expressions.append(join.this) + + expressions.extend(scope.expression.args.get("laterals") or []) + + for expression in expressions: + if isinstance(expression, exp.Table): + table_name = expression.name + source_name = expression.alias_or_name + + if table_name in scope.sources: + # This is a reference to a parent source (e.g. a CTE), not an actual table. + sources[source_name] = scope.sources[table_name] + else: + sources[source_name] = expression + continue + + if isinstance(expression, exp.UDTF): + lateral_sources = sources + scope_type = ScopeType.UDTF + scopes = scope.udtf_scopes else: - sources[source_name] = table + lateral_sources = None + scope_type = ScopeType.DERIVED_TABLE + scopes = scope.derived_table_scopes + + for child_scope in _traverse_scope( + scope.branch( + expression, + lateral_sources=lateral_sources, + outer_column_list=expression.alias_column_names, + scope_type=scope_type, + ) + ): + yield child_scope + + # Tables without aliases will be set as "" + # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. + # Until then, this means that only a single, unaliased derived table is allowed (rather, + # the latest one wins. + alias = expression.alias + sources[alias] = child_scope + + # append the final child_scope yielded + scopes.append(child_scope) + scope.table_scopes.append(child_scope) scope.sources.update(sources) @@ -624,9 +664,10 @@ def walk_in_scope(expression, bfs=True): if node is expression: continue - elif isinstance(node, exp.CTE): - prune = True - elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)): - prune = True - elif isinstance(node, exp.Subqueryable): + if ( + isinstance(node, exp.CTE) + or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join))) + or isinstance(node, exp.UDTF) + or isinstance(node, exp.Subqueryable) + ): prune = True |