1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
|
from __future__ import annotations
import sys
import typing as t
from sqlglot import expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.helper import flatten
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnOrName
class Window:
_JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808
_JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807
_PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG)
_FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG)
unboundedPreceding: int = _JAVA_MIN_LONG
unboundedFollowing: int = _JAVA_MAX_LONG
currentRow: int = 0
@classmethod
def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
return WindowSpec().partitionBy(*cols)
@classmethod
def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
return WindowSpec().orderBy(*cols)
@classmethod
def rowsBetween(cls, start: int, end: int) -> WindowSpec:
return WindowSpec().rowsBetween(start, end)
@classmethod
def rangeBetween(cls, start: int, end: int) -> WindowSpec:
return WindowSpec().rangeBetween(start, end)
class WindowSpec:
def __init__(self, expression: exp.Expression = exp.Window()):
self.expression = expression
def copy(self):
return WindowSpec(self.expression.copy())
def sql(self, **kwargs) -> str:
return self.expression.sql(dialect="spark", **kwargs)
def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
from sqlglot.dataframe.sql.column import Column
cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
expressions = [Column.ensure_col(x).expression for x in cols]
window_spec = self.copy()
partition_by_expressions = window_spec.expression.args.get("partition_by", [])
partition_by_expressions.extend(expressions)
window_spec.expression.set("partition_by", partition_by_expressions)
return window_spec
def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
from sqlglot.dataframe.sql.column import Column
cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
expressions = [Column.ensure_col(x).expression for x in cols]
window_spec = self.copy()
if window_spec.expression.args.get("order") is None:
window_spec.expression.set("order", exp.Order(expressions=[]))
order_by = window_spec.expression.args["order"].expressions
order_by.extend(expressions)
window_spec.expression.args["order"].set("expressions", order_by)
return window_spec
def _calc_start_end(
self, start: int, end: int
) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {
"start_side": None,
"end_side": None,
}
if start == Window.currentRow:
kwargs["start"] = "CURRENT ROW"
else:
kwargs = {
**kwargs,
**{
"start_side": "PRECEDING",
"start": "UNBOUNDED"
if start <= Window.unboundedPreceding
else F.lit(start).expression,
},
}
if end == Window.currentRow:
kwargs["end"] = "CURRENT ROW"
else:
kwargs = {
**kwargs,
**{
"end_side": "FOLLOWING",
"end": "UNBOUNDED"
if end >= Window.unboundedFollowing
else F.lit(end).expression,
},
}
return kwargs
def rowsBetween(self, start: int, end: int) -> WindowSpec:
window_spec = self.copy()
spec = self._calc_start_end(start, end)
spec["kind"] = "ROWS"
window_spec.expression.set(
"spec",
exp.WindowSpec(
**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
),
)
return window_spec
def rangeBetween(self, start: int, end: int) -> WindowSpec:
window_spec = self.copy()
spec = self._calc_start_end(start, end)
spec["kind"] = "RANGE"
window_spec.expression.set(
"spec",
exp.WindowSpec(
**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
),
)
return window_spec
|