From 8b4272814fb4585be120f183eb7c26bb8acde974 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 21 Oct 2022 11:29:26 +0200 Subject: Merging upstream version 9.0.1. Signed-off-by: Daniel Baumann --- tests/dataframe/unit/test_column.py | 167 ++++++++++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 tests/dataframe/unit/test_column.py (limited to 'tests/dataframe/unit/test_column.py') diff --git a/tests/dataframe/unit/test_column.py b/tests/dataframe/unit/test_column.py new file mode 100644 index 0000000..df0ebff --- /dev/null +++ b/tests/dataframe/unit/test_column.py @@ -0,0 +1,167 @@ +import datetime +import unittest + +from sqlglot.dataframe.sql import functions as F +from sqlglot.dataframe.sql.window import Window + + +class TestDataframeColumn(unittest.TestCase): + def test_eq(self): + self.assertEqual("cola = 1", (F.col("cola") == 1).sql()) + + def test_neq(self): + self.assertEqual("cola <> 1", (F.col("cola") != 1).sql()) + + def test_gt(self): + self.assertEqual("cola > 1", (F.col("cola") > 1).sql()) + + def test_lt(self): + self.assertEqual("cola < 1", (F.col("cola") < 1).sql()) + + def test_le(self): + self.assertEqual("cola <= 1", (F.col("cola") <= 1).sql()) + + def test_ge(self): + self.assertEqual("cola >= 1", (F.col("cola") >= 1).sql()) + + def test_and(self): + self.assertEqual( + "cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql() + ) + + def test_or(self): + self.assertEqual( + "cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql() + ) + + def test_mod(self): + self.assertEqual("cola % 2", (F.col("cola") % 2).sql()) + + def test_add(self): + self.assertEqual("cola + 1", (F.col("cola") + 1).sql()) + + def test_sub(self): + self.assertEqual("cola - 1", (F.col("cola") - 1).sql()) + + def test_mul(self): + self.assertEqual("cola * 2", (F.col("cola") * 2).sql()) + + def test_div(self): + self.assertEqual("cola / 2", (F.col("cola") / 2).sql()) + + def test_radd(self): + self.assertEqual("1 + cola", (1 + F.col("cola")).sql()) + + def test_rsub(self): + self.assertEqual("1 - cola", (1 - F.col("cola")).sql()) + + def test_rmul(self): + self.assertEqual("1 * cola", (1 * F.col("cola")).sql()) + + def test_rdiv(self): + self.assertEqual("1 / cola", (1 / F.col("cola")).sql()) + + def test_pow(self): + self.assertEqual("POWER(cola, 2)", (F.col("cola") ** 2).sql()) + + def test_rpow(self): + self.assertEqual("POWER(2, cola)", (2 ** F.col("cola")).sql()) + + def test_invert(self): + self.assertEqual("NOT cola", (~F.col("cola")).sql()) + + def test_startswith(self): + self.assertEqual("STARTSWITH(cola, 'test')", F.col("cola").startswith("test").sql()) + + def test_endswith(self): + self.assertEqual("ENDSWITH(cola, 'test')", F.col("cola").endswith("test").sql()) + + def test_rlike(self): + self.assertEqual("cola RLIKE 'foo'", F.col("cola").rlike("foo").sql()) + + def test_like(self): + self.assertEqual("cola LIKE 'foo%'", F.col("cola").like("foo%").sql()) + + def test_ilike(self): + self.assertEqual("cola ILIKE 'foo%'", F.col("cola").ilike("foo%").sql()) + + def test_substring(self): + self.assertEqual("SUBSTRING(cola, 2, 3)", F.col("cola").substr(2, 3).sql()) + + def test_isin(self): + self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin([1, 2, 3]).sql()) + self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin(1, 2, 3).sql()) + + def test_asc(self): + self.assertEqual("cola", F.col("cola").asc().sql()) + + def test_desc(self): + self.assertEqual("cola DESC", F.col("cola").desc().sql()) + + def test_asc_nulls_first(self): + self.assertEqual("cola", F.col("cola").asc_nulls_first().sql()) + + def test_asc_nulls_last(self): + self.assertEqual("cola NULLS LAST", F.col("cola").asc_nulls_last().sql()) + + def test_desc_nulls_first(self): + self.assertEqual("cola DESC NULLS FIRST", F.col("cola").desc_nulls_first().sql()) + + def test_desc_nulls_last(self): + self.assertEqual("cola DESC", F.col("cola").desc_nulls_last().sql()) + + def test_when_otherwise(self): + self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql()) + self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()) + self.assertEqual( + "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END", + (F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(), + ) + self.assertEqual( + "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END", + F.col("cola").when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).sql(), + ) + self.assertEqual( + "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 ELSE 4 END", + F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3).otherwise(4).sql(), + ) + + def test_is_null(self): + self.assertEqual("cola IS NULL", F.col("cola").isNull().sql()) + + def test_is_not_null(self): + self.assertEqual("NOT cola IS NULL", F.col("cola").isNotNull().sql()) + + def test_cast(self): + self.assertEqual("CAST(cola AS INT)", F.col("cola").cast("INT").sql()) + + def test_alias(self): + self.assertEqual("cola AS new_name", F.col("cola").alias("new_name").sql()) + + def test_between(self): + self.assertEqual("cola BETWEEN 1 AND 3", F.col("cola").between(1, 3).sql()) + self.assertEqual("cola BETWEEN 10.1 AND 12.1", F.col("cola").between(10.1, 12.1).sql()) + self.assertEqual( + "cola BETWEEN TO_DATE('2022-01-01') AND TO_DATE('2022-03-01')", + F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(), + ) + self.assertEqual( + "cola BETWEEN CAST('2022-01-01 01:01:01' AS TIMESTAMP) " "AND CAST('2022-03-01 01:01:01' AS TIMESTAMP)", + F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(), + ) + + def test_over(self): + over_rows = F.sum("cola").over( + Window.partitionBy("colb").orderBy("colc").rowsBetween(1, Window.unboundedFollowing) + ) + self.assertEqual( + "SUM(cola) OVER (PARTITION BY colb ORDER BY colc ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", + over_rows.sql(), + ) + over_range = F.sum("cola").over( + Window.partitionBy("colb").orderBy("colc").rangeBetween(1, Window.unboundedFollowing) + ) + self.assertEqual( + "SUM(cola) OVER (PARTITION BY colb ORDER BY colc RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", + over_range.sql(), + ) -- cgit v1.2.3