summaryrefslogtreecommitdiffstats
path: root/sqlglot/transforms.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/transforms.py')
-rw-r--r--sqlglot/transforms.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 412b881..99949a1 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -2,6 +2,8 @@ from __future__ import annotations
import typing as t
+from sqlglot.helper import find_new_name
+
if t.TYPE_CHECKING:
from sqlglot.generator import Generator
@@ -43,6 +45,43 @@ def unalias_group(expression: exp.Expression) -> exp.Expression:
return expression
+def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
+ """
+ Convert SELECT DISTINCT ON statements to a subquery with a window function.
+
+ This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
+
+ Args:
+ expression: the expression that will be transformed.
+
+ Returns:
+ The transformed expression.
+ """
+ if (
+ isinstance(expression, exp.Select)
+ and expression.args.get("distinct")
+ and expression.args["distinct"].args.get("on")
+ and isinstance(expression.args["distinct"].args["on"], exp.Tuple)
+ ):
+ distinct_cols = [e.copy() for e in expression.args["distinct"].args["on"].expressions]
+ outer_selects = [e.copy() for e in expression.expressions]
+ nested = expression.copy()
+ nested.args["distinct"].pop()
+ row_number = find_new_name(expression.named_selects, "_row_number")
+ window = exp.Window(
+ this=exp.RowNumber(),
+ partition_by=distinct_cols,
+ )
+ order = nested.args.get("order")
+ if order:
+ window.set("order", order.copy())
+ order.pop()
+ window = exp.alias_(window, row_number)
+ nested.select(window, copy=False)
+ return exp.select(*outer_selects).from_(nested.subquery()).where(f'"{row_number}" = 1')
+ return expression
+
+
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
to_sql: t.Callable[[Generator, exp.Expression], str],
@@ -81,3 +120,4 @@ def delegate(attr: str) -> t.Callable:
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
+ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}