summaryrefslogtreecommitdiffstats
path: root/sqlglot/dataframe/sql/window.py
blob: c54c07e21c410ae265a069d72e5675cdd914e835 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from __future__ import annotations

import sys
import typing as t

from sqlglot import expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.helper import flatten

if t.TYPE_CHECKING:
    from sqlglot.dataframe.sql._typing import ColumnOrName


class Window:
    _JAVA_MIN_LONG = -(1 << 63)  # -9223372036854775808
    _JAVA_MAX_LONG = (1 << 63) - 1  # 9223372036854775807
    _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG)
    _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG)

    unboundedPreceding: int = _JAVA_MIN_LONG

    unboundedFollowing: int = _JAVA_MAX_LONG

    currentRow: int = 0

    @classmethod
    def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
        return WindowSpec().partitionBy(*cols)

    @classmethod
    def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
        return WindowSpec().orderBy(*cols)

    @classmethod
    def rowsBetween(cls, start: int, end: int) -> WindowSpec:
        return WindowSpec().rowsBetween(start, end)

    @classmethod
    def rangeBetween(cls, start: int, end: int) -> WindowSpec:
        return WindowSpec().rangeBetween(start, end)


class WindowSpec:
    def __init__(self, expression: exp.Expression = exp.Window()):
        self.expression = expression

    def copy(self):
        return WindowSpec(self.expression.copy())

    def sql(self, **kwargs) -> str:
        return self.expression.sql(dialect="spark", **kwargs)

    def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
        from sqlglot.dataframe.sql.column import Column

        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
        expressions = [Column.ensure_col(x).expression for x in cols]
        window_spec = self.copy()
        partition_by_expressions = window_spec.expression.args.get("partition_by", [])
        partition_by_expressions.extend(expressions)
        window_spec.expression.set("partition_by", partition_by_expressions)
        return window_spec

    def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
        from sqlglot.dataframe.sql.column import Column

        cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols  # type: ignore
        expressions = [Column.ensure_col(x).expression for x in cols]
        window_spec = self.copy()
        if window_spec.expression.args.get("order") is None:
            window_spec.expression.set("order", exp.Order(expressions=[]))
        order_by = window_spec.expression.args["order"].expressions
        order_by.extend(expressions)
        window_spec.expression.args["order"].set("expressions", order_by)
        return window_spec

    def _calc_start_end(
        self, start: int, end: int
    ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
        kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {
            "start_side": None,
            "end_side": None,
        }
        if start == Window.currentRow:
            kwargs["start"] = "CURRENT ROW"
        else:
            kwargs = {
                **kwargs,
                **{
                    "start_side": "PRECEDING",
                    "start": "UNBOUNDED"
                    if start <= Window.unboundedPreceding
                    else F.lit(start).expression,
                },
            }
        if end == Window.currentRow:
            kwargs["end"] = "CURRENT ROW"
        else:
            kwargs = {
                **kwargs,
                **{
                    "end_side": "FOLLOWING",
                    "end": "UNBOUNDED"
                    if end >= Window.unboundedFollowing
                    else F.lit(end).expression,
                },
            }
        return kwargs

    def rowsBetween(self, start: int, end: int) -> WindowSpec:
        window_spec = self.copy()
        spec = self._calc_start_end(start, end)
        spec["kind"] = "ROWS"
        window_spec.expression.set(
            "spec",
            exp.WindowSpec(
                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
            ),
        )
        return window_spec

    def rangeBetween(self, start: int, end: int) -> WindowSpec:
        window_spec = self.copy()
        spec = self._calc_start_end(start, end)
        spec["kind"] = "RANGE"
        window_spec.expression.set(
            "spec",
            exp.WindowSpec(
                **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
            ),
        )
        return window_spec