From d1f00706bff58b863b0a1c5bf4adf39d36049d4c Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 11 Nov 2022 09:54:35 +0100 Subject: Merging upstream version 10.0.1. Signed-off-by: Daniel Baumann --- tests/dataframe/integration/dataframe_validator.py | 52 +++- tests/dataframe/integration/test_dataframe.py | 295 ++++++++++++++------- tests/dataframe/integration/test_session.py | 12 +- 3 files changed, 245 insertions(+), 114 deletions(-) (limited to 'tests/dataframe/integration') diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py index 4a89c78..16f8922 100644 --- a/tests/dataframe/integration/dataframe_validator.py +++ b/tests/dataframe/integration/dataframe_validator.py @@ -1,9 +1,9 @@ -import sys import typing as t import unittest import warnings import sqlglot +from sqlglot.helper import PYTHON_VERSION from tests.helpers import SKIP_INTEGRATION if t.TYPE_CHECKING: @@ -11,7 +11,8 @@ if t.TYPE_CHECKING: @unittest.skipIf( - SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set" + SKIP_INTEGRATION or PYTHON_VERSION > (3, 10), + "Skipping Integration Tests since `SKIP_INTEGRATION` is set", ) class DataFrameValidator(unittest.TestCase): spark = None @@ -36,7 +37,12 @@ class DataFrameValidator(unittest.TestCase): # This is for test `test_branching_root_dataframes` config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")]) - cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate() + cls.spark = ( + SparkSession.builder.master("local[*]") + .appName("Unit-tests") + .config(conf=config) + .getOrCreate() + ) cls.spark.sparkContext.setLogLevel("ERROR") cls.sqlglot = SqlglotSparkSession() cls.spark_employee_schema = types.StructType( @@ -50,7 +56,9 @@ class DataFrameValidator(unittest.TestCase): ) cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType( [ - sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField( + "employee_id", sqlglotSparkTypes.IntegerType(), False + ), sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False), sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False), sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False), @@ -64,8 +72,12 @@ class DataFrameValidator(unittest.TestCase): (4, "Claire", "Littleton", 27, 2), (5, "Hugo", "Reyes", 29, 100), ] - cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema) - cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema) + cls.df_employee = cls.spark.createDataFrame( + data=employee_data, schema=cls.spark_employee_schema + ) + cls.dfs_employee = cls.sqlglot.createDataFrame( + data=employee_data, schema=cls.sqlglot_employee_schema + ) cls.df_employee.createOrReplaceTempView("employee") cls.spark_store_schema = types.StructType( @@ -80,7 +92,9 @@ class DataFrameValidator(unittest.TestCase): [ sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False), sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False), - sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField( + "district_id", sqlglotSparkTypes.IntegerType(), False + ), sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False), ] ) @@ -89,7 +103,9 @@ class DataFrameValidator(unittest.TestCase): (2, "Arrow", 2, 2000), ] cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema) - cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema) + cls.dfs_store = cls.sqlglot.createDataFrame( + data=store_data, schema=cls.sqlglot_store_schema + ) cls.df_store.createOrReplaceTempView("store") cls.spark_district_schema = types.StructType( @@ -101,17 +117,27 @@ class DataFrameValidator(unittest.TestCase): ) cls.sqlglot_district_schema = sqlglotSparkTypes.StructType( [ - sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False), - sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False), - sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False), + sqlglotSparkTypes.StructField( + "district_id", sqlglotSparkTypes.IntegerType(), False + ), + sqlglotSparkTypes.StructField( + "district_name", sqlglotSparkTypes.StringType(), False + ), + sqlglotSparkTypes.StructField( + "manager_name", sqlglotSparkTypes.StringType(), False + ), ] ) district_data = [ (1, "Temple", "Dogen"), (2, "Lighthouse", "Jacob"), ] - cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema) - cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema) + cls.df_district = cls.spark.createDataFrame( + data=district_data, schema=cls.spark_district_schema + ) + cls.dfs_district = cls.sqlglot.createDataFrame( + data=district_data, schema=cls.sqlglot_district_schema + ) cls.df_district.createOrReplaceTempView("district") sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema) sqlglot.schema.add_table("store", cls.sqlglot_store_schema) diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py index c740bec..19e3b89 100644 --- a/tests/dataframe/integration/test_dataframe.py +++ b/tests/dataframe/integration/test_dataframe.py @@ -41,22 +41,32 @@ class TestDataframeFunc(DataFrameValidator): def test_alias_with_select(self): df_employee = self.df_spark_employee.alias("df_employee").select( - self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname + self.df_spark_employee["employee_id"], + F.col("df_employee.fname"), + self.df_spark_employee.lname, ) dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select( - self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname + self.df_sqlglot_employee["employee_id"], + SF.col("dfs_employee.fname"), + self.df_sqlglot_employee.lname, ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_case_when_otherwise(self): df = self.df_spark_employee.select( - F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")) + F.when( + (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), + F.lit("between 40 and 60"), + ) .when(F.col("age") < F.lit(40), "less than 40") .otherwise("greater than 60") ) dfs = self.df_sqlglot_employee.select( - SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")) + SF.when( + (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), + SF.lit("between 40 and 60"), + ) .when(SF.col("age") < SF.lit(40), "less than 40") .otherwise("greater than 60") ) @@ -65,15 +75,17 @@ class TestDataframeFunc(DataFrameValidator): def test_case_when_no_otherwise(self): df = self.df_spark_employee.select( - F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when( - F.col("age") < F.lit(40), "less than 40" - ) + F.when( + (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), + F.lit("between 40 and 60"), + ).when(F.col("age") < F.lit(40), "less than 40") ) dfs = self.df_sqlglot_employee.select( - SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when( - SF.col("age") < SF.lit(40), "less than 40" - ) + SF.when( + (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), + SF.lit("between 40 and 60"), + ).when(SF.col("age") < SF.lit(40), "less than 40") ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -84,7 +96,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_where_clause_multiple_and(self): - df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))) + df_employee = self.df_spark_employee.where( + (F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack")) + ) dfs_employee = self.df_sqlglot_employee.where( (SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack")) ) @@ -106,7 +120,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_where_clause_multiple_or(self): - df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))) + df_employee = self.df_spark_employee.where( + (F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate")) + ) dfs_employee = self.df_sqlglot_employee.where( (SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate")) ) @@ -172,28 +188,43 @@ class TestDataframeFunc(DataFrameValidator): dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37)) self.compare_spark_with_sqlglot(df_employee, dfs_employee) - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)) + df_employee = self.df_spark_employee.where( + self.df_spark_employee["age"] % F.lit(5) == F.lit(0) + ) + dfs_employee = self.df_sqlglot_employee.where( + self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0) + ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)) + df_employee = self.df_spark_employee.where( + self.df_spark_employee["age"] + F.lit(5) > F.lit(28) + ) + dfs_employee = self.df_sqlglot_employee.where( + self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28) + ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)) + df_employee = self.df_spark_employee.where( + self.df_spark_employee["age"] - F.lit(5) > F.lit(28) + ) + dfs_employee = self.df_sqlglot_employee.where( + self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28) + ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) df_employee = self.df_spark_employee.where( self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2) ) dfs_employee = self.df_sqlglot_employee.where( - self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2) + self.df_sqlglot_employee["age"] * SF.lit(0.5) + == self.df_sqlglot_employee["age"] / SF.lit(2) ) self.compare_spark_with_sqlglot(df_employee, dfs_employee) def test_join_inner(self): - df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select( + df_joined = self.df_spark_employee.join( + self.df_spark_store, on=["store_id"], how="inner" + ).select( self.df_spark_employee.employee_id, self.df_spark_employee["fname"], F.col("lname"), @@ -202,7 +233,9 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_store.store_name, self.df_spark_store["num_sales"], ) - dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select( + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, on=["store_id"], how="inner" + ).select( self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee["fname"], SF.col("lname"), @@ -214,17 +247,27 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_joined, dfs_joined) def test_join_inner_no_select(self): - df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join( - self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner" + df_joined = self.df_spark_employee.select( + F.col("store_id"), F.col("fname"), F.col("lname") + ).join( + self.df_spark_store.select(F.col("store_id"), F.col("store_name")), + on=["store_id"], + how="inner", ) - dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner" + dfs_joined = self.df_sqlglot_employee.select( + SF.col("store_id"), SF.col("fname"), SF.col("lname") + ).join( + self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), + on=["store_id"], + how="inner", ) self.compare_spark_with_sqlglot(df_joined, dfs_joined) def test_join_inner_equality_single(self): df_joined = self.df_spark_employee.join( - self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner" + self.df_spark_store, + on=self.df_spark_employee.store_id == self.df_spark_store.store_id, + how="inner", ).select( self.df_spark_employee.employee_id, self.df_spark_employee["fname"], @@ -235,7 +278,9 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_store["num_sales"], ) dfs_joined = self.df_sqlglot_employee.join( - self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner" + self.df_sqlglot_store, + on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, + how="inner", ).select( self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee["fname"], @@ -343,7 +388,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df_joined, dfs_joined) def test_join_full_outer(self): - df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select( + df_joined = self.df_spark_employee.join( + self.df_spark_store, on=["store_id"], how="full_outer" + ).select( self.df_spark_employee.employee_id, self.df_spark_employee["fname"], F.col("lname"), @@ -352,7 +399,9 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_store.store_name, self.df_spark_store["num_sales"], ) - dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select( + dfs_joined = self.df_sqlglot_employee.join( + self.df_sqlglot_store, on=["store_id"], how="full_outer" + ).select( self.df_sqlglot_employee.employee_id, self.df_sqlglot_employee["fname"], SF.col("lname"), @@ -365,7 +414,9 @@ class TestDataframeFunc(DataFrameValidator): def test_triple_join(self): df = ( - self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id) + self.df_employee.join( + self.df_store, on=self.df_employee.employee_id == self.df_store.store_id + ) .join(self.df_district, on=self.df_store.store_id == self.df_district.district_id) .select( self.df_employee.employee_id, @@ -377,7 +428,9 @@ class TestDataframeFunc(DataFrameValidator): ) ) dfs = ( - self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id) + self.dfs_employee.join( + self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id + ) .join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id) .select( self.dfs_employee.employee_id, @@ -391,13 +444,13 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_join_select_and_select_start(self): - df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join( - self.df_spark_store, "store_id", "inner" - ) + df = self.df_spark_employee.select( + F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id") + ).join(self.df_spark_store, "store_id", "inner") - dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join( - self.df_sqlglot_store, "store_id", "inner" - ) + dfs = self.df_sqlglot_employee.select( + SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id") + ).join(self.df_sqlglot_store, "store_id", "inner") self.compare_spark_with_sqlglot(df, dfs) @@ -485,13 +538,17 @@ class TestDataframeFunc(DataFrameValidator): dfs_unioned = ( self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname")) .unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))) - .unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))) + .unionAll( + self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")) + ) ) self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) def test_union_by_name(self): - df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName( + df = self.df_spark_employee.select( + F.col("employee_id"), F.col("fname"), F.col("lname") + ).unionByName( self.df_spark_store.select( F.col("store_name").alias("lname"), F.col("store_id").alias("employee_id"), @@ -499,7 +556,9 @@ class TestDataframeFunc(DataFrameValidator): ) ) - dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName( + dfs = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("fname"), SF.col("lname") + ).unionByName( self.df_sqlglot_store.select( SF.col("store_name").alias("lname"), SF.col("store_id").alias("employee_id"), @@ -537,10 +596,16 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_order_by_default(self): - df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id")) + df = ( + self.df_spark_store.groupBy(F.col("district_id")) + .agg(F.min("num_sales")) + .orderBy(F.col("district_id")) + ) dfs = ( - self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id")) + self.df_sqlglot_store.groupBy(SF.col("district_id")) + .agg(SF.min("num_sales")) + .orderBy(SF.col("district_id")) ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -594,13 +659,17 @@ class TestDataframeFunc(DataFrameValidator): df = ( self.df_spark_store.groupBy(F.col("district_id")) .agg(F.min("num_sales").alias("total_sales")) - .orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()) + .orderBy( + F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last() + ) ) dfs = ( self.df_sqlglot_store.groupBy(SF.col("district_id")) .agg(SF.min("num_sales").alias("total_sales")) - .orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()) + .orderBy( + SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last() + ) ) self.compare_spark_with_sqlglot(df, dfs) @@ -609,81 +678,87 @@ class TestDataframeFunc(DataFrameValidator): df = ( self.df_spark_store.groupBy(F.col("district_id")) .agg(F.min("num_sales").alias("total_sales")) - .orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()) + .orderBy( + F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first() + ) ) dfs = ( self.df_sqlglot_store.groupBy(SF.col("district_id")) .agg(SF.min("num_sales").alias("total_sales")) - .orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first()) + .orderBy( + SF.when( + SF.col("district_id") == SF.lit(1), SF.col("district_id") + ).desc_nulls_first() + ) ) self.compare_spark_with_sqlglot(df, dfs) def test_intersect(self): - df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( - self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) - ) + df_employee_duplicate = self.df_spark_employee.select( + F.col("employee_id"), F.col("store_id") + ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( - self.df_spark_store.select(F.col("store_id"), F.col("district_id")) - ) + df_store_duplicate = self.df_spark_store.select( + F.col("store_id"), F.col("district_id") + ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) df = df_employee_duplicate.intersect(df_store_duplicate) - dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( - self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) - ) + dfs_employee_duplicate = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("store_id") + ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) - ) + dfs_store_duplicate = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("district_id") + ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate) self.compare_spark_with_sqlglot(df, dfs) def test_intersect_all(self): - df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( - self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) - ) + df_employee_duplicate = self.df_spark_employee.select( + F.col("employee_id"), F.col("store_id") + ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( - self.df_spark_store.select(F.col("store_id"), F.col("district_id")) - ) + df_store_duplicate = self.df_spark_store.select( + F.col("store_id"), F.col("district_id") + ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) df = df_employee_duplicate.intersectAll(df_store_duplicate) - dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( - self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) - ) + dfs_employee_duplicate = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("store_id") + ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) - ) + dfs_store_duplicate = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("district_id") + ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate) self.compare_spark_with_sqlglot(df, dfs) def test_except_all(self): - df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union( - self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")) - ) + df_employee_duplicate = self.df_spark_employee.select( + F.col("employee_id"), F.col("store_id") + ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union( - self.df_spark_store.select(F.col("store_id"), F.col("district_id")) - ) + df_store_duplicate = self.df_spark_store.select( + F.col("store_id"), F.col("district_id") + ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) df = df_employee_duplicate.exceptAll(df_store_duplicate) - dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union( - self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")) - ) + dfs_employee_duplicate = self.df_sqlglot_employee.select( + SF.col("employee_id"), SF.col("store_id") + ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")) - ) + dfs_store_duplicate = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("district_id") + ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate) @@ -721,7 +796,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_drop_na_default(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna() + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).dropna() dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -746,7 +823,9 @@ class TestDataframeFunc(DataFrameValidator): ).dropna(how="any", thresh=2) dfs = self.df_sqlglot_employee.select( - SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + SF.lit(None), + SF.lit(1), + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"), ).dropna(how="any", thresh=2) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -757,13 +836,17 @@ class TestDataframeFunc(DataFrameValidator): ).dropna(thresh=1, subset="the_age") dfs = self.df_sqlglot_employee.select( - SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") + SF.lit(None), + SF.lit(1), + SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"), ).dropna(thresh=1, subset="the_age") self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) def test_dropna_na_function(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop() + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).na.drop() dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -772,7 +855,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_fillna_default(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100) + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).fillna(100) dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -798,7 +883,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) def test_fillna_na_func(self): - df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100) + df = self.df_spark_employee.select( + F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") + ).na.fill(100) dfs = self.df_sqlglot_employee.select( SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") @@ -807,7 +894,9 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs) def test_replace_basic(self): - df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100) + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace( + to_replace=37, value=100 + ) dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( to_replace=37, value=100 @@ -827,9 +916,13 @@ class TestDataframeFunc(DataFrameValidator): self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) def test_replace_mapping(self): - df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100}) + df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace( + {37: 100} + ) - dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100}) + dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( + {37: 100} + ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -849,9 +942,9 @@ class TestDataframeFunc(DataFrameValidator): to_replace=37, value=100 ) - dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace( - to_replace=37, value=100 - ) + dfs = self.df_sqlglot_employee.select( + SF.col("age"), SF.lit(37).alias("test_col") + ).na.replace(to_replace=37, value=100) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) @@ -881,16 +974,18 @@ class TestDataframeFunc(DataFrameValidator): "first_name", "first_name_again" ) - dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed( - "first_name", "first_name_again" - ) + dfs = self.df_sqlglot_employee.select( + SF.col("fname").alias("first_name") + ).withColumnRenamed("first_name", "first_name_again") self.compare_spark_with_sqlglot(df, dfs) def test_drop_column_single(self): df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age") - dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age") + dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop( + "age" + ) self.compare_spark_with_sqlglot(df, dfs) @@ -906,7 +1001,9 @@ class TestDataframeFunc(DataFrameValidator): df_sqlglot_employee_cols = self.df_sqlglot_employee.select( SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id") ) - df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")) + df_sqlglot_store_cols = self.df_sqlglot_store.select( + SF.col("store_id"), SF.col("store_name") + ) dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop( df_sqlglot_employee_cols.age, ) diff --git a/tests/dataframe/integration/test_session.py b/tests/dataframe/integration/test_session.py index ff1477b..ec50034 100644 --- a/tests/dataframe/integration/test_session.py +++ b/tests/dataframe/integration/test_session.py @@ -23,6 +23,14 @@ class TestSessionFunc(DataFrameValidator): ON e.store_id = s.store_id """ - df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id"))) - dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id"))) + df = ( + self.spark.sql(query) + .groupBy(F.col("store_id")) + .agg(F.countDistinct(F.col("employee_id"))) + ) + dfs = ( + self.sqlglot.sql(query) + .groupBy(SF.col("store_id")) + .agg(SF.countDistinct(SF.col("employee_id"))) + ) self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) -- cgit v1.2.3