summaryrefslogtreecommitdiffstats
path: root/tests/dataframe/unit/test_column.py
blob: 977971e03fe237f8bf024bb9aa6cf054794f57f6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import datetime
import unittest

from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.window import Window


class TestDataframeColumn(unittest.TestCase):
    def test_eq(self):
        self.assertEqual("cola = 1", (F.col("cola") == 1).sql())

    def test_neq(self):
        self.assertEqual("cola <> 1", (F.col("cola") != 1).sql())

    def test_gt(self):
        self.assertEqual("cola > 1", (F.col("cola") > 1).sql())

    def test_lt(self):
        self.assertEqual("cola < 1", (F.col("cola") < 1).sql())

    def test_le(self):
        self.assertEqual("cola <= 1", (F.col("cola") <= 1).sql())

    def test_ge(self):
        self.assertEqual("cola >= 1", (F.col("cola") >= 1).sql())

    def test_and(self):
        self.assertEqual(
            "cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql()
        )

    def test_or(self):
        self.assertEqual(
            "cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql()
        )

    def test_mod(self):
        self.assertEqual("cola % 2", (F.col("cola") % 2).sql())

    def test_add(self):
        self.assertEqual("cola + 1", (F.col("cola") + 1).sql())

    def test_sub(self):
        self.assertEqual("cola - 1", (F.col("cola") - 1).sql())

    def test_mul(self):
        self.assertEqual("cola * 2", (F.col("cola") * 2).sql())

    def test_div(self):
        self.assertEqual("cola / 2", (F.col("cola") / 2).sql())

    def test_radd(self):
        self.assertEqual("1 + cola", (1 + F.col("cola")).sql())

    def test_rsub(self):
        self.assertEqual("1 - cola", (1 - F.col("cola")).sql())

    def test_rmul(self):
        self.assertEqual("1 * cola", (1 * F.col("cola")).sql())

    def test_rdiv(self):
        self.assertEqual("1 / cola", (1 / F.col("cola")).sql())

    def test_pow(self):
        self.assertEqual("POWER(cola, 2)", (F.col("cola") ** 2).sql())

    def test_rpow(self):
        self.assertEqual("POWER(2, cola)", (2 ** F.col("cola")).sql())

    def test_invert(self):
        self.assertEqual("NOT cola", (~F.col("cola")).sql())

    def test_startswith(self):
        self.assertEqual("STARTSWITH(cola, 'test')", F.col("cola").startswith("test").sql())

    def test_endswith(self):
        self.assertEqual("ENDSWITH(cola, 'test')", F.col("cola").endswith("test").sql())

    def test_rlike(self):
        self.assertEqual("cola RLIKE 'foo'", F.col("cola").rlike("foo").sql())

    def test_like(self):
        self.assertEqual("cola LIKE 'foo%'", F.col("cola").like("foo%").sql())

    def test_ilike(self):
        self.assertEqual("cola ILIKE 'foo%'", F.col("cola").ilike("foo%").sql())

    def test_substring(self):
        self.assertEqual("SUBSTRING(cola, 2, 3)", F.col("cola").substr(2, 3).sql())

    def test_isin(self):
        self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin([1, 2, 3]).sql())
        self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin(1, 2, 3).sql())

    def test_asc(self):
        self.assertEqual("cola", F.col("cola").asc().sql())

    def test_desc(self):
        self.assertEqual("cola DESC", F.col("cola").desc().sql())

    def test_asc_nulls_first(self):
        self.assertEqual("cola", F.col("cola").asc_nulls_first().sql())

    def test_asc_nulls_last(self):
        self.assertEqual("cola NULLS LAST", F.col("cola").asc_nulls_last().sql())

    def test_desc_nulls_first(self):
        self.assertEqual("cola DESC NULLS FIRST", F.col("cola").desc_nulls_first().sql())

    def test_desc_nulls_last(self):
        self.assertEqual("cola DESC", F.col("cola").desc_nulls_last().sql())

    def test_when_otherwise(self):
        self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql())
        self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql())
        self.assertEqual(
            "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
            (F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(),
        )
        self.assertEqual(
            "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
            F.col("cola").when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).sql(),
        )
        self.assertEqual(
            "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 ELSE 4 END",
            F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).otherwise(4).sql(),
        )

    def test_is_null(self):
        self.assertEqual("cola IS NULL", F.col("cola").isNull().sql())

    def test_is_not_null(self):
        self.assertEqual("NOT cola IS NULL", F.col("cola").isNotNull().sql())

    def test_cast(self):
        self.assertEqual("CAST(cola AS INT)", F.col("cola").cast("INT").sql())

    def test_alias(self):
        self.assertEqual("cola AS new_name", F.col("cola").alias("new_name").sql())

    def test_between(self):
        self.assertEqual("cola BETWEEN 1 AND 3", F.col("cola").between(1, 3).sql())
        self.assertEqual("cola BETWEEN 10.1 AND 12.1", F.col("cola").between(10.1, 12.1).sql())
        self.assertEqual(
            "cola BETWEEN TO_DATE('2022-01-01') AND TO_DATE('2022-03-01')",
            F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(),
        )
        self.assertEqual(
            "cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) "
            "AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)",
            F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(),
        )

    def test_over(self):
        over_rows = F.sum("cola").over(
            Window.partitionBy("colb").orderBy("colc").rowsBetween(1, Window.unboundedFollowing)
        )
        self.assertEqual(
            "SUM(cola) OVER (PARTITION BY colb ORDER BY colc ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
            over_rows.sql(),
        )
        over_range = F.sum("cola").over(
            Window.partitionBy("colb").orderBy("colc").rangeBetween(1, Window.unboundedFollowing)
        )
        self.assertEqual(
            "SUM(cola) OVER (PARTITION BY colb ORDER BY colc RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
            over_range.sql(),
        )