summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/scope.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r--sqlglot/optimizer/scope.py169
1 files changed, 105 insertions, 64 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 8565c64..335ff3e 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -26,6 +26,10 @@ class Scope:
SELECT * FROM x {"x": Table(this="x")}
SELECT * FROM x AS y {"y": Table(this="x")}
SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
+ lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals
+ For example:
+ SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
+ The LATERAL VIEW EXPLODE gets x as a source.
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
For example:
@@ -34,8 +38,10 @@ class Scope:
parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent
subquery_scopes (list[Scope]): List of all child scopes for subqueries
- cte_scopes = (list[Scope]) List of all child scopes for CTEs
- derived_table_scopes = (list[Scope]) List of all child scopes for derived_tables
+ cte_scopes (list[Scope]): List of all child scopes for CTEs
+ derived_table_scopes (list[Scope]): List of all child scopes for derived_tables
+ udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions
+ table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined
union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be
a list of the left and right child scopes.
"""
@@ -47,22 +53,28 @@ class Scope:
outer_column_list=None,
parent=None,
scope_type=ScopeType.ROOT,
+ lateral_sources=None,
):
self.expression = expression
self.sources = sources or {}
+ self.lateral_sources = lateral_sources.copy() if lateral_sources else {}
+ self.sources.update(self.lateral_sources)
self.outer_column_list = outer_column_list or []
self.parent = parent
self.scope_type = scope_type
self.subquery_scopes = []
self.derived_table_scopes = []
+ self.table_scopes = []
self.cte_scopes = []
self.union_scopes = []
+ self.udtf_scopes = []
self.clear_cache()
def clear_cache(self):
self._collected = False
self._raw_columns = None
self._derived_tables = None
+ self._udtfs = None
self._tables = None
self._ctes = None
self._subqueries = None
@@ -86,6 +98,7 @@ class Scope:
self._ctes = []
self._subqueries = []
self._derived_tables = []
+ self._udtfs = []
self._raw_columns = []
self._join_hints = []
@@ -99,7 +112,7 @@ class Scope:
elif isinstance(node, exp.JoinHint):
self._join_hints.append(node)
elif isinstance(node, exp.UDTF):
- self._derived_tables.append(node)
+ self._udtfs.append(node)
elif isinstance(node, exp.CTE):
self._ctes.append(node)
elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
@@ -200,6 +213,17 @@ class Scope:
return self._derived_tables
@property
+ def udtfs(self):
+ """
+ List of "User Defined Tabular Functions" in this scope.
+
+ Returns:
+ list[exp.UDTF]: UDTFs
+ """
+ self._ensure_collected()
+ return self._udtfs
+
+ @property
def subqueries(self):
"""
List of subqueries in this scope.
@@ -227,7 +251,9 @@ class Scope:
columns = self._raw_columns
external_columns = [
- column for scope in self.subquery_scopes for column in scope.external_columns
+ column
+ for scope in itertools.chain(self.subquery_scopes, self.udtf_scopes)
+ for column in scope.external_columns
]
named_selects = set(self.expression.named_selects)
@@ -262,9 +288,8 @@ class Scope:
for table in self.tables:
referenced_names.append((table.alias_or_name, table))
- for derived_table in self.derived_tables:
- referenced_names.append((derived_table.alias, derived_table.unnest()))
-
+ for expression in itertools.chain(self.derived_tables, self.udtfs):
+ referenced_names.append((expression.alias, expression.unnest()))
result = {}
for name, node in referenced_names:
@@ -414,7 +439,7 @@ class Scope:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
- self.cte_scopes, self.union_scopes, self.derived_table_scopes, self.subquery_scopes
+ self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
):
yield from child_scope.traverse()
yield self
@@ -480,24 +505,23 @@ def _traverse_scope(scope):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_union(scope)
- elif isinstance(scope.expression, exp.UDTF):
- _set_udtf_scope(scope)
elif isinstance(scope.expression, exp.Subquery):
yield from _traverse_subqueries(scope)
+ elif isinstance(scope.expression, exp.UDTF):
+ pass
else:
raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
yield scope
def _traverse_select(scope):
- yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
- yield from _traverse_derived_tables(scope.derived_tables, scope, ScopeType.DERIVED_TABLE)
+ yield from _traverse_ctes(scope)
+ yield from _traverse_tables(scope)
yield from _traverse_subqueries(scope)
- _add_table_sources(scope)
def _traverse_union(scope):
- yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE)
+ yield from _traverse_ctes(scope)
# The last scope to be yield should be the top most scope
left = None
@@ -511,82 +535,98 @@ def _traverse_union(scope):
scope.union_scopes = [left, right]
-def _set_udtf_scope(scope):
- parent = scope.expression.parent
- from_ = parent.args.get("from")
-
- if not from_:
- return
-
- for table in from_.expressions:
- if isinstance(table, exp.Table):
- scope.tables.append(table)
- elif isinstance(table, exp.Subquery):
- scope.subqueries.append(table)
- _add_table_sources(scope)
- _traverse_subqueries(scope)
-
-
-def _traverse_derived_tables(derived_tables, scope, scope_type):
+def _traverse_ctes(scope):
sources = {}
- is_cte = scope_type == ScopeType.CTE
- for derived_table in derived_tables:
+ for cte in scope.ctes:
recursive_scope = None
# if the scope is a recursive cte, it must be in the form of
# base_case UNION recursive. thus the recursive scope is the first
# section of the union.
- if is_cte and scope.expression.args["with"].recursive:
- union = derived_table.this
+ if scope.expression.args["with"].recursive:
+ union = cte.this
if isinstance(union, exp.Union):
recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE)
for child_scope in _traverse_scope(
scope.branch(
- derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this,
- chain_sources=sources if scope_type == ScopeType.CTE else None,
- outer_column_list=derived_table.alias_column_names,
- scope_type=ScopeType.UDTF if isinstance(derived_table, exp.UDTF) else scope_type,
+ cte.this,
+ chain_sources=sources,
+ outer_column_list=cte.alias_column_names,
+ scope_type=ScopeType.CTE,
)
):
yield child_scope
- # Tables without aliases will be set as ""
- # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
- # Until then, this means that only a single, unaliased derived table is allowed (rather,
- # the latest one wins.
- alias = derived_table.alias
+ alias = cte.alias
sources[alias] = child_scope
if recursive_scope:
child_scope.add_source(alias, recursive_scope)
# append the final child_scope yielded
- if is_cte:
- scope.cte_scopes.append(child_scope)
- else:
- scope.derived_table_scopes.append(child_scope)
+ scope.cte_scopes.append(child_scope)
scope.sources.update(sources)
-def _add_table_sources(scope):
+def _traverse_tables(scope):
sources = {}
- for table in scope.tables:
- table_name = table.name
- if table.alias:
- source_name = table.alias
- else:
- source_name = table_name
+ # Traverse FROMs, JOINs, and LATERALs in the order they are defined
+ expressions = []
+ from_ = scope.expression.args.get("from")
+ if from_:
+ expressions.extend(from_.expressions)
- if table_name in scope.sources:
- # This is a reference to a parent source (e.g. a CTE), not an actual table.
- scope.sources[source_name] = scope.sources[table_name]
+ for join in scope.expression.args.get("joins") or []:
+ expressions.append(join.this)
+
+ expressions.extend(scope.expression.args.get("laterals") or [])
+
+ for expression in expressions:
+ if isinstance(expression, exp.Table):
+ table_name = expression.name
+ source_name = expression.alias_or_name
+
+ if table_name in scope.sources:
+ # This is a reference to a parent source (e.g. a CTE), not an actual table.
+ sources[source_name] = scope.sources[table_name]
+ else:
+ sources[source_name] = expression
+ continue
+
+ if isinstance(expression, exp.UDTF):
+ lateral_sources = sources
+ scope_type = ScopeType.UDTF
+ scopes = scope.udtf_scopes
else:
- sources[source_name] = table
+ lateral_sources = None
+ scope_type = ScopeType.DERIVED_TABLE
+ scopes = scope.derived_table_scopes
+
+ for child_scope in _traverse_scope(
+ scope.branch(
+ expression,
+ lateral_sources=lateral_sources,
+ outer_column_list=expression.alias_column_names,
+ scope_type=scope_type,
+ )
+ ):
+ yield child_scope
+
+ # Tables without aliases will be set as ""
+ # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
+ # Until then, this means that only a single, unaliased derived table is allowed (rather,
+ # the latest one wins.
+ alias = expression.alias
+ sources[alias] = child_scope
+
+ # append the final child_scope yielded
+ scopes.append(child_scope)
+ scope.table_scopes.append(child_scope)
scope.sources.update(sources)
@@ -624,9 +664,10 @@ def walk_in_scope(expression, bfs=True):
if node is expression:
continue
- elif isinstance(node, exp.CTE):
- prune = True
- elif isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)):
- prune = True
- elif isinstance(node, exp.Subqueryable):
+ if (
+ isinstance(node, exp.CTE)
+ or (isinstance(node, exp.Subquery) and isinstance(parent, (exp.From, exp.Join)))
+ or isinstance(node, exp.UDTF)
+ or isinstance(node, exp.Subqueryable)
+ ):
prune = True