diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/transforms.py | 33 |
1 files changed, 15 insertions, 18 deletions
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 35ff75a..aa7d240 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -27,20 +27,18 @@ def unalias_group(expression: exp.Expression) -> exp.Expression: """ if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): aliased_selects = { - e.alias: (i, e.this) + e.alias: i for i, e in enumerate(expression.parent.expressions, start=1) if isinstance(e, exp.Alias) } - expression = expression.copy() - - top_level_expression = None - for item, parent, _ in expression.walk(bfs=False): - top_level_expression = item if isinstance(parent, exp.Group) else top_level_expression - if isinstance(item, exp.Column) and not item.table: - alias_index, col_expression = aliased_selects.get(item.name, (None, None)) - if alias_index and top_level_expression != col_expression: - item.replace(exp.Literal.number(alias_index)) + for group_by in expression.expressions: + if ( + isinstance(group_by, exp.Column) + and not group_by.table + and group_by.name in aliased_selects + ): + group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) return expression @@ -63,22 +61,21 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: and expression.args["distinct"].args.get("on") and isinstance(expression.args["distinct"].args["on"], exp.Tuple) ): - distinct_cols = [e.copy() for e in expression.args["distinct"].args["on"].expressions] - outer_selects = [e.copy() for e in expression.expressions] - nested = expression.copy() - nested.args["distinct"].pop() + distinct_cols = expression.args["distinct"].args["on"].expressions + expression.args["distinct"].pop() + outer_selects = expression.selects row_number = find_new_name(expression.named_selects, "_row_number") window = exp.Window( this=exp.RowNumber(), partition_by=distinct_cols, ) - order = nested.args.get("order") + order = expression.args.get("order") if order: window.set("order", order.copy()) order.pop() window = exp.alias_(window, row_number) - nested.select(window, copy=False) - return exp.select(*outer_selects).from_(nested.subquery()).where(f'"{row_number}" = 1') + expression.select(window, copy=False) + return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') return expression @@ -120,7 +117,7 @@ def preprocess( """ def _to_sql(self, expression): - expression = transforms[0](expression) + expression = transforms[0](expression.copy()) for t in transforms[1:]: expression = t(expression) return to_sql(self, expression) |