summaryrefslogtreecommitdiffstats
path: root/sqlglot/transforms.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/transforms.py')
-rw-r--r--sqlglot/transforms.py33
1 files changed, 15 insertions, 18 deletions
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 35ff75a..aa7d240 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -27,20 +27,18 @@ def unalias_group(expression: exp.Expression) -> exp.Expression:
"""
if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
aliased_selects = {
- e.alias: (i, e.this)
+ e.alias: i
for i, e in enumerate(expression.parent.expressions, start=1)
if isinstance(e, exp.Alias)
}
- expression = expression.copy()
-
- top_level_expression = None
- for item, parent, _ in expression.walk(bfs=False):
- top_level_expression = item if isinstance(parent, exp.Group) else top_level_expression
- if isinstance(item, exp.Column) and not item.table:
- alias_index, col_expression = aliased_selects.get(item.name, (None, None))
- if alias_index and top_level_expression != col_expression:
- item.replace(exp.Literal.number(alias_index))
+ for group_by in expression.expressions:
+ if (
+ isinstance(group_by, exp.Column)
+ and not group_by.table
+ and group_by.name in aliased_selects
+ ):
+ group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name)))
return expression
@@ -63,22 +61,21 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
and expression.args["distinct"].args.get("on")
and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
):
- distinct_cols = [e.copy() for e in expression.args["distinct"].args["on"].expressions]
- outer_selects = [e.copy() for e in expression.expressions]
- nested = expression.copy()
- nested.args["distinct"].pop()
+ distinct_cols = expression.args["distinct"].args["on"].expressions
+ expression.args["distinct"].pop()
+ outer_selects = expression.selects
row_number = find_new_name(expression.named_selects, "_row_number")
window = exp.Window(
this=exp.RowNumber(),
partition_by=distinct_cols,
)
- order = nested.args.get("order")
+ order = expression.args.get("order")
if order:
window.set("order", order.copy())
order.pop()
window = exp.alias_(window, row_number)
- nested.select(window, copy=False)
- return exp.select(*outer_selects).from_(nested.subquery()).where(f'"{row_number}" = 1')
+ expression.select(window, copy=False)
+ return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
return expression
@@ -120,7 +117,7 @@ def preprocess(
"""
def _to_sql(self, expression):
- expression = transforms[0](expression)
+ expression = transforms[0](expression.copy())
for t in transforms[1:]:
expression = t(expression)
return to_sql(self, expression)