1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
|
from __future__ import annotations
import sys
import typing as t
from sqlglot import expressions as exp
from sqlglot.dataframe.sql import functions as F
from sqlglot.helper import flatten
if t.TYPE_CHECKING:
from sqlglot.dataframe.sql._typing import ColumnOrName
class Window:
_JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808
_JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807
_PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG)
_FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG)
unboundedPreceding: int = _JAVA_MIN_LONG
unboundedFollowing: int = _JAVA_MAX_LONG
currentRow: int = 0
@classmethod
def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
return WindowSpec().partitionBy(*cols)
@classmethod
def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
return WindowSpec().orderBy(*cols)
@classmethod
def rowsBetween(cls, start: int, end: int) -> WindowSpec:
return WindowSpec().rowsBetween(start, end)
@classmethod
def rangeBetween(cls, start: int, end: int) -> WindowSpec:
return WindowSpec().rangeBetween(start, end)
class WindowSpec:
def __init__(self, expression: exp.Expression = exp.Window()):
self.expression = expression
def copy(self):
return WindowSpec(self.expression.copy())
def sql(self, **kwargs) -> str:
return self.expression.sql(dialect="spark", **kwargs)
def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
from sqlglot.dataframe.sql.column import Column
cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
expressions = [Column.ensure_col(x).expression for x in cols]
window_spec = self.copy()
partition_by_expressions = window_spec.expression.args.get("partition_by", [])
partition_by_expressions.extend(expressions)
window_spec.expression.set("partition_by", partition_by_expressions)
return window_spec
def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec:
from sqlglot.dataframe.sql.column import Column
cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore
expressions = [Column.ensure_col(x).expression for x in cols]
window_spec = self.copy()
if window_spec.expression.args.get("order") is None:
window_spec.expression.set("order", exp.Order(expressions=[]))
order_by = window_spec.expression.args["order"].expressions
order_by.extend(expressions)
window_spec.expression.args["order"].set("expressions", order_by)
return window_spec
def _calc_start_end(self, start: int, end: int) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {"start_side": None, "end_side": None}
if start == Window.currentRow:
kwargs["start"] = "CURRENT ROW"
else:
kwargs = {
**kwargs,
**{
"start_side": "PRECEDING",
"start": "UNBOUNDED" if start <= Window.unboundedPreceding else F.lit(start).expression,
},
}
if end == Window.currentRow:
kwargs["end"] = "CURRENT ROW"
else:
kwargs = {
**kwargs,
**{
"end_side": "FOLLOWING",
"end": "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression,
},
}
return kwargs
def rowsBetween(self, start: int, end: int) -> WindowSpec:
window_spec = self.copy()
spec = self._calc_start_end(start, end)
spec["kind"] = "ROWS"
window_spec.expression.set(
"spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
)
return window_spec
def rangeBetween(self, start: int, end: int) -> WindowSpec:
window_spec = self.copy()
spec = self._calc_start_end(start, end)
spec["kind"] = "RANGE"
window_spec.expression.set(
"spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
)
return window_spec
|