diff options
Diffstat (limited to 'tests/dataframe')
-rw-r--r-- | tests/dataframe/integration/dataframe_validator.py | 52 | ||||
-rw-r--r-- | tests/dataframe/integration/test_dataframe.py | 295 | ||||
-rw-r--r-- | tests/dataframe/integration/test_session.py | 12 | ||||
-rw-r--r-- | tests/dataframe/unit/dataframe_sql_validator.py | 12 | ||||
-rw-r--r-- | tests/dataframe/unit/test_column.py | 14 | ||||
-rw-r--r-- | tests/dataframe/unit/test_dataframe.py | 4 | ||||
-rw-r--r-- | tests/dataframe/unit/test_functions.py | 55 | ||||
-rw-r--r-- | tests/dataframe/unit/test_session.py | 11 | ||||
-rw-r--r-- | tests/dataframe/unit/test_types.py | 5 | ||||
-rw-r--r-- | tests/dataframe/unit/test_window.py | 32 |
10 files changed, 340 insertions, 152 deletions
diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py index 4a89c78..16f8922 100644 --- a/tests/dataframe/integration/dataframe_validator.py +++ b/tests/dataframe/integration/dataframe_validator.py @@ -1,9 +1,9 @@ -import sys import typing as t import unittest import warnings import sqlglot +from sqlglot.helper import PYTHON_VERSION from tests.helpers import SKIP_INTEGRATION if t.TYPE_CHECKING: @@ -11,7 +11,8 @@ if t.TYPE_CHECKING: @unittest.skipIf( - SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set" + SKIP_INTEGRATION or PYTHON_VERSION > (3, 10), + "Skipping Integration Tests since `SKIP_INTEGRATION` is set", ) class DataFrameValidator(unittest.TestCase): spark = None @@ -36,7 +37,12 @@ class DataFrameValidator(unittest.TestCase): # This is for test `test_branching_root_dataframes` config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")]) - cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate() + cls.spark = ( + SparkSession.builder.master("local[*]") + .appName("Unit-tests") + .config(conf=config) + .getOrCreate() + ) cls.spark.sparkContext.setLogLevel("ERROR") cls.sqlglot = SqlglotSparkSession() cls.spark_employee_schema = types.StructType( @@ -50,7 +56,9 @@ class DataFrameValidator(unittest.TestCase): ) cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType( [ - sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField( + "employee_id", sqlglotSparkTypes.IntegerType(), False + ), sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False), sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False), sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False), @@ -64,8 +72,12 @@ class DataFrameValidator(unittest.TestCase): (4, "Claire", "Littleton", 27, 2), (5, "Hugo", "Reyes", 29, 100), ] - cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema) - cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema) + cls.df_employee = cls.spark.createDataFrame( + data=employee_data, schema=cls.spark_employee_schema + ) + cls.dfs_employee = cls.sqlglot.createDataFrame( + data=employee_data, schema=cls.sqlglot_employee_schema + ) cls.df_employee.createOrReplaceTempView("employee") cls.spark_store_schema = types.StructType( @@ -80,7 +92,9 @@ class DataFrameValidator(unittest.TestCase): [ sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False), sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False), - sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField( + "district_id", sqlglotSparkTypes.IntegerType(), False + ), sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False), ] ) @@ -89,7 +103,9 @@ class DataFrameValidator(unittest.TestCase): (2, "Arrow", 2, 2000), ] cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema) - cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema) + cls.dfs_store = cls.sqlglot.createDataFrame( + data=store_data, schema=cls.sqlglot_store_schema + ) cls.df_store.createOrReplaceTempView("store") cls.spark_district_schema = types.StructType( @@ -101,17 +117,27 @@ class DataFrameValidator(unittest.TestCase): ) cls.sqlglot_district_schema = sqlglotSparkTypes.StructType( [ - sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False), - sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False), - sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False), + sqlglotSparkTypes.StructField( + "district_id", sqlglotSparkTypes.IntegerType(), False + ), + sqlglotSparkTypes.StructField( + "district_name", sqlglotSparkTypes.StringType(), False + ), + sqlglotSparkTypes.StructField( + "manager_name", sqlglotSparkTypes.StringType(), False + ), ] ) district_data = [ (1, "Temple", "Dogen"), (2, "Lighthouse", "Jacob"), ] - cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema) - cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema) + cls.df_district = cls.spark.createDataFrame( + data=district_data, schema=cls.spark_district_schema + ) + cls.dfs_district = cls.sqlglot.createDataFrame( + data=district_data, schema=cls.sqlglot_district_schema + ) cls.df_district.createOrReplaceTempView("district") sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema) sqlglot.schema.add_table("store", cls.sqlglot_store_schema) diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py index c740bec..19e3b89 100644 --- a/tests/dataframe/integration/test_dataframe.py +++ b/tests/dataframe/integration/test_dataframe.py @@ -41,22 +41,32 @@ class TestDataframeFunc(DataFrameValidator): def test_alias_with_select(self): df_employee = self.df_spark_employee.alias("df_employee").select( - self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname + self.df_spark_employee["employee_id"], + F.col("df_employee.fname"), + self.df_spark_employee.lname, ) dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select( - self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname + self.df_sqlglot_employee["employee_id"], + SF.col("dfs_employee.fname"), + self.df_sqlglot_employee.lname, ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_case_when_otherwise(self): df = self.df_spark_employee.select( - F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")) + F.when( + (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), + F.lit("between 40 and 60"), + ) .when(F.col("age") < F.lit(40), "less than 40") .otherwise("greater than 60") ) dfs = self.df_sqlglot_employee.select( - SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")) + SF.when( + (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), + SF.lit("between 40 and 60"), + ) .when(SF.col("age") < SF.lit(40), "less than 40") .otherwise("greater than 60") ) @@ -65,15 +75,17 @@ class TestDataframeFunc(DataFrameValidator): def test_case_when_no_otherwise(self): df = self.df_spark_employee.select( - F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when( - F.col("age") < F.lit(40), "less than 40" - ) + F.when( + (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), + F.lit("between 40 and 60"), + ).when(F.col("age") < F.lit(40), "less than 40") ) dfs = self.df_sqlglot_employee.select( - SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when( - SF.col("age") < SF.lit(40), "less than 40" - ) + SF.when( + (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), + SF.lit("between 40 and 60"), + ).when(SF.col("age") < SF.lit(40), "less than 40") ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -84,7 +96,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_where_clause_multiple_and(self): - df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))) + df_employee = self.df_spark_employee.where( + (F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack")) + ) dfs_employee = self.df_sqlglot_employee.where( (SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack")) ) @@ -106,7 +120,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_where_clause_multiple_or(self): - df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))) + df_employee = self.df_spark_employee.where( + (F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate")) + ) dfs_employee = self.df_sqlglot_employee.where( (SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate")) ) @@ -172,28 +188,43 @@ class TestDataframeFunc(DataFrameValidator): dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37)) self.compare_spark_with_sqlglot(df_employee, dfs_employee) - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)) + df_employee = self.df_spark_employee.where( + self.df_spark_employee["age"] % F.lit(5) == F.lit(0) + ) + dfs_employee = self.df_sqlglot_employee.where( + self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0) + ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)) + df_employee = self.df_spark_employee.where( + self.df_spark_employee["age"] + F.lit(5) > F.lit(28) + ) + dfs_employee = self.df_sqlglot_employee.where( + self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28) + ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)) + df_employee = self.df_spark_employee.where( + self.df_spark_employee["age"] - F.lit(5) > F.lit(28) + ) + dfs_employee = self.df_sqlglot_employee.where( + self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28) + ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) df_employee = self.df_spark_employee.where( self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2) ) dfs_employee = self.df_sqlglot_employee.where( - self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2) + self.df_sqlglot_employee["age"] * SF.lit(0.5) + == self.df_sqlglot_employee["age"] / SF.lit(2) ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_join_inner(self): - df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select( + df_joined = self.df_spark_employee.join( + self.df_spark_store, on=["store_id"], how="inner" + ).select( self.df_spark_employee.employee_id, self.df_spark_employee["fname"], F.col("lname"), @@ -202,7 +233,9 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_store.store_name, self.df_spark_store["num_sales"], ) - dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select( + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, on=["store_id"], how="inner" + ).select( self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee["fname"], SF.col("lname"), @@ -214,17 +247,27 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_joined, dfs_joined) def test_join_inner_no_select(self): - df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join( - self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner" + df_joined = self.df_spark_employee.select( + F.col("store_id"), F.col("fname"), F.col("lname") + ).join( + self.df_spark_store.select(F.col("store_id"), F.col("store_name")), + on=["store_id"], + how="inner", ) - dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner" + dfs_joined = self.df_sqlglot_employee.select( + SF.col("store_id"), SF.col("fname"), SF.col("lname") + ).join( + self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), + on=["store_id"], + how="inner", ) self.compare_spark_with_sqlglot(df_joined, dfs_joined) def test_join_inner_equality_single(self): df_joined = self.df_spark_employee.join( - self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner" + self.df_spark_store, + on=self.df_spark_employee.store_id == self.df_spark_store.store_id, + how="inner", ).select( self.df_spark_employee.employee_id, self.df_spark_employee["fname"], @@ -235,7 +278,9 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_store["num_sales"], ) dfs_joined = self.df_sqlglot_employee.join( - self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner" + self.df_sqlglot_store, + on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, + how="inner", ).select( self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee["fname"], @@ -343,7 +388,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_joined, dfs_joined) def test_join_full_outer(self): - df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select( + df_joined = self.df_spark_employee.join( + self.df_spark_store, on=["store_id"], how="full_outer" + ).select( self.df_spark_employee.employee_id, self.df_spark_employee["fname"], F.col("lname"), @@ -352,7 +399,9 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_store.store_name, self.df_spark_store["num_sales"], ) - dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select( + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, on=["store_id"], how="full_outer" + ).select( self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee["fname"], SF.col("lname"), @@ -365,7 +414,9 @@ class TestDataframeFunc(DataFrameValidator): def test_triple_join(self): df = ( - self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id) + self.df_employee.join( + self.df_store, on=self.df_employee.employee_id == self.df_store.store_id + ) .join(self.df_district, on=self.df_store.store_id == self.df_district.district_id) .select( self.df_employee.employee_id, @@ -377,7 +428,9 @@ class TestDataframeFunc(DataFrameValidator): ) ) dfs = ( - self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id) + self.dfs_employee.join( + self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id + ) .join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id) .select( self.dfs_employee.employee_id, @@ -391,13 +444,13 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_join_select_and_select_start(self): - df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join( - self.df_spark_store, "store_id", "inner" - ) + df = self.df_spark_employee.select( + F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id") + ).join(self.df_spark_store, "store_id", "inner") - dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join( - self.df_sqlglot_store, "store_id", "inner" - ) + dfs = self.df_sqlglot_employee.select( + SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id") + ).join(self.df_sqlglot_store, "store_id", "inner") self.compare_spark_with_sqlglot(df, dfs) @@ -485,13 +538,17 @@ class TestDataframeFunc(DataFrameValidator): dfs_unioned = ( self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname")) .unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))) - .unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))) + .unionAll( + self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")) + ) ) self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) def test_union_by_name(self): - df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName( + df = self.df_spark_employee.select( + F.col("employee_id"), F.col("fname"), F.col("lname") + ).unionByName( self.df_spark_store.select( F.col("store_name").alias("lname"), F.col("store_id").alias("employee_id"), @@ -499,7 +556,9 @@ class TestDataframeFunc(DataFrameValidator): ) ) - dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName( + dfs = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("fname"), SF.col("lname") + ).unionByName( self.df_sqlglot_store.select( SF.col("store_name").alias("lname"), SF.col("store_id").alias("employee_id"), @@ -537,10 +596,16 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_order_by_default(self): - df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id")) + df = ( + self.df_spark_store.groupBy(F.col("district_id")) + .agg(F.min("num_sales")) + .orderBy(F.col("district_id")) + ) dfs = ( - self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id")) + self.df_sqlglot_store.groupBy(SF.col("district_id")) + .agg(SF.min("num_sales")) + .orderBy(SF.col("district_id")) ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -594,13 +659,17 @@ class TestDataframeFunc(DataFrameValidator): df = ( self.df_spark_store.groupBy(F.col("district_id")) .agg(F.min("num_sales").alias("total_sales")) - .orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()) + .orderBy( + F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last() + ) ) dfs = ( self.df_sqlglot_store.groupBy(SF.col("district_id")) .agg(SF.min("num_sales").alias("total_sales")) - .orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()) + .orderBy( + SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last() + ) ) self.compare_spark_with_sqlglot(df, dfs) @@ -609,81 +678,87 @@ class TestDataframeFunc(DataFrameValidator): df = ( self.df_spark_store.groupBy(F.col("district_id")) .agg(F.min("num_sales").alias("total_sales")) - .orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()) + .orderBy( + F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first() + ) ) dfs = ( self.df_sqlglot_store.groupBy(SF.col("district_id")) .agg(SF.min("num_sales").alias("total_sales")) - .orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first()) + .orderBy( + SF.when( + SF.col("district_id") == SF.lit(1), SF.col("district_id") + ).desc_nulls_first() + ) ) self.compare_spark_with_sqlglot(df, dfs) def test_intersect(self): - df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( - self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) - ) + df_employee_duplicate = self.df_spark_employee.select( + F.col("employee_id"), F.col("store_id") + ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( - self.df_spark_store.select(F.col("store_id"), F.col("district_id")) - ) + df_store_duplicate = self.df_spark_store.select( + F.col("store_id"), F.col("district_id") + ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) df = df_employee_duplicate.intersect(df_store_duplicate) - dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( - self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) - ) + dfs_employee_duplicate = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("store_id") + ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) - ) + dfs_store_duplicate = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("district_id") + ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate) self.compare_spark_with_sqlglot(df, dfs) def test_intersect_all(self): - df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( - self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) - ) + df_employee_duplicate = self.df_spark_employee.select( + F.col("employee_id"), F.col("store_id") + ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( - self.df_spark_store.select(F.col("store_id"), F.col("district_id")) - ) + df_store_duplicate = self.df_spark_store.select( + F.col("store_id"), F.col("district_id") + ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) df = df_employee_duplicate.intersectAll(df_store_duplicate) - dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( - self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) - ) + dfs_employee_duplicate = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("store_id") + ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) - ) + dfs_store_duplicate = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("district_id") + ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate) self.compare_spark_with_sqlglot(df, dfs) def test_except_all(self): - df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( - self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) - ) + df_employee_duplicate = self.df_spark_employee.select( + F.col("employee_id"), F.col("store_id") + ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( - self.df_spark_store.select(F.col("store_id"), F.col("district_id")) - ) + df_store_duplicate = self.df_spark_store.select( + F.col("store_id"), F.col("district_id") + ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) df = df_employee_duplicate.exceptAll(df_store_duplicate) - dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( - self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) - ) + dfs_employee_duplicate = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("store_id") + ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) - ) + dfs_store_duplicate = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("district_id") + ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate) @@ -721,7 +796,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_drop_na_default(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna() + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).dropna() dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -746,7 +823,9 @@ class TestDataframeFunc(DataFrameValidator): ).dropna(how="any", thresh=2) dfs = self.df_sqlglot_employee.select( - SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + SF.lit(None), + SF.lit(1), + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"), ).dropna(how="any", thresh=2) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -757,13 +836,17 @@ class TestDataframeFunc(DataFrameValidator): ).dropna(thresh=1, subset="the_age") dfs = self.df_sqlglot_employee.select( - SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + SF.lit(None), + SF.lit(1), + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"), ).dropna(thresh=1, subset="the_age") self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) def test_dropna_na_function(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop() + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).na.drop() dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -772,7 +855,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_fillna_default(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100) + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).fillna(100) dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -798,7 +883,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) def test_fillna_na_func(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100) + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).na.fill(100) dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -807,7 +894,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_replace_basic(self): - df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100) + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace( + to_replace=37, value=100 + ) dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( to_replace=37, value=100 @@ -827,9 +916,13 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) def test_replace_mapping(self): - df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100}) + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace( + {37: 100} + ) - dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100}) + dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( + {37: 100} + ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -849,9 +942,9 @@ class TestDataframeFunc(DataFrameValidator): to_replace=37, value=100 ) - dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace( - to_replace=37, value=100 - ) + dfs = self.df_sqlglot_employee.select( + SF.col("age"), SF.lit(37).alias("test_col") + ).na.replace(to_replace=37, value=100) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -881,16 +974,18 @@ class TestDataframeFunc(DataFrameValidator): "first_name", "first_name_again" ) - dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed( - "first_name", "first_name_again" - ) + dfs = self.df_sqlglot_employee.select( + SF.col("fname").alias("first_name") + ).withColumnRenamed("first_name", "first_name_again") self.compare_spark_with_sqlglot(df, dfs) def test_drop_column_single(self): df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age") - dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age") + dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop( + "age" + ) self.compare_spark_with_sqlglot(df, dfs) @@ -906,7 +1001,9 @@ class TestDataframeFunc(DataFrameValidator): df_sqlglot_employee_cols = self.df_sqlglot_employee.select( SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id") ) - df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")) + df_sqlglot_store_cols = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("store_name") + ) dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop( df_sqlglot_employee_cols.age, ) diff --git a/tests/dataframe/integration/test_session.py b/tests/dataframe/integration/test_session.py index ff1477b..ec50034 100644 --- a/tests/dataframe/integration/test_session.py +++ b/tests/dataframe/integration/test_session.py @@ -23,6 +23,14 @@ class TestSessionFunc(DataFrameValidator): ON e.store_id = s.store_id """ - df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id"))) - dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id"))) + df = ( + self.spark.sql(query) + .groupBy(F.col("store_id")) + .agg(F.countDistinct(F.col("employee_id"))) + ) + dfs = ( + self.sqlglot.sql(query) + .groupBy(SF.col("store_id")) + .agg(SF.countDistinct(SF.col("employee_id"))) + ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py index fc56553..32ff8f2 100644 --- a/tests/dataframe/unit/dataframe_sql_validator.py +++ b/tests/dataframe/unit/dataframe_sql_validator.py @@ -25,11 +25,17 @@ class DataFrameSQLValidator(unittest.TestCase): (4, "Claire", "Littleton", 27, 2), (5, "Hugo", "Reyes", 29, 100), ] - self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema) + self.df_employee = self.spark.createDataFrame( + data=employee_data, schema=self.employee_schema + ) - def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False): + def compare_sql( + self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False + ): actual_sqls = df.sql(pretty=pretty) - expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements + expected_statements = ( + [expected_statements] if isinstance(expected_statements, str) else expected_statements + ) self.assertEqual(len(expected_statements), len(actual_sqls)) for expected, actual in zip(expected_statements, actual_sqls): self.assertEqual(expected, actual) diff --git a/tests/dataframe/unit/test_column.py b/tests/dataframe/unit/test_column.py index 977971e..da18502 100644 --- a/tests/dataframe/unit/test_column.py +++ b/tests/dataframe/unit/test_column.py @@ -26,12 +26,14 @@ class TestDataframeColumn(unittest.TestCase): def test_and(self): self.assertEqual( - "cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql() + "cola = colb AND colc = cold", + ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql(), ) def test_or(self): self.assertEqual( - "cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql() + "cola = colb OR colc = cold", + ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql(), ) def test_mod(self): @@ -112,7 +114,9 @@ class TestDataframeColumn(unittest.TestCase): def test_when_otherwise(self): self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql()) - self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()) + self.assertEqual( + "CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql() + ) self.assertEqual( "CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END", (F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(), @@ -148,7 +152,9 @@ class TestDataframeColumn(unittest.TestCase): self.assertEqual( "cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) " "AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)", - F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(), + F.col("cola") + .between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)) + .sql(), ) def test_over(self): diff --git a/tests/dataframe/unit/test_dataframe.py b/tests/dataframe/unit/test_dataframe.py index c222cac..e36667b 100644 --- a/tests/dataframe/unit/test_dataframe.py +++ b/tests/dataframe/unit/test_dataframe.py @@ -9,7 +9,9 @@ class TestDataframe(DataFrameSQLValidator): self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression)) def test_columns(self): - self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns) + self.assertEqual( + ["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns + ) def test_cache(self): df = self.df_employee.select("fname").cache() diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index eadbb93..8e5e5cd 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -925,12 +925,17 @@ class TestFunctions(unittest.TestCase): col = SF.window(SF.col("cola"), "10 minutes") self.assertEqual("WINDOW(cola, '10 minutes')", col.sql()) col_all_values = SF.window("cola", "2 minutes 30 seconds", "30 seconds", "15 seconds") - self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql()) + self.assertEqual( + "WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql() + ) col_no_start_time = SF.window("cola", "2 minutes 30 seconds", "30 seconds") - self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql()) + self.assertEqual( + "WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql() + ) col_no_slide = SF.window("cola", "2 minutes 30 seconds", startTime="15 seconds") self.assertEqual( - "WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')", col_no_slide.sql() + "WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')", + col_no_slide.sql(), ) def test_session_window(self): @@ -1359,9 +1364,13 @@ class TestFunctions(unittest.TestCase): def test_from_json(self): col_str = SF.from_json("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy")) - self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + self.assertEqual( + "FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql() + ) col = SF.from_json(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy")) - self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + self.assertEqual( + "FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql() + ) col_no_option = SF.from_json("cola", "cola INT") self.assertEqual("FROM_JSON(cola, 'cola INT')", col_no_option.sql()) @@ -1375,7 +1384,9 @@ class TestFunctions(unittest.TestCase): def test_schema_of_json(self): col_str = SF.schema_of_json("cola", dict(timestampFormat="dd/MM/yyyy")) - self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + self.assertEqual( + "SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql() + ) col = SF.schema_of_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy")) self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) col_no_option = SF.schema_of_json("cola") @@ -1429,7 +1440,10 @@ class TestFunctions(unittest.TestCase): col = SF.array_sort(SF.col("cola")) self.assertEqual("ARRAY_SORT(cola)", col.sql()) col_comparator = SF.array_sort( - "cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x)) + "cola", + lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise( + SF.length(y) - SF.length(x) + ), ) self.assertEqual( "ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)", @@ -1504,9 +1518,13 @@ class TestFunctions(unittest.TestCase): def test_from_csv(self): col_str = SF.from_csv("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy")) - self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()) + self.assertEqual( + "FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql() + ) col = SF.from_csv(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy")) - self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()) + self.assertEqual( + "FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql() + ) col_no_option = SF.from_csv("cola", "cola INT") self.assertEqual("FROM_CSV(cola, 'cola INT')", col_no_option.sql()) @@ -1535,7 +1553,9 @@ class TestFunctions(unittest.TestCase): self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql()) col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count) - self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()) + self.assertEqual( + "TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql() + ) def test_exists(self): col_str = SF.exists("cola", lambda x: x % 2 == 0) @@ -1558,10 +1578,13 @@ class TestFunctions(unittest.TestCase): self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql()) col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i)) self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql()) - col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count)) + col_custom_names = SF.filter( + "cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count) + ) self.assertEqual( - "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql() + "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", + col_custom_names.sql(), ) def test_zip_with(self): @@ -1570,7 +1593,9 @@ class TestFunctions(unittest.TestCase): col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y)) self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql()) col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r)) - self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()) + self.assertEqual( + "ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql() + ) def test_transform_keys(self): col_str = SF.transform_keys("cola", lambda k, v: SF.upper(k)) @@ -1586,7 +1611,9 @@ class TestFunctions(unittest.TestCase): col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v)) self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql()) col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value)) - self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()) + self.assertEqual( + "TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql() + ) def test_map_filter(self): col_str = SF.map_filter("cola", lambda k, v: k > v) diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py index 158dcec..7e8bfad 100644 --- a/tests/dataframe/unit/test_session.py +++ b/tests/dataframe/unit/test_session.py @@ -21,9 +21,7 @@ class TestDataframeSession(DataFrameSQLValidator): def test_cdf_no_schema(self): df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]]) - expected = ( - "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)" - ) + expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)" self.compare_sql(df, expected) def test_cdf_row_mixed_primitives(self): @@ -77,7 +75,8 @@ class TestDataframeSession(DataFrameSQLValidator): sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) df = self.spark.sql(query) self.assertIn( - "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False) + "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", + df.sql(pretty=False), ) @mock.patch("sqlglot.schema", MappingSchema()) @@ -104,9 +103,7 @@ class TestDataframeSession(DataFrameSQLValidator): query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1" sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}) df = self.spark.sql(query) - expected = ( - "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" - ) + expected = "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" self.compare_sql(df, expected) def test_session_create_builder_patterns(self): diff --git a/tests/dataframe/unit/test_types.py b/tests/dataframe/unit/test_types.py index 1f6c5dc..52f5d72 100644 --- a/tests/dataframe/unit/test_types.py +++ b/tests/dataframe/unit/test_types.py @@ -53,7 +53,10 @@ class TestDataframeTypes(unittest.TestCase): self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString()) def test_map(self): - self.assertEqual("map<int, string>", types.MapType(types.IntegerType(), types.StringType()).simpleString()) + self.assertEqual( + "map<int, string>", + types.MapType(types.IntegerType(), types.StringType()).simpleString(), + ) def test_struct_field(self): self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString()) diff --git a/tests/dataframe/unit/test_window.py b/tests/dataframe/unit/test_window.py index eea4582..70a868a 100644 --- a/tests/dataframe/unit/test_window.py +++ b/tests/dataframe/unit/test_window.py @@ -39,22 +39,38 @@ class TestDataframeWindow(unittest.TestCase): def test_window_rows_unbounded(self): rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2) - self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql()) + self.assertEqual( + "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", + rows_between_unbounded_start.sql(), + ) rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing) - self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql()) - rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) self.assertEqual( - "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql() + "OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", + rows_between_unbounded_end.sql(), + ) + rows_between_unbounded_both = Window.rowsBetween( + Window.unboundedPreceding, Window.unboundedFollowing + ) + self.assertEqual( + "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", + rows_between_unbounded_both.sql(), ) def test_window_range_unbounded(self): range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2) self.assertEqual( - "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql() + "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", + range_between_unbounded_start.sql(), ) range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing) - self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql()) - range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing) self.assertEqual( - "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql() + "OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", + range_between_unbounded_end.sql(), + ) + range_between_unbounded_both = Window.rangeBetween( + Window.unboundedPreceding, Window.unboundedFollowing + ) + self.assertEqual( + "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", + range_between_unbounded_both.sql(), ) |