summaryrefslogtreecommitdiffstats
path: root/sqlglot/transforms.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/transforms.py')
-rw-r--r--sqlglot/transforms.py44
1 files changed, 30 insertions, 14 deletions
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 3643cd7..a1ec1bd 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -3,7 +3,7 @@ from __future__ import annotations
import typing as t
from sqlglot import expressions as exp
-from sqlglot.helper import find_new_name
+from sqlglot.helper import find_new_name, name_sequence
if t.TYPE_CHECKING:
from sqlglot.generator import Generator
@@ -63,16 +63,17 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
distinct_cols = expression.args["distinct"].pop().args["on"].expressions
outer_selects = expression.selects
row_number = find_new_name(expression.named_selects, "_row_number")
- window = exp.Window(
- this=exp.RowNumber(),
- partition_by=distinct_cols,
- )
+ window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
order = expression.args.get("order")
+
if order:
window.set("order", order.pop().copy())
+
window = exp.alias_(window, row_number)
expression.select(window, copy=False)
+
return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1')
+
return expression
@@ -93,7 +94,7 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
for select in expression.selects:
if not select.alias_or_name:
alias = find_new_name(taken, "_c")
- select.replace(exp.alias_(select.copy(), alias))
+ select.replace(exp.alias_(select, alias))
taken.add(alias)
outer_selects = exp.select(*[select.alias_or_name for select in expression.selects])
@@ -102,8 +103,9 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression:
for expr in qualify_filters.find_all((exp.Window, exp.Column)):
if isinstance(expr, exp.Window):
alias = find_new_name(expression.named_selects, "_w")
- expression.select(exp.alias_(expr.copy(), alias), copy=False)
+ expression.select(exp.alias_(expr, alias), copy=False)
column = exp.column(alias)
+
if isinstance(expr.parent, exp.Qualify):
qualify_filters = column
else:
@@ -123,6 +125,7 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr
"""
for node in expression.find_all(exp.DataType):
node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)])
+
return expression
@@ -147,6 +150,7 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
),
)
+
return expression
@@ -156,7 +160,10 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression:
from sqlglot.optimizer.scope import build_scope
taken_select_names = set(expression.named_selects)
- taken_source_names = set(build_scope(expression).selected_sources)
+ scope = build_scope(expression)
+ if not scope:
+ return expression
+ taken_source_names = set(scope.selected_sources)
for select in expression.selects:
to_replace = select
@@ -226,6 +233,7 @@ def remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
else node,
copy=False,
)
+
return expression
@@ -242,12 +250,20 @@ def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expre
return expression
-def unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression:
- if isinstance(expression, exp.Pivot):
- expression.args["field"].transform(
- lambda node: exp.column(node.output_name) if isinstance(node, exp.Column) else node,
- copy=False,
- )
+def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression:
+ if isinstance(expression, exp.With) and expression.recursive:
+ next_name = name_sequence("_c_")
+
+ for cte in expression.expressions:
+ if not cte.args["alias"].columns:
+ query = cte.this
+ if isinstance(query, exp.Union):
+ query = query.this
+
+ cte.args["alias"].set(
+ "columns",
+ [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
+ )
return expression