summaryrefslogtreecommitdiffstats
path: root/tests/dataframe
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/dataframe/unit/test_functions.py54
1 files changed, 31 insertions, 23 deletions
diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py
index 10f3b57..97753bd 100644
--- a/tests/dataframe/unit/test_functions.py
+++ b/tests/dataframe/unit/test_functions.py
@@ -9,7 +9,6 @@ from sqlglot.errors import ErrorLevel
class TestFunctions(unittest.TestCase):
- @unittest.skip("not yet fixed.")
def test_invoke_anonymous(self):
for name, func in inspect.getmembers(SF, inspect.isfunction):
with self.subTest(f"{name} should not invoke anonymous_function"):
@@ -438,13 +437,13 @@ class TestFunctions(unittest.TestCase):
def test_pow(self):
col_str = SF.pow("cola", "colb")
- self.assertEqual("POW(cola, colb)", col_str.sql())
+ self.assertEqual("POWER(cola, colb)", col_str.sql())
col = SF.pow(SF.col("cola"), SF.col("colb"))
- self.assertEqual("POW(cola, colb)", col.sql())
+ self.assertEqual("POWER(cola, colb)", col.sql())
col_float = SF.pow(10.10, "colb")
- self.assertEqual("POW(10.1, colb)", col_float.sql())
+ self.assertEqual("POWER(10.1, colb)", col_float.sql())
col_float2 = SF.pow("cola", 10.10)
- self.assertEqual("POW(cola, 10.1)", col_float2.sql())
+ self.assertEqual("POWER(cola, 10.1)", col_float2.sql())
def test_row_number(self):
col_str = SF.row_number()
@@ -493,6 +492,8 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("COALESCE(cola, colb, colc)", col_str.sql())
col = SF.coalesce(SF.col("cola"), "colb", SF.col("colc"))
self.assertEqual("COALESCE(cola, colb, colc)", col.sql())
+ col_single = SF.coalesce("cola")
+ self.assertEqual("COALESCE(cola)", col_single.sql())
def test_corr(self):
col_str = SF.corr("cola", "colb")
@@ -843,8 +844,8 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("TO_DATE(cola)", col_str.sql())
col = SF.to_date(SF.col("cola"))
self.assertEqual("TO_DATE(cola)", col.sql())
- col_with_format = SF.to_date("cola", "yyyy-MM-dd")
- self.assertEqual("TO_DATE(cola, 'yyyy-MM-dd')", col_with_format.sql())
+ col_with_format = SF.to_date("cola", "yy-MM-dd")
+ self.assertEqual("TO_DATE(cola, 'yy-MM-dd')", col_with_format.sql())
def test_to_timestamp(self):
col_str = SF.to_timestamp("cola")
@@ -883,16 +884,16 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("FROM_UNIXTIME(cola)", col_str.sql())
col = SF.from_unixtime(SF.col("cola"))
self.assertEqual("FROM_UNIXTIME(cola)", col.sql())
- col_format = SF.from_unixtime("cola", "yyyy-MM-dd HH:mm:ss")
- self.assertEqual("FROM_UNIXTIME(cola, 'yyyy-MM-dd HH:mm:ss')", col_format.sql())
+ col_format = SF.from_unixtime("cola", "yyyy-MM-dd HH:mm")
+ self.assertEqual("FROM_UNIXTIME(cola, 'yyyy-MM-dd HH:mm')", col_format.sql())
def test_unix_timestamp(self):
col_str = SF.unix_timestamp("cola")
self.assertEqual("UNIX_TIMESTAMP(cola)", col_str.sql())
col = SF.unix_timestamp(SF.col("cola"))
self.assertEqual("UNIX_TIMESTAMP(cola)", col.sql())
- col_format = SF.unix_timestamp("cola", "yyyy-MM-dd HH:mm:ss")
- self.assertEqual("UNIX_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm:ss')", col_format.sql())
+ col_format = SF.unix_timestamp("cola", "yyyy-MM-dd HH:mm")
+ self.assertEqual("UNIX_TIMESTAMP(cola, 'yyyy-MM-dd HH:mm')", col_format.sql())
col_current = SF.unix_timestamp()
self.assertEqual("UNIX_TIMESTAMP()", col_current.sql())
@@ -1427,6 +1428,13 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("ARRAY_SORT(cola)", col_str.sql())
col = SF.array_sort(SF.col("cola"))
self.assertEqual("ARRAY_SORT(cola)", col.sql())
+ col_comparator = SF.array_sort(
+ "cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x))
+ )
+ self.assertEqual(
+ "ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)",
+ col_comparator.sql(),
+ )
def test_reverse(self):
col_str = SF.reverse("cola")
@@ -1514,8 +1522,6 @@ class TestFunctions(unittest.TestCase):
SF.lit(0),
lambda accumulator, target: accumulator + target,
lambda accumulator: accumulator * 2,
- "accumulator",
- "target",
)
self.assertEqual(
"AGGREGATE(cola, 0, (accumulator, target) -> accumulator + target, accumulator -> accumulator * 2)",
@@ -1527,7 +1533,7 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("TRANSFORM(cola, x -> x * 2)", col_str.sql())
col = SF.transform(SF.col("cola"), lambda x, i: x * i)
self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql())
- col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count, "target", "row_count")
+ col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count)
self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql())
@@ -1536,7 +1542,7 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col_str.sql())
col = SF.exists(SF.col("cola"), lambda x: x % 2 == 0)
self.assertEqual("EXISTS(cola, x -> x % 2 = 0)", col.sql())
- col_custom_name = SF.exists("cola", lambda target: target > 0, "target")
+ col_custom_name = SF.exists("cola", lambda target: target > 0)
self.assertEqual("EXISTS(cola, target -> target > 0)", col_custom_name.sql())
def test_forall(self):
@@ -1544,7 +1550,7 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col_str.sql())
col = SF.forall(SF.col("cola"), lambda x: x.rlike("foo"))
self.assertEqual("FORALL(cola, x -> x RLIKE 'foo')", col.sql())
- col_custom_name = SF.forall("cola", lambda target: target.rlike("foo"), "target")
+ col_custom_name = SF.forall("cola", lambda target: target.rlike("foo"))
self.assertEqual("FORALL(cola, target -> target RLIKE 'foo')", col_custom_name.sql())
def test_filter(self):
@@ -1552,9 +1558,7 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql())
col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i))
self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql())
- col_custom_names = SF.filter(
- "cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count), "target", "row_count"
- )
+ col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count))
self.assertEqual(
"FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql()
@@ -1565,7 +1569,7 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col_str.sql())
col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y))
self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql())
- col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r), "l", "r")
+ col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r))
self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql())
def test_transform_keys(self):
@@ -1573,7 +1577,7 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col_str.sql())
col = SF.transform_keys(SF.col("cola"), lambda k, v: SF.upper(k))
self.assertEqual("TRANSFORM_KEYS(cola, (k, v) -> UPPER(k))", col.sql())
- col_custom_names = SF.transform_keys("cola", lambda key, _: SF.upper(key), "key", "_")
+ col_custom_names = SF.transform_keys("cola", lambda key, _: SF.upper(key))
self.assertEqual("TRANSFORM_KEYS(cola, (key, _) -> UPPER(key))", col_custom_names.sql())
def test_transform_values(self):
@@ -1581,7 +1585,7 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col_str.sql())
col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v))
self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql())
- col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value), "_", "value")
+ col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value))
self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql())
def test_map_filter(self):
@@ -1589,5 +1593,9 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col_str.sql())
col = SF.map_filter(SF.col("cola"), lambda k, v: k > v)
self.assertEqual("MAP_FILTER(cola, (k, v) -> k > v)", col.sql())
- col_custom_names = SF.map_filter("cola", lambda key, value: key > value, "key", "value")
+ col_custom_names = SF.map_filter("cola", lambda key, value: key > value)
self.assertEqual("MAP_FILTER(cola, (key, value) -> key > value)", col_custom_names.sql())
+
+ def test_map_zip_with(self):
+ col = SF.map_zip_with("base", "ratio", lambda k, v1, v2: SF.round(v1 * v2, 2))
+ self.assertEqual("MAP_ZIP_WITH(base, ratio, (k, v1, v2) -> ROUND(v1 * v2, 2))", col.sql())