summaryrefslogtreecommitdiffstats
path: root/tests/dataframe/unit/test_window.py
blob: 9c4c8975866b1c15f7b8a445d6619773d70ca2e4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.window import Window, WindowSpec
from tests.dataframe.unit.dataframe_test_base import DataFrameTestBase


class TestDataframeWindow(DataFrameTestBase):
    def test_window_spec_partition_by(self):
        partition_by = WindowSpec().partitionBy(F.col("cola"), F.col("colb"))
        self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql())

    def test_window_spec_order_by(self):
        order_by = WindowSpec().orderBy("cola", "colb")
        self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql())

    def test_window_spec_rows_between(self):
        rows_between = WindowSpec().rowsBetween(3, 5)
        self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())

    def test_window_spec_range_between(self):
        range_between = WindowSpec().rangeBetween(3, 5)
        self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())

    def test_window_partition_by(self):
        partition_by = Window.partitionBy(F.col("cola"), F.col("colb"))
        self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql())

    def test_window_order_by(self):
        order_by = Window.orderBy("cola", "colb")
        self.assertEqual("OVER (ORDER BY cola, colb)", order_by.sql())

    def test_window_rows_between(self):
        rows_between = Window.rowsBetween(3, 5)
        self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql())

    def test_window_range_between(self):
        range_between = Window.rangeBetween(3, 5)
        self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql())

    def test_window_rows_unbounded(self):
        rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
        self.assertEqual(
            "OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
            rows_between_unbounded_start.sql(),
        )
        rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
        self.assertEqual(
            "OVER (ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
            rows_between_unbounded_end.sql(),
        )
        rows_between_unbounded_both = Window.rowsBetween(
            Window.unboundedPreceding, Window.unboundedFollowing
        )
        self.assertEqual(
            "OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
            rows_between_unbounded_both.sql(),
        )

    def test_window_range_unbounded(self):
        range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
        self.assertEqual(
            "OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
            range_between_unbounded_start.sql(),
        )
        range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
        self.assertEqual(
            "OVER (RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
            range_between_unbounded_end.sql(),
        )
        range_between_unbounded_both = Window.rangeBetween(
            Window.unboundedPreceding, Window.unboundedFollowing
        )
        self.assertEqual(
            "OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
            range_between_unbounded_both.sql(),
        )