diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/transforms.py | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 412b881..99949a1 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -2,6 +2,8 @@ from __future__ import annotations import typing as t +from sqlglot.helper import find_new_name + if t.TYPE_CHECKING: from sqlglot.generator import Generator @@ -43,6 +45,43 @@ def unalias_group(expression: exp.Expression) -> exp.Expression: return expression +def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: + """ + Convert SELECT DISTINCT ON statements to a subquery with a window function. + + This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. + + Args: + expression: the expression that will be transformed. + + Returns: + The transformed expression. + """ + if ( + isinstance(expression, exp.Select) + and expression.args.get("distinct") + and expression.args["distinct"].args.get("on") + and isinstance(expression.args["distinct"].args["on"], exp.Tuple) + ): + distinct_cols = [e.copy() for e in expression.args["distinct"].args["on"].expressions] + outer_selects = [e.copy() for e in expression.expressions] + nested = expression.copy() + nested.args["distinct"].pop() + row_number = find_new_name(expression.named_selects, "_row_number") + window = exp.Window( + this=exp.RowNumber(), + partition_by=distinct_cols, + ) + order = nested.args.get("order") + if order: + window.set("order", order.copy()) + order.pop() + window = exp.alias_(window, row_number) + nested.select(window, copy=False) + return exp.select(*outer_selects).from_(nested.subquery()).where(f'"{row_number}" = 1') + return expression + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], to_sql: t.Callable[[Generator, exp.Expression], str], @@ -81,3 +120,4 @@ def delegate(attr: str) -> t.Callable: UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))} +ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))} |