diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-02 23:59:40 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-02 23:59:46 +0000 |
commit | 20739a12c39121a9e7ad3c9a2469ec5a6876199d (patch) | |
tree | c000de91c59fd29b2d9beecf9f93b84e69727f37 /sqlglot/transforms.py | |
parent | Releasing debian version 12.2.0-1. (diff) | |
download | sqlglot-20739a12c39121a9e7ad3c9a2469ec5a6876199d.tar.xz sqlglot-20739a12c39121a9e7ad3c9a2469ec5a6876199d.zip |
Merging upstream version 15.0.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/transforms.py')
-rw-r--r-- | sqlglot/transforms.py | 44 |
1 files changed, 30 insertions, 14 deletions
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 3643cd7..a1ec1bd 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -3,7 +3,7 @@ from __future__ import annotations import typing as t from sqlglot import expressions as exp -from sqlglot.helper import find_new_name +from sqlglot.helper import find_new_name, name_sequence if t.TYPE_CHECKING: from sqlglot.generator import Generator @@ -63,16 +63,17 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: distinct_cols = expression.args["distinct"].pop().args["on"].expressions outer_selects = expression.selects row_number = find_new_name(expression.named_selects, "_row_number") - window = exp.Window( - this=exp.RowNumber(), - partition_by=distinct_cols, - ) + window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) order = expression.args.get("order") + if order: window.set("order", order.pop().copy()) + window = exp.alias_(window, row_number) expression.select(window, copy=False) + return exp.select(*outer_selects).from_(expression.subquery()).where(f'"{row_number}" = 1') + return expression @@ -93,7 +94,7 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression: for select in expression.selects: if not select.alias_or_name: alias = find_new_name(taken, "_c") - select.replace(exp.alias_(select.copy(), alias)) + select.replace(exp.alias_(select, alias)) taken.add(alias) outer_selects = exp.select(*[select.alias_or_name for select in expression.selects]) @@ -102,8 +103,9 @@ def eliminate_qualify(expression: exp.Expression) -> exp.Expression: for expr in qualify_filters.find_all((exp.Window, exp.Column)): if isinstance(expr, exp.Window): alias = find_new_name(expression.named_selects, "_w") - expression.select(exp.alias_(expr.copy(), alias), copy=False) + expression.select(exp.alias_(expr, alias), copy=False) column = exp.column(alias) + if isinstance(expr.parent, exp.Qualify): qualify_filters = column else: @@ -123,6 +125,7 @@ def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expr """ for node in expression.find_all(exp.DataType): node.set("expressions", [e for e in node.expressions if isinstance(e, exp.DataType)]) + return expression @@ -147,6 +150,7 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression: alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore ), ) + return expression @@ -156,7 +160,10 @@ def explode_to_unnest(expression: exp.Expression) -> exp.Expression: from sqlglot.optimizer.scope import build_scope taken_select_names = set(expression.named_selects) - taken_source_names = set(build_scope(expression).selected_sources) + scope = build_scope(expression) + if not scope: + return expression + taken_source_names = set(scope.selected_sources) for select in expression.selects: to_replace = select @@ -226,6 +233,7 @@ def remove_target_from_merge(expression: exp.Expression) -> exp.Expression: else node, copy=False, ) + return expression @@ -242,12 +250,20 @@ def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expre return expression -def unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: - if isinstance(expression, exp.Pivot): - expression.args["field"].transform( - lambda node: exp.column(node.output_name) if isinstance(node, exp.Column) else node, - copy=False, - ) +def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.With) and expression.recursive: + next_name = name_sequence("_c_") + + for cte in expression.expressions: + if not cte.args["alias"].columns: + query = cte.this + if isinstance(query, exp.Union): + query = query.this + + cte.args["alias"].set( + "columns", + [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], + ) return expression |