diff options
Diffstat (limited to 'sqlglot/optimizer/scope.py')
-rw-r--r-- | sqlglot/optimizer/scope.py | 32 |
1 files changed, 21 insertions, 11 deletions
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index b2b4230..a7dab35 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -124,8 +124,8 @@ class Scope: self._ctes.append(node) elif ( isinstance(node, exp.Subquery) - and isinstance(parent, (exp.From, exp.Join)) - and _is_subquery_scope(node) + and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) + and _is_derived_table(node) ): self._derived_tables.append(node) elif isinstance(node, exp.Subqueryable): @@ -610,13 +610,13 @@ def _traverse_ctes(scope): scope.sources.update(sources) -def _is_subquery_scope(expression: exp.Subquery) -> bool: +def _is_derived_table(expression: exp.Subquery) -> bool: """ - We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a new scope. - If an alias is present, it shadows all names under the Subquery, so that's an - exception to this rule. + We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", + as it doesn't introduce a new scope. If an alias is present, it shadows all names + under the Subquery, so that's one exception to this rule. """ - return bool(not isinstance(expression.unnest(), exp.Table) or expression.alias) + return bool(expression.alias or isinstance(expression.this, exp.Subqueryable)) def _traverse_tables(scope): @@ -654,7 +654,10 @@ def _traverse_tables(scope): else: sources[source_name] = expression - expressions.extend(join.this for join in expression.args.get("joins") or []) + # Make sure to not include the joins twice + if expression is not scope.expression: + expressions.extend(join.this for join in expression.args.get("joins") or []) + continue if not isinstance(expression, exp.DerivedTable): @@ -664,10 +667,11 @@ def _traverse_tables(scope): lateral_sources = sources scope_type = ScopeType.UDTF scopes = scope.udtf_scopes - elif _is_subquery_scope(expression): + elif _is_derived_table(expression): lateral_sources = None scope_type = ScopeType.DERIVED_TABLE scopes = scope.derived_table_scopes + expressions.extend(join.this for join in expression.args.get("joins") or []) else: # Makes sure we check for possible sources in nested table constructs expressions.append(expression.this) @@ -735,10 +739,16 @@ def walk_in_scope(expression, bfs=True): isinstance(node, exp.CTE) or ( isinstance(node, exp.Subquery) - and isinstance(parent, (exp.From, exp.Join)) - and _is_subquery_scope(node) + and isinstance(parent, (exp.From, exp.Join, exp.Subquery)) + and _is_derived_table(node) ) or isinstance(node, exp.UDTF) or isinstance(node, exp.Subqueryable) ): prune = True + + if isinstance(node, (exp.Subquery, exp.UDTF)): + # The following args are not actually in the inner scope, so we should visit them + for key in ("joins", "laterals", "pivots"): + for arg in node.args.get(key) or []: + yield from walk_in_scope(arg, bfs=bfs) |