diff options
Diffstat (limited to '')
-rw-r--r-- | tests/dataframe/unit/test_functions.py | 54 |
1 files changed, 31 insertions, 23 deletions
diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index 10f3b57..97753bd 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -9,7 +9,6 @@ from sqlglot.errors import ErrorLevel class TestFunctions(unittest.TestCase): - @unittest.skip("not yet fixed.") def test_invoke_anonymous(self): for name, func in inspect.getmembers(SF, inspect.isfunction): with self.subTest(f"{name} should not invoke anonymous_function"): @@ -438,13 +437,13 @@ class TestFunctions(unittest.TestCase): def test_pow(self): col_str = SF.pow("cola", "colb") - self.assertEqual("POW(cola, colb)", col_str.sql()) + self.assertEqual("POWER(cola, colb)", col_str.sql()) col = SF.pow(SF.col("cola"), SF.col("colb")) - self.assertEqual("POW(cola, colb)", col.sql()) + self.assertEqual("POWER(cola, colb)", col.sql()) col_float = SF.pow(10.10, "colb") - self.assertEqual("POW(10.1, colb)", col_float.sql()) + self.assertEqual("POWER(10.1, colb)", col_float.sql()) col_float2 = SF.pow("cola", 10.10) - self.assertEqual("POW(cola, 10.1)", col_float2.sql()) + self.assertEqual("POWER(cola, 10.1)", col_float2.sql()) def test_row_number(self): col_str = SF.row_number() @@ -493,6 +492,8 @@ class TestFunctions(unittest.TestCase): self.assertEqual("COALESCE(cola, colb, colc)", col_str.sql()) col = SF.coalesce(SF.col("cola"), "colb", SF.col("colc")) self.assertEqual("COALESCE(cola, colb, colc)", col.sql()) + col_single = SF.coalesce("cola") + self.assertEqual("COALESCE(cola)", col_single.sql()) def test_corr(self): col_str = SF.corr("cola", "colb") @@ -843,8 +844,8 @@ class TestFunctions(unittest.TestCase): self.assertEqual("TO_DATE(cola)", col_str.sql()) col = SF.to_date(SF.col("cola")) self.assertEqual("TO_DATE(cola)", col.sql()) - col_with_format = SF.to_date("cola", "yyyy-MM-dd") - self.assertEqual("TO_DATE(cola, 'yyyy-MM-dd')", col_with_format.sql()) + col_with_format = SF.to_date("cola", "yy-MM-dd") + self.assertEqual("TO_DATE(cola, 'yy-MM-dd')", col_with_format.sql()) def test_to_timestamp(self): col_str = SF.to_timestamp("cola") @@ -883,16 +884,16 @@ class TestFunctions(unittest.TestCase): self.assertEqual("FROM_UNIXTIME(cola)", col_str.sql()) col = SF.from_unixtime(SF.col("cola")) self.assertEqual("FROM_UNIXTIME(cola)", col.sql()) - col_format = SF.from_unixtime("cola", "yyyy-MM-dd HH:mm:ss") - self.assertEqual("FROM_UNIXTIME(cola, 'yyyy-MM-dd HH:mm:ss')", col_format.sql()) + col_format = SF.from_unixtime("cola", "yyyy-MM-dd HH:mm") + self.assertEqual("FROM_UNIXTIME(cola, 'yyyy-MM-dd HH:mm')", col_format.sql()) def test_unix_timestamp(self): col_str = SF.unix_timestamp("cola") self.assertEqual("UNIX_TIMESTAMP(cola)", col_str.sql()) col = SF.unix_timestamp(SF.col("cola")) self.assertEqual("UNIX_TIMESTAMP(cola)", col.sql()) - col_format = SF.unix_timestamp("cola", "yyyy-MM-dd HH:mm:ss") - self.assertEqual("UNIX_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm:ss')", col_format.sql()) + col_format = SF.unix_timestamp("cola", "yyyy-MM-dd HH:mm") + self.assertEqual("UNIX_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm')", col_format.sql()) col_current = SF.unix_timestamp() self.assertEqual("UNIX_TIMESTAMP()", col_current.sql()) @@ -1427,6 +1428,13 @@ class TestFunctions(unittest.TestCase): self.assertEqual("ARRAY_SORT(cola)", col_str.sql()) col = SF.array_sort(SF.col("cola")) self.assertEqual("ARRAY_SORT(cola)", col.sql()) + col_comparator = SF.array_sort( + "cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x)) + ) + self.assertEqual( + "ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)", + col_comparator.sql(), + ) def test_reverse(self): col_str = SF.reverse("cola") @@ -1514,8 +1522,6 @@ class TestFunctions(unittest.TestCase): SF.lit(0), lambda accumulator, target: accumulator + target, lambda accumulator: accumulator * 2, - "accumulator", - "target", ) self.assertEqual( "AGGREGATE(cola, 0, (accumulator, target) -> accumulator + target, accumulator -> accumulator * 2)", @@ -1527,7 +1533,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("TRANSFORM(cola, x -> x * 2)", col_str.sql()) col = SF.transform(SF.col("cola"), lambda x, i: x * i) self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql()) - col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count, "target", "row_count") + col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count) self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()) @@ -1536,7 +1542,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col_str.sql()) col = SF.exists(SF.col("cola"), lambda x: x % 2 == 0) self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col.sql()) - col_custom_name = SF.exists("cola", lambda target: target > 0, "target") + col_custom_name = SF.exists("cola", lambda target: target > 0) self.assertEqual("EXISTS(cola, target -> target > 0)", col_custom_name.sql()) def test_forall(self): @@ -1544,7 +1550,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col_str.sql()) col = SF.forall(SF.col("cola"), lambda x: x.rlike("foo")) self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col.sql()) - col_custom_name = SF.forall("cola", lambda target: target.rlike("foo"), "target") + col_custom_name = SF.forall("cola", lambda target: target.rlike("foo")) self.assertEqual("FORALL(cola, target -> target RLIKE 'foo')", col_custom_name.sql()) def test_filter(self): @@ -1552,9 +1558,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql()) col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i)) self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql()) - col_custom_names = SF.filter( - "cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count), "target", "row_count" - ) + col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count)) self.assertEqual( "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql() @@ -1565,7 +1569,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col_str.sql()) col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y)) self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql()) - col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r), "l", "r") + col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r)) self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()) def test_transform_keys(self): @@ -1573,7 +1577,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col_str.sql()) col = SF.transform_keys(SF.col("cola"), lambda k, v: SF.upper(k)) self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col.sql()) - col_custom_names = SF.transform_keys("cola", lambda key, _: SF.upper(key), "key", "_") + col_custom_names = SF.transform_keys("cola", lambda key, _: SF.upper(key)) self.assertEqual("TRANSFORM_KEYS(cola, (key, _) -> UPPER(key))", col_custom_names.sql()) def test_transform_values(self): @@ -1581,7 +1585,7 @@ class TestFunctions(unittest.TestCase): self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col_str.sql()) col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v)) self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql()) - col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value), "_", "value") + col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value)) self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()) def test_map_filter(self): @@ -1589,5 +1593,9 @@ class TestFunctions(unittest.TestCase): self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col_str.sql()) col = SF.map_filter(SF.col("cola"), lambda k, v: k > v) self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col.sql()) - col_custom_names = SF.map_filter("cola", lambda key, value: key > value, "key", "value") + col_custom_names = SF.map_filter("cola", lambda key, value: key > value) self.assertEqual("MAP_FILTER(cola, (key, value) -> key > value)", col_custom_names.sql()) + + def test_map_zip_with(self): + col = SF.map_zip_with("base", "ratio", lambda k, v1, v2: SF.round(v1 * v2, 2)) + self.assertEqual("MAP_ZIP_WITH(base, ratio, (k, v1, v2) -> ROUND(v1 * v2, 2))", col.sql()) |