diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-11-11 08:54:30 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-11-11 08:54:30 +0000 |
commit | 9ebe8c99ba4be74ccebf1b013f4e56ec09e023c1 (patch) | |
tree | 7ab2f39fbb6fd832aeea5cef45b54bfd59ba5ba5 /tests/dataframe/integration/test_dataframe.py | |
parent | Adding upstream version 9.0.6. (diff) | |
download | sqlglot-9ebe8c99ba4be74ccebf1b013f4e56ec09e023c1.tar.xz sqlglot-9ebe8c99ba4be74ccebf1b013f4e56ec09e023c1.zip |
Adding upstream version 10.0.1.upstream/10.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/dataframe/integration/test_dataframe.py')
-rw-r--r-- | tests/dataframe/integration/test_dataframe.py | 295 |
1 files changed, 196 insertions, 99 deletions
diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py index c740bec..19e3b89 100644 --- a/tests/dataframe/integration/test_dataframe.py +++ b/tests/dataframe/integration/test_dataframe.py @@ -41,22 +41,32 @@ class TestDataframeFunc(DataFrameValidator): def test_alias_with_select(self): df_employee = self.df_spark_employee.alias("df_employee").select( - self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname + self.df_spark_employee["employee_id"], + F.col("df_employee.fname"), + self.df_spark_employee.lname, ) dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select( - self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname + self.df_sqlglot_employee["employee_id"], + SF.col("dfs_employee.fname"), + self.df_sqlglot_employee.lname, ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_case_when_otherwise(self): df = self.df_spark_employee.select( - F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")) + F.when( + (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), + F.lit("between 40 and 60"), + ) .when(F.col("age") < F.lit(40), "less than 40") .otherwise("greater than 60") ) dfs = self.df_sqlglot_employee.select( - SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")) + SF.when( + (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), + SF.lit("between 40 and 60"), + ) .when(SF.col("age") < SF.lit(40), "less than 40") .otherwise("greater than 60") ) @@ -65,15 +75,17 @@ class TestDataframeFunc(DataFrameValidator): def test_case_when_no_otherwise(self): df = self.df_spark_employee.select( - F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when( - F.col("age") < F.lit(40), "less than 40" - ) + F.when( + (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), + F.lit("between 40 and 60"), + ).when(F.col("age") < F.lit(40), "less than 40") ) dfs = self.df_sqlglot_employee.select( - SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when( - SF.col("age") < SF.lit(40), "less than 40" - ) + SF.when( + (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), + SF.lit("between 40 and 60"), + ).when(SF.col("age") < SF.lit(40), "less than 40") ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -84,7 +96,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_where_clause_multiple_and(self): - df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))) + df_employee = self.df_spark_employee.where( + (F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack")) + ) dfs_employee = self.df_sqlglot_employee.where( (SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack")) ) @@ -106,7 +120,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_where_clause_multiple_or(self): - df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))) + df_employee = self.df_spark_employee.where( + (F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate")) + ) dfs_employee = self.df_sqlglot_employee.where( (SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate")) ) @@ -172,28 +188,43 @@ class TestDataframeFunc(DataFrameValidator): dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37)) self.compare_spark_with_sqlglot(df_employee, dfs_employee) - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)) + df_employee = self.df_spark_employee.where( + self.df_spark_employee["age"] % F.lit(5) == F.lit(0) + ) + dfs_employee = self.df_sqlglot_employee.where( + self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0) + ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)) + df_employee = self.df_spark_employee.where( + self.df_spark_employee["age"] + F.lit(5) > F.lit(28) + ) + dfs_employee = self.df_sqlglot_employee.where( + self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28) + ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)) + df_employee = self.df_spark_employee.where( + self.df_spark_employee["age"] - F.lit(5) > F.lit(28) + ) + dfs_employee = self.df_sqlglot_employee.where( + self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28) + ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) df_employee = self.df_spark_employee.where( self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2) ) dfs_employee = self.df_sqlglot_employee.where( - self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2) + self.df_sqlglot_employee["age"] * SF.lit(0.5) + == self.df_sqlglot_employee["age"] / SF.lit(2) ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_join_inner(self): - df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select( + df_joined = self.df_spark_employee.join( + self.df_spark_store, on=["store_id"], how="inner" + ).select( self.df_spark_employee.employee_id, self.df_spark_employee["fname"], F.col("lname"), @@ -202,7 +233,9 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_store.store_name, self.df_spark_store["num_sales"], ) - dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select( + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, on=["store_id"], how="inner" + ).select( self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee["fname"], SF.col("lname"), @@ -214,17 +247,27 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_joined, dfs_joined) def test_join_inner_no_select(self): - df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join( - self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner" + df_joined = self.df_spark_employee.select( + F.col("store_id"), F.col("fname"), F.col("lname") + ).join( + self.df_spark_store.select(F.col("store_id"), F.col("store_name")), + on=["store_id"], + how="inner", ) - dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner" + dfs_joined = self.df_sqlglot_employee.select( + SF.col("store_id"), SF.col("fname"), SF.col("lname") + ).join( + self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), + on=["store_id"], + how="inner", ) self.compare_spark_with_sqlglot(df_joined, dfs_joined) def test_join_inner_equality_single(self): df_joined = self.df_spark_employee.join( - self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner" + self.df_spark_store, + on=self.df_spark_employee.store_id == self.df_spark_store.store_id, + how="inner", ).select( self.df_spark_employee.employee_id, self.df_spark_employee["fname"], @@ -235,7 +278,9 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_store["num_sales"], ) dfs_joined = self.df_sqlglot_employee.join( - self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner" + self.df_sqlglot_store, + on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, + how="inner", ).select( self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee["fname"], @@ -343,7 +388,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_joined, dfs_joined) def test_join_full_outer(self): - df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select( + df_joined = self.df_spark_employee.join( + self.df_spark_store, on=["store_id"], how="full_outer" + ).select( self.df_spark_employee.employee_id, self.df_spark_employee["fname"], F.col("lname"), @@ -352,7 +399,9 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_store.store_name, self.df_spark_store["num_sales"], ) - dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select( + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, on=["store_id"], how="full_outer" + ).select( self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee["fname"], SF.col("lname"), @@ -365,7 +414,9 @@ class TestDataframeFunc(DataFrameValidator): def test_triple_join(self): df = ( - self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id) + self.df_employee.join( + self.df_store, on=self.df_employee.employee_id == self.df_store.store_id + ) .join(self.df_district, on=self.df_store.store_id == self.df_district.district_id) .select( self.df_employee.employee_id, @@ -377,7 +428,9 @@ class TestDataframeFunc(DataFrameValidator): ) ) dfs = ( - self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id) + self.dfs_employee.join( + self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id + ) .join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id) .select( self.dfs_employee.employee_id, @@ -391,13 +444,13 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_join_select_and_select_start(self): - df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join( - self.df_spark_store, "store_id", "inner" - ) + df = self.df_spark_employee.select( + F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id") + ).join(self.df_spark_store, "store_id", "inner") - dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join( - self.df_sqlglot_store, "store_id", "inner" - ) + dfs = self.df_sqlglot_employee.select( + SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id") + ).join(self.df_sqlglot_store, "store_id", "inner") self.compare_spark_with_sqlglot(df, dfs) @@ -485,13 +538,17 @@ class TestDataframeFunc(DataFrameValidator): dfs_unioned = ( self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname")) .unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))) - .unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))) + .unionAll( + self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")) + ) ) self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) def test_union_by_name(self): - df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName( + df = self.df_spark_employee.select( + F.col("employee_id"), F.col("fname"), F.col("lname") + ).unionByName( self.df_spark_store.select( F.col("store_name").alias("lname"), F.col("store_id").alias("employee_id"), @@ -499,7 +556,9 @@ class TestDataframeFunc(DataFrameValidator): ) ) - dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName( + dfs = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("fname"), SF.col("lname") + ).unionByName( self.df_sqlglot_store.select( SF.col("store_name").alias("lname"), SF.col("store_id").alias("employee_id"), @@ -537,10 +596,16 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_order_by_default(self): - df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id")) + df = ( + self.df_spark_store.groupBy(F.col("district_id")) + .agg(F.min("num_sales")) + .orderBy(F.col("district_id")) + ) dfs = ( - self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id")) + self.df_sqlglot_store.groupBy(SF.col("district_id")) + .agg(SF.min("num_sales")) + .orderBy(SF.col("district_id")) ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -594,13 +659,17 @@ class TestDataframeFunc(DataFrameValidator): df = ( self.df_spark_store.groupBy(F.col("district_id")) .agg(F.min("num_sales").alias("total_sales")) - .orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()) + .orderBy( + F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last() + ) ) dfs = ( self.df_sqlglot_store.groupBy(SF.col("district_id")) .agg(SF.min("num_sales").alias("total_sales")) - .orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()) + .orderBy( + SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last() + ) ) self.compare_spark_with_sqlglot(df, dfs) @@ -609,81 +678,87 @@ class TestDataframeFunc(DataFrameValidator): df = ( self.df_spark_store.groupBy(F.col("district_id")) .agg(F.min("num_sales").alias("total_sales")) - .orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()) + .orderBy( + F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first() + ) ) dfs = ( self.df_sqlglot_store.groupBy(SF.col("district_id")) .agg(SF.min("num_sales").alias("total_sales")) - .orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first()) + .orderBy( + SF.when( + SF.col("district_id") == SF.lit(1), SF.col("district_id") + ).desc_nulls_first() + ) ) self.compare_spark_with_sqlglot(df, dfs) def test_intersect(self): - df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( - self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) - ) + df_employee_duplicate = self.df_spark_employee.select( + F.col("employee_id"), F.col("store_id") + ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( - self.df_spark_store.select(F.col("store_id"), F.col("district_id")) - ) + df_store_duplicate = self.df_spark_store.select( + F.col("store_id"), F.col("district_id") + ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) df = df_employee_duplicate.intersect(df_store_duplicate) - dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( - self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) - ) + dfs_employee_duplicate = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("store_id") + ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) - ) + dfs_store_duplicate = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("district_id") + ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate) self.compare_spark_with_sqlglot(df, dfs) def test_intersect_all(self): - df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( - self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) - ) + df_employee_duplicate = self.df_spark_employee.select( + F.col("employee_id"), F.col("store_id") + ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( - self.df_spark_store.select(F.col("store_id"), F.col("district_id")) - ) + df_store_duplicate = self.df_spark_store.select( + F.col("store_id"), F.col("district_id") + ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) df = df_employee_duplicate.intersectAll(df_store_duplicate) - dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( - self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) - ) + dfs_employee_duplicate = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("store_id") + ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) - ) + dfs_store_duplicate = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("district_id") + ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate) self.compare_spark_with_sqlglot(df, dfs) def test_except_all(self): - df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( - self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) - ) + df_employee_duplicate = self.df_spark_employee.select( + F.col("employee_id"), F.col("store_id") + ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( - self.df_spark_store.select(F.col("store_id"), F.col("district_id")) - ) + df_store_duplicate = self.df_spark_store.select( + F.col("store_id"), F.col("district_id") + ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) df = df_employee_duplicate.exceptAll(df_store_duplicate) - dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( - self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) - ) + dfs_employee_duplicate = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("store_id") + ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) - ) + dfs_store_duplicate = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("district_id") + ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate) @@ -721,7 +796,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_drop_na_default(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna() + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).dropna() dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -746,7 +823,9 @@ class TestDataframeFunc(DataFrameValidator): ).dropna(how="any", thresh=2) dfs = self.df_sqlglot_employee.select( - SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + SF.lit(None), + SF.lit(1), + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"), ).dropna(how="any", thresh=2) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -757,13 +836,17 @@ class TestDataframeFunc(DataFrameValidator): ).dropna(thresh=1, subset="the_age") dfs = self.df_sqlglot_employee.select( - SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + SF.lit(None), + SF.lit(1), + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"), ).dropna(thresh=1, subset="the_age") self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) def test_dropna_na_function(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop() + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).na.drop() dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -772,7 +855,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_fillna_default(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100) + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).fillna(100) dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -798,7 +883,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) def test_fillna_na_func(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100) + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).na.fill(100) dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -807,7 +894,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_replace_basic(self): - df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100) + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace( + to_replace=37, value=100 + ) dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( to_replace=37, value=100 @@ -827,9 +916,13 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) def test_replace_mapping(self): - df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100}) + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace( + {37: 100} + ) - dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100}) + dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( + {37: 100} + ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -849,9 +942,9 @@ class TestDataframeFunc(DataFrameValidator): to_replace=37, value=100 ) - dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace( - to_replace=37, value=100 - ) + dfs = self.df_sqlglot_employee.select( + SF.col("age"), SF.lit(37).alias("test_col") + ).na.replace(to_replace=37, value=100) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -881,16 +974,18 @@ class TestDataframeFunc(DataFrameValidator): "first_name", "first_name_again" ) - dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed( - "first_name", "first_name_again" - ) + dfs = self.df_sqlglot_employee.select( + SF.col("fname").alias("first_name") + ).withColumnRenamed("first_name", "first_name_again") self.compare_spark_with_sqlglot(df, dfs) def test_drop_column_single(self): df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age") - dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age") + dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop( + "age" + ) self.compare_spark_with_sqlglot(df, dfs) @@ -906,7 +1001,9 @@ class TestDataframeFunc(DataFrameValidator): df_sqlglot_employee_cols = self.df_sqlglot_employee.select( SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id") ) - df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")) + df_sqlglot_store_cols = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("store_name") + ) dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop( df_sqlglot_employee_cols.age, ) |