diff options
Diffstat (limited to 'sqlglot/optimizer/unnest_subqueries.py')
-rw-r--r-- | sqlglot/optimizer/unnest_subqueries.py | 23 |
1 files changed, 9 insertions, 14 deletions
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index a515489..09e3f2a 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -1,6 +1,5 @@ -import itertools - from sqlglot import exp +from sqlglot.helper import name_sequence from sqlglot.optimizer.scope import ScopeType, traverse_scope @@ -22,7 +21,7 @@ def unnest_subqueries(expression): Returns: sqlglot.Expression: unnested expression """ - sequence = itertools.count() + next_alias_name = name_sequence("_u_") for scope in traverse_scope(expression): select = scope.expression @@ -30,19 +29,19 @@ def unnest_subqueries(expression): if not parent: continue if scope.external_columns: - decorrelate(select, parent, scope.external_columns, sequence) + decorrelate(select, parent, scope.external_columns, next_alias_name) elif scope.scope_type == ScopeType.SUBQUERY: - unnest(select, parent, sequence) + unnest(select, parent, next_alias_name) return expression -def unnest(select, parent_select, sequence): +def unnest(select, parent_select, next_alias_name): if len(select.selects) > 1: return predicate = select.find_ancestor(exp.Condition) - alias = _alias(sequence) + alias = next_alias_name() if not predicate or parent_select is not predicate.parent_select: return @@ -87,13 +86,13 @@ def unnest(select, parent_select, sequence): ) -def decorrelate(select, parent_select, external_columns, sequence): +def decorrelate(select, parent_select, external_columns, next_alias_name): where = select.args.get("where") if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): return - table_alias = _alias(sequence) + table_alias = next_alias_name() keys = [] # for all external columns in the where statement, find the relevant predicate @@ -136,7 +135,7 @@ def decorrelate(select, parent_select, external_columns, sequence): group_by.append(key) else: if key not in key_aliases: - key_aliases[key] = _alias(sequence) + key_aliases[key] = next_alias_name() # all predicates that are equalities must also be in the unique # so that we don't do a many to many join if isinstance(predicate, exp.EQ) and key not in group_by: @@ -244,10 +243,6 @@ def decorrelate(select, parent_select, external_columns, sequence): ) -def _alias(sequence): - return f"_u_{next(sequence)}" - - def _replace(expression, condition): return expression.replace(exp.condition(condition)) |