summaryrefslogtreecommitdiffstats
path: root/tests/dataframe/unit
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dataframe/unit')
-rw-r--r--tests/dataframe/unit/dataframe_sql_validator.py12
-rw-r--r--tests/dataframe/unit/test_column.py14
-rw-r--r--tests/dataframe/unit/test_dataframe.py4
-rw-r--r--tests/dataframe/unit/test_functions.py55
-rw-r--r--tests/dataframe/unit/test_session.py11
-rw-r--r--tests/dataframe/unit/test_types.py5
-rw-r--r--tests/dataframe/unit/test_window.py32
7 files changed, 95 insertions, 38 deletions
diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py
index fc56553..32ff8f2 100644
--- a/tests/dataframe/unit/dataframe_sql_validator.py
+++ b/tests/dataframe/unit/dataframe_sql_validator.py
@@ -25,11 +25,17 @@ class DataFrameSQLValidator(unittest.TestCase):
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
- self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema)
+ self.df_employee = self.spark.createDataFrame(
+ data=employee_data, schema=self.employee_schema
+ )
- def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False):
+ def compare_sql(
+ self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
+ ):
actual_sqls = df.sql(pretty=pretty)
- expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements
+ expected_statements = (
+ [expected_statements] if isinstance(expected_statements, str) else expected_statements
+ )
self.assertEqual(len(expected_statements), len(actual_sqls))
for expected, actual in zip(expected_statements, actual_sqls):
self.assertEqual(expected, actual)
diff --git a/tests/dataframe/unit/test_column.py b/tests/dataframe/unit/test_column.py
index 977971e..da18502 100644
--- a/tests/dataframe/unit/test_column.py
+++ b/tests/dataframe/unit/test_column.py
@@ -26,12 +26,14 @@ class TestDataframeColumn(unittest.TestCase):
def test_and(self):
self.assertEqual(
- "cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql()
+ "cola = colb AND colc = cold",
+ ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql(),
)
def test_or(self):
self.assertEqual(
- "cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql()
+ "cola = colb OR colc = cold",
+ ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql(),
)
def test_mod(self):
@@ -112,7 +114,9 @@ class TestDataframeColumn(unittest.TestCase):
def test_when_otherwise(self):
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql())
- self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql())
+ self.assertEqual(
+ "CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()
+ )
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
(F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(),
@@ -148,7 +152,9 @@ class TestDataframeColumn(unittest.TestCase):
self.assertEqual(
"cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) "
"AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)",
- F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(),
+ F.col("cola")
+ .between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1))
+ .sql(),
)
def test_over(self):
diff --git a/tests/dataframe/unit/test_dataframe.py b/tests/dataframe/unit/test_dataframe.py
index c222cac..e36667b 100644
--- a/tests/dataframe/unit/test_dataframe.py
+++ b/tests/dataframe/unit/test_dataframe.py
@@ -9,7 +9,9 @@ class TestDataframe(DataFrameSQLValidator):
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
def test_columns(self):
- self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns)
+ self.assertEqual(
+ ["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns
+ )
def test_cache(self):
df = self.df_employee.select("fname").cache()
diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py
index eadbb93..8e5e5cd 100644
--- a/tests/dataframe/unit/test_functions.py
+++ b/tests/dataframe/unit/test_functions.py
@@ -925,12 +925,17 @@ class TestFunctions(unittest.TestCase):
col = SF.window(SF.col("cola"), "10 minutes")
self.assertEqual("WINDOW(cola, '10 minutes')", col.sql())
col_all_values = SF.window("cola", "2 minutes 30 seconds", "30 seconds", "15 seconds")
- self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql())
+ self.assertEqual(
+ "WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql()
+ )
col_no_start_time = SF.window("cola", "2 minutes 30 seconds", "30 seconds")
- self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql())
+ self.assertEqual(
+ "WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql()
+ )
col_no_slide = SF.window("cola", "2 minutes 30 seconds", startTime="15 seconds")
self.assertEqual(
- "WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')", col_no_slide.sql()
+ "WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')",
+ col_no_slide.sql(),
)
def test_session_window(self):
@@ -1359,9 +1364,13 @@ class TestFunctions(unittest.TestCase):
def test_from_json(self):
col_str = SF.from_json("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
+ self.assertEqual(
+ "FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
+ )
col = SF.from_json(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
+ self.assertEqual(
+ "FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
+ )
col_no_option = SF.from_json("cola", "cola INT")
self.assertEqual("FROM_JSON(cola, 'cola INT')", col_no_option.sql())
@@ -1375,7 +1384,9 @@ class TestFunctions(unittest.TestCase):
def test_schema_of_json(self):
col_str = SF.schema_of_json("cola", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
+ self.assertEqual(
+ "SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
+ )
col = SF.schema_of_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
col_no_option = SF.schema_of_json("cola")
@@ -1429,7 +1440,10 @@ class TestFunctions(unittest.TestCase):
col = SF.array_sort(SF.col("cola"))
self.assertEqual("ARRAY_SORT(cola)", col.sql())
col_comparator = SF.array_sort(
- "cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x))
+ "cola",
+ lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(
+ SF.length(y) - SF.length(x)
+ ),
)
self.assertEqual(
"ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)",
@@ -1504,9 +1518,13 @@ class TestFunctions(unittest.TestCase):
def test_from_csv(self):
col_str = SF.from_csv("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
+ self.assertEqual(
+ "FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
+ )
col = SF.from_csv(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
+ self.assertEqual(
+ "FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
+ )
col_no_option = SF.from_csv("cola", "cola INT")
self.assertEqual("FROM_CSV(cola, 'cola INT')", col_no_option.sql())
@@ -1535,7 +1553,9 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql())
col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count)
- self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql())
+ self.assertEqual(
+ "TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()
+ )
def test_exists(self):
col_str = SF.exists("cola", lambda x: x % 2 == 0)
@@ -1558,10 +1578,13 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql())
col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i))
self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql())
- col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count))
+ col_custom_names = SF.filter(
+ "cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count)
+ )
self.assertEqual(
- "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql()
+ "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)",
+ col_custom_names.sql(),
)
def test_zip_with(self):
@@ -1570,7 +1593,9 @@ class TestFunctions(unittest.TestCase):
col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y))
self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql())
col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r))
- self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql())
+ self.assertEqual(
+ "ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()
+ )
def test_transform_keys(self):
col_str = SF.transform_keys("cola", lambda k, v: SF.upper(k))
@@ -1586,7 +1611,9 @@ class TestFunctions(unittest.TestCase):
col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v))
self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql())
col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value))
- self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql())
+ self.assertEqual(
+ "TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()
+ )
def test_map_filter(self):
col_str = SF.map_filter("cola", lambda k, v: k > v)
diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py
index 158dcec..7e8bfad 100644
--- a/tests/dataframe/unit/test_session.py
+++ b/tests/dataframe/unit/test_session.py
@@ -21,9 +21,7 @@ class TestDataframeSession(DataFrameSQLValidator):
def test_cdf_no_schema(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
- expected = (
- "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
- )
+ expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
self.compare_sql(df, expected)
def test_cdf_row_mixed_primitives(self):
@@ -77,7 +75,8 @@ class TestDataframeSession(DataFrameSQLValidator):
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
self.assertIn(
- "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False)
+ "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`",
+ df.sql(pretty=False),
)
@mock.patch("sqlglot.schema", MappingSchema())
@@ -104,9 +103,7 @@ class TestDataframeSession(DataFrameSQLValidator):
query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
- expected = (
- "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
- )
+ expected = "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
self.compare_sql(df, expected)
def test_session_create_builder_patterns(self):
diff --git a/tests/dataframe/unit/test_types.py b/tests/dataframe/unit/test_types.py
index 1f6c5dc..52f5d72 100644
--- a/tests/dataframe/unit/test_types.py
+++ b/tests/dataframe/unit/test_types.py
@@ -53,7 +53,10 @@ class TestDataframeTypes(unittest.TestCase):
self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString())
def test_map(self):
- self.assertEqual("map<int, string>", types.MapType(types.IntegerType(), types.StringType()).simpleString())
+ self.assertEqual(
+ "map<int, string>",
+ types.MapType(types.IntegerType(), types.StringType()).simpleString(),
+ )
def test_struct_field(self):
self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString())
diff --git a/tests/dataframe/unit/test_window.py b/tests/dataframe/unit/test_window.py
index eea4582..70a868a 100644
--- a/tests/dataframe/unit/test_window.py
+++ b/tests/dataframe/unit/test_window.py
@@ -39,22 +39,38 @@ class TestDataframeWindow(unittest.TestCase):
def test_window_rows_unbounded(self):
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
- self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql())
+ self.assertEqual(
+ "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
+ rows_between_unbounded_start.sql(),
+ )
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
- self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql())
- rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual(
- "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql()
+ "OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
+ rows_between_unbounded_end.sql(),
+ )
+ rows_between_unbounded_both = Window.rowsBetween(
+ Window.unboundedPreceding, Window.unboundedFollowing
+ )
+ self.assertEqual(
+ "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
+ rows_between_unbounded_both.sql(),
)
def test_window_range_unbounded(self):
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
self.assertEqual(
- "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql()
+ "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
+ range_between_unbounded_start.sql(),
)
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
- self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql())
- range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual(
- "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql()
+ "OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
+ range_between_unbounded_end.sql(),
+ )
+ range_between_unbounded_both = Window.rangeBetween(
+ Window.unboundedPreceding, Window.unboundedFollowing
+ )
+ self.assertEqual(
+ "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
+ range_between_unbounded_both.sql(),
)