diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-23 07:22:20 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-23 07:22:20 +0000 |
commit | 41e67f6ce6b4b732d02e421d6825c18b8d15a59d (patch) | |
tree | 30fb0000d3e6ff11b366567bc35564842e7dbb50 /tests/dataframe/integration | |
parent | Adding upstream version 23.16.0. (diff) | |
download | sqlglot-41e67f6ce6b4b732d02e421d6825c18b8d15a59d.tar.xz sqlglot-41e67f6ce6b4b732d02e421d6825c18b8d15a59d.zip |
Adding upstream version 24.0.0.upstream/24.0.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/dataframe/integration')
-rw-r--r-- | tests/dataframe/integration/__init__.py | 0 | ||||
-rw-r--r-- | tests/dataframe/integration/dataframe_validator.py | 174 | ||||
-rw-r--r-- | tests/dataframe/integration/test_dataframe.py | 1281 | ||||
-rw-r--r-- | tests/dataframe/integration/test_grouped_data.py | 71 | ||||
-rw-r--r-- | tests/dataframe/integration/test_session.py | 43 |
5 files changed, 0 insertions, 1569 deletions
diff --git a/tests/dataframe/integration/__init__.py b/tests/dataframe/integration/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/tests/dataframe/integration/__init__.py +++ /dev/null diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py deleted file mode 100644 index 22d4982..0000000 --- a/tests/dataframe/integration/dataframe_validator.py +++ /dev/null @@ -1,174 +0,0 @@ -import typing as t -import unittest -import warnings - -import sqlglot -from tests.helpers import SKIP_INTEGRATION - -if t.TYPE_CHECKING: - from pyspark.sql import DataFrame as SparkDataFrame - - -@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set") -class DataFrameValidator(unittest.TestCase): - spark = None - sqlglot = None - df_employee = None - df_store = None - df_district = None - spark_employee_schema = None - sqlglot_employee_schema = None - spark_store_schema = None - sqlglot_store_schema = None - spark_district_schema = None - sqlglot_district_schema = None - - @classmethod - def setUpClass(cls): - from pyspark import SparkConf - from pyspark.sql import SparkSession, types - - from sqlglot.dataframe.sql import types as sqlglotSparkTypes - from sqlglot.dataframe.sql.session import SparkSession as SqlglotSparkSession - - # This is for test `test_branching_root_dataframes` - config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")]) - cls.spark = ( - SparkSession.builder.master("local[*]") - .appName("Unit-tests") - .config(conf=config) - .getOrCreate() - ) - cls.spark.sparkContext.setLogLevel("ERROR") - cls.sqlglot = SqlglotSparkSession() - cls.spark_employee_schema = types.StructType( - [ - types.StructField("employee_id", types.IntegerType(), False), - types.StructField("fname", types.StringType(), False), - types.StructField("lname", types.StringType(), False), - types.StructField("age", types.IntegerType(), False), - types.StructField("store_id", types.IntegerType(), False), - ] - ) - cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType( - [ - sqlglotSparkTypes.StructField( - "employee_id", sqlglotSparkTypes.IntegerType(), False - ), - sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False), - sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False), - sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False), - sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False), - ] - ) - employee_data = [ - (1, "Jack", "Shephard", 37, 1), - (2, "John", "Locke", 65, 1), - (3, "Kate", "Austen", 37, 2), - (4, "Claire", "Littleton", 27, 2), - (5, "Hugo", "Reyes", 29, 100), - ] - cls.df_employee = cls.spark.createDataFrame( - data=employee_data, schema=cls.spark_employee_schema - ) - cls.dfs_employee = cls.sqlglot.createDataFrame( - data=employee_data, schema=cls.sqlglot_employee_schema - ) - cls.df_employee.createOrReplaceTempView("employee") - - cls.spark_store_schema = types.StructType( - [ - types.StructField("store_id", types.IntegerType(), False), - types.StructField("store_name", types.StringType(), False), - types.StructField("district_id", types.IntegerType(), False), - types.StructField("num_sales", types.IntegerType(), False), - ] - ) - cls.sqlglot_store_schema = sqlglotSparkTypes.StructType( - [ - sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False), - sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False), - sqlglotSparkTypes.StructField( - "district_id", sqlglotSparkTypes.IntegerType(), False - ), - sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False), - ] - ) - store_data = [ - (1, "Hydra", 1, 37), - (2, "Arrow", 2, 2000), - ] - cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema) - cls.dfs_store = cls.sqlglot.createDataFrame( - data=store_data, schema=cls.sqlglot_store_schema - ) - cls.df_store.createOrReplaceTempView("store") - - cls.spark_district_schema = types.StructType( - [ - types.StructField("district_id", types.IntegerType(), False), - types.StructField("district_name", types.StringType(), False), - types.StructField("manager_name", types.StringType(), False), - ] - ) - cls.sqlglot_district_schema = sqlglotSparkTypes.StructType( - [ - sqlglotSparkTypes.StructField( - "district_id", sqlglotSparkTypes.IntegerType(), False - ), - sqlglotSparkTypes.StructField( - "district_name", sqlglotSparkTypes.StringType(), False - ), - sqlglotSparkTypes.StructField( - "manager_name", sqlglotSparkTypes.StringType(), False - ), - ] - ) - district_data = [ - (1, "Temple", "Dogen"), - (2, "Lighthouse", "Jacob"), - ] - cls.df_district = cls.spark.createDataFrame( - data=district_data, schema=cls.spark_district_schema - ) - cls.dfs_district = cls.sqlglot.createDataFrame( - data=district_data, schema=cls.sqlglot_district_schema - ) - cls.df_district.createOrReplaceTempView("district") - sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema, dialect="spark") - sqlglot.schema.add_table("store", cls.sqlglot_store_schema, dialect="spark") - sqlglot.schema.add_table("district", cls.sqlglot_district_schema, dialect="spark") - - def setUp(self) -> None: - warnings.filterwarnings("ignore", category=ResourceWarning) - self.df_spark_store = self.df_store.alias("df_store") # type: ignore - self.df_spark_employee = self.df_employee.alias("df_employee") # type: ignore - self.df_spark_district = self.df_district.alias("df_district") # type: ignore - self.df_sqlglot_store = self.dfs_store.alias("store") # type: ignore - self.df_sqlglot_employee = self.dfs_employee.alias("employee") # type: ignore - self.df_sqlglot_district = self.dfs_district.alias("district") # type: ignore - - def compare_spark_with_sqlglot( - self, df_spark, df_sqlglot, no_empty=True, skip_schema_compare=False - ) -> t.Tuple["SparkDataFrame", "SparkDataFrame"]: - def compare_schemas(schema_1, schema_2): - for schema in [schema_1, schema_2]: - for struct_field in schema.fields: - struct_field.metadata = {} - self.assertEqual(schema_1, schema_2) - - for statement in df_sqlglot.sql(): - actual_df_sqlglot = self.spark.sql(statement) # type: ignore - df_sqlglot_results = actual_df_sqlglot.collect() - df_spark_results = df_spark.collect() - if not skip_schema_compare: - compare_schemas(df_spark.schema, actual_df_sqlglot.schema) - self.assertEqual(df_spark_results, df_sqlglot_results) - if no_empty: - self.assertNotEqual(len(df_spark_results), 0) - self.assertNotEqual(len(df_sqlglot_results), 0) - return df_spark, actual_df_sqlglot - - @classmethod - def get_explain_plan(cls, df: "SparkDataFrame", mode: str = "extended") -> str: - return df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), mode) # type: ignore diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py deleted file mode 100644 index 702c6ee..0000000 --- a/tests/dataframe/integration/test_dataframe.py +++ /dev/null @@ -1,1281 +0,0 @@ -from pyspark.sql import functions as F - -from sqlglot.dataframe.sql import functions as SF -from tests.dataframe.integration.dataframe_validator import DataFrameValidator - - -class TestDataframeFunc(DataFrameValidator): - def test_simple_select(self): - df_employee = self.df_spark_employee.select(F.col("employee_id")) - dfs_employee = self.df_sqlglot_employee.select(SF.col("employee_id")) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - def test_simple_select_from_table(self): - df = self.df_spark_employee - dfs = self.sqlglot.read.table("employee") - self.compare_spark_with_sqlglot(df, dfs) - - def test_simple_select_df_attribute(self): - df_employee = self.df_spark_employee.select(self.df_spark_employee.employee_id) - dfs_employee = self.df_sqlglot_employee.select(self.df_sqlglot_employee.employee_id) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - def test_simple_select_df_dict(self): - df_employee = self.df_spark_employee.select(self.df_spark_employee["employee_id"]) - dfs_employee = self.df_sqlglot_employee.select(self.df_sqlglot_employee["employee_id"]) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - def test_multiple_selects(self): - df_employee = self.df_spark_employee.select( - self.df_spark_employee["employee_id"], F.col("fname"), self.df_spark_employee.lname - ) - dfs_employee = self.df_sqlglot_employee.select( - self.df_sqlglot_employee["employee_id"], SF.col("fname"), self.df_sqlglot_employee.lname - ) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - def test_alias_no_op(self): - df_employee = self.df_spark_employee.alias("df_employee") - dfs_employee = self.df_sqlglot_employee.alias("dfs_employee") - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - def test_alias_with_select(self): - df_employee = self.df_spark_employee.alias("df_employee").select( - self.df_spark_employee["employee_id"], - F.col("df_employee.fname"), - self.df_spark_employee.lname, - ) - dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select( - self.df_sqlglot_employee["employee_id"], - SF.col("dfs_employee.fname"), - self.df_sqlglot_employee.lname, - ) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - def test_case_when_otherwise(self): - df = self.df_spark_employee.select( - F.when( - (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), - F.lit("between 40 and 60"), - ) - .when(F.col("age") < F.lit(40), "less than 40") - .otherwise("greater than 60") - ) - - dfs = self.df_sqlglot_employee.select( - SF.when( - (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), - SF.lit("between 40 and 60"), - ) - .when(SF.col("age") < SF.lit(40), "less than 40") - .otherwise("greater than 60") - ) - - self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) - - def test_case_when_no_otherwise(self): - df = self.df_spark_employee.select( - F.when( - (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), - F.lit("between 40 and 60"), - ).when(F.col("age") < F.lit(40), "less than 40") - ) - - dfs = self.df_sqlglot_employee.select( - SF.when( - (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), - SF.lit("between 40 and 60"), - ).when(SF.col("age") < SF.lit(40), "less than 40") - ) - - self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) - - def test_where_clause_single(self): - df_employee = self.df_spark_employee.where(F.col("age") == F.lit(37)) - dfs_employee = self.df_sqlglot_employee.where(SF.col("age") == SF.lit(37)) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - def test_where_clause_multiple_and(self): - df_employee = self.df_spark_employee.where( - (F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack")) - ) - dfs_employee = self.df_sqlglot_employee.where( - (SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack")) - ) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - def test_where_many_and(self): - df_employee = self.df_spark_employee.where( - (F.col("age") == F.lit(37)) - & (F.col("fname") == F.lit("Jack")) - & (F.col("lname") == F.lit("Shephard")) - & (F.col("employee_id") == F.lit(1)) - ) - dfs_employee = self.df_sqlglot_employee.where( - (SF.col("age") == SF.lit(37)) - & (SF.col("fname") == SF.lit("Jack")) - & (SF.col("lname") == SF.lit("Shephard")) - & (SF.col("employee_id") == SF.lit(1)) - ) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - def test_where_clause_multiple_or(self): - df_employee = self.df_spark_employee.where( - (F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate")) - ) - dfs_employee = self.df_sqlglot_employee.where( - (SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate")) - ) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - def test_where_many_or(self): - df_employee = self.df_spark_employee.where( - (F.col("age") == F.lit(37)) - | (F.col("fname") == F.lit("Kate")) - | (F.col("lname") == F.lit("Littleton")) - | (F.col("employee_id") == F.lit(2)) - ) - dfs_employee = self.df_sqlglot_employee.where( - (SF.col("age") == SF.lit(37)) - | (SF.col("fname") == SF.lit("Kate")) - | (SF.col("lname") == SF.lit("Littleton")) - | (SF.col("employee_id") == SF.lit(2)) - ) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - def test_where_mixed_and_or(self): - df_employee = self.df_spark_employee.where( - ((F.col("age") == F.lit(65)) & (F.col("fname") == F.lit("John"))) - | ((F.col("lname") == F.lit("Shephard")) & (F.col("age") == F.lit(37))) - ) - dfs_employee = self.df_sqlglot_employee.where( - ((SF.col("age") == SF.lit(65)) & (SF.col("fname") == SF.lit("John"))) - | ((SF.col("lname") == SF.lit("Shephard")) & (SF.col("age") == SF.lit(37))) - ) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - def test_where_multiple_chained(self): - df_employee = self.df_spark_employee.where(F.col("age") == F.lit(37)).where( - self.df_spark_employee.fname == F.lit("Jack") - ) - dfs_employee = self.df_sqlglot_employee.where(SF.col("age") == SF.lit(37)).where( - self.df_sqlglot_employee.fname == SF.lit("Jack") - ) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - def test_operators(self): - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] < F.lit(50)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] < SF.lit(50)) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] <= F.lit(37)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] <= SF.lit(37)) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] > F.lit(50)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] > SF.lit(50)) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] >= F.lit(37)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] >= SF.lit(37)) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] != F.lit(50)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] != SF.lit(50)) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] == F.lit(37)) - dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37)) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - df_employee = self.df_spark_employee.where( - self.df_spark_employee["age"] % F.lit(5) == F.lit(0) - ) - dfs_employee = self.df_sqlglot_employee.where( - self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0) - ) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - df_employee = self.df_spark_employee.where( - self.df_spark_employee["age"] + F.lit(5) > F.lit(28) - ) - dfs_employee = self.df_sqlglot_employee.where( - self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28) - ) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - df_employee = self.df_spark_employee.where( - self.df_spark_employee["age"] - F.lit(5) > F.lit(28) - ) - dfs_employee = self.df_sqlglot_employee.where( - self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28) - ) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - df_employee = self.df_spark_employee.where( - self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2) - ) - dfs_employee = self.df_sqlglot_employee.where( - self.df_sqlglot_employee["age"] * SF.lit(0.5) - == self.df_sqlglot_employee["age"] / SF.lit(2) - ) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - def test_join_inner(self): - df_joined = self.df_spark_employee.join( - self.df_spark_store, on=["store_id"], how="inner" - ).select( - self.df_spark_employee.employee_id, - self.df_spark_employee["fname"], - F.col("lname"), - F.col("age"), - F.col("store_id"), - self.df_spark_store.store_name, - self.df_spark_store["num_sales"], - ) - dfs_joined = self.df_sqlglot_employee.join( - self.df_sqlglot_store, on=["store_id"], how="inner" - ).select( - self.df_sqlglot_employee.employee_id, - self.df_sqlglot_employee["fname"], - SF.col("lname"), - SF.col("age"), - SF.col("store_id"), - self.df_sqlglot_store.store_name, - self.df_sqlglot_store["num_sales"], - ) - self.compare_spark_with_sqlglot(df_joined, dfs_joined) - - def test_join_inner_no_select(self): - df_joined = self.df_spark_employee.select( - F.col("store_id"), F.col("fname"), F.col("lname") - ).join( - self.df_spark_store.select(F.col("store_id"), F.col("store_name")), - on=["store_id"], - how="inner", - ) - dfs_joined = self.df_sqlglot_employee.select( - SF.col("store_id"), SF.col("fname"), SF.col("lname") - ).join( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), - on=["store_id"], - how="inner", - ) - self.compare_spark_with_sqlglot(df_joined, dfs_joined) - - def test_join_inner_equality_single(self): - df_joined = self.df_spark_employee.join( - self.df_spark_store, - on=self.df_spark_employee.store_id == self.df_spark_store.store_id, - how="inner", - ).select( - self.df_spark_employee.employee_id, - self.df_spark_employee["fname"], - F.col("lname"), - F.col("age"), - self.df_spark_employee.store_id, - self.df_spark_store.store_name, - self.df_spark_store["num_sales"], - F.lit("literal_value"), - ) - dfs_joined = self.df_sqlglot_employee.join( - self.df_sqlglot_store, - on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, - how="inner", - ).select( - self.df_sqlglot_employee.employee_id, - self.df_sqlglot_employee["fname"], - SF.col("lname"), - SF.col("age"), - self.df_sqlglot_employee.store_id, - self.df_sqlglot_store.store_name, - self.df_sqlglot_store["num_sales"], - SF.lit("literal_value"), - ) - self.compare_spark_with_sqlglot(df_joined, dfs_joined) - - def test_join_inner_equality_multiple(self): - df_joined = self.df_spark_employee.join( - self.df_spark_store, - on=[ - self.df_spark_employee.store_id == self.df_spark_store.store_id, - self.df_spark_employee.age == self.df_spark_store.num_sales, - ], - how="inner", - ).select( - self.df_spark_employee.employee_id, - self.df_spark_employee["fname"], - F.col("lname"), - F.col("age"), - self.df_spark_employee.store_id, - self.df_spark_store.store_name, - self.df_spark_store["num_sales"], - ) - dfs_joined = self.df_sqlglot_employee.join( - self.df_sqlglot_store, - on=[ - self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, - self.df_sqlglot_employee.age == self.df_sqlglot_store.num_sales, - ], - how="inner", - ).select( - self.df_sqlglot_employee.employee_id, - self.df_sqlglot_employee["fname"], - SF.col("lname"), - SF.col("age"), - self.df_sqlglot_employee.store_id, - self.df_sqlglot_store.store_name, - self.df_sqlglot_store["num_sales"], - ) - self.compare_spark_with_sqlglot(df_joined, dfs_joined) - - def test_join_inner_equality_multiple_bitwise_and(self): - df_joined = self.df_spark_employee.join( - self.df_spark_store, - on=(self.df_spark_store.store_id == self.df_spark_employee.store_id) - & (self.df_spark_store.num_sales == self.df_spark_employee.age), - how="inner", - ).select( - self.df_spark_employee.employee_id, - self.df_spark_employee["fname"], - F.col("lname"), - F.col("age"), - self.df_spark_employee.store_id, - self.df_spark_store.store_name, - self.df_spark_store["num_sales"], - ) - dfs_joined = self.df_sqlglot_employee.join( - self.df_sqlglot_store, - on=(self.df_sqlglot_store.store_id == self.df_sqlglot_employee.store_id) - & (self.df_sqlglot_store.num_sales == self.df_sqlglot_employee.age), - how="inner", - ).select( - self.df_sqlglot_employee.employee_id, - self.df_sqlglot_employee["fname"], - SF.col("lname"), - SF.col("age"), - self.df_sqlglot_employee.store_id, - self.df_sqlglot_store.store_name, - self.df_sqlglot_store["num_sales"], - ) - self.compare_spark_with_sqlglot(df_joined, dfs_joined) - - def test_join_left_outer(self): - df_joined = ( - self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="left_outer") - .select( - self.df_spark_employee.employee_id, - self.df_spark_employee["fname"], - F.col("lname"), - F.col("age"), - F.col("store_id"), - self.df_spark_store.store_name, - self.df_spark_store["num_sales"], - ) - .orderBy(F.col("employee_id")) - ) - dfs_joined = ( - self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="left_outer") - .select( - self.df_sqlglot_employee.employee_id, - self.df_sqlglot_employee["fname"], - SF.col("lname"), - SF.col("age"), - SF.col("store_id"), - self.df_sqlglot_store.store_name, - self.df_sqlglot_store["num_sales"], - ) - .orderBy(SF.col("employee_id")) - ) - self.compare_spark_with_sqlglot(df_joined, dfs_joined) - - def test_join_full_outer(self): - df_joined = self.df_spark_employee.join( - self.df_spark_store, on=["store_id"], how="full_outer" - ).select( - self.df_spark_employee.employee_id, - self.df_spark_employee["fname"], - F.col("lname"), - F.col("age"), - F.col("store_id"), - self.df_spark_store.store_name, - self.df_spark_store["num_sales"], - ) - dfs_joined = self.df_sqlglot_employee.join( - self.df_sqlglot_store, on=["store_id"], how="full_outer" - ).select( - self.df_sqlglot_employee.employee_id, - self.df_sqlglot_employee["fname"], - SF.col("lname"), - SF.col("age"), - SF.col("store_id"), - self.df_sqlglot_store.store_name, - self.df_sqlglot_store["num_sales"], - ) - self.compare_spark_with_sqlglot(df_joined, dfs_joined) - - def test_triple_join(self): - df = ( - self.df_employee.join( - self.df_store, on=self.df_employee.employee_id == self.df_store.store_id - ) - .join(self.df_district, on=self.df_store.store_id == self.df_district.district_id) - .select( - self.df_employee.employee_id, - self.df_store.store_id, - self.df_district.district_id, - self.df_employee.fname, - self.df_store.store_name, - self.df_district.district_name, - ) - ) - dfs = ( - self.dfs_employee.join( - self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id - ) - .join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id) - .select( - self.dfs_employee.employee_id, - self.dfs_store.store_id, - self.dfs_district.district_id, - self.dfs_employee.fname, - self.dfs_store.store_name, - self.dfs_district.district_name, - ) - ) - self.compare_spark_with_sqlglot(df, dfs) - - def test_triple_join_no_select(self): - df = ( - self.df_employee.join( - self.df_store, - on=self.df_employee["employee_id"] == self.df_store["store_id"], - how="left", - ) - .join( - self.df_district, - on=self.df_store["store_id"] == self.df_district["district_id"], - how="left", - ) - .orderBy(F.col("employee_id")) - ) - dfs = ( - self.dfs_employee.join( - self.dfs_store, - on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"], - how="left", - ) - .join( - self.dfs_district, - on=self.dfs_store["store_id"] == self.dfs_district["district_id"], - how="left", - ) - .orderBy(SF.col("employee_id")) - ) - self.compare_spark_with_sqlglot(df, dfs) - - def test_triple_joins_filter(self): - df = ( - self.df_employee.join( - self.df_store, - on=self.df_employee["employee_id"] == self.df_store["store_id"], - how="left", - ).join( - self.df_district, - on=self.df_store["store_id"] == self.df_district["district_id"], - how="left", - ) - ).filter(F.coalesce(self.df_store["num_sales"], F.lit(0)) > 100) - dfs = ( - self.dfs_employee.join( - self.dfs_store, - on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"], - how="left", - ).join( - self.dfs_district, - on=self.dfs_store["store_id"] == self.dfs_district["district_id"], - how="left", - ) - ).filter(SF.coalesce(self.dfs_store["num_sales"], SF.lit(0)) > 100) - self.compare_spark_with_sqlglot(df, dfs) - - def test_triple_join_column_name_only(self): - df = ( - self.df_employee.join( - self.df_store, - on=self.df_employee["employee_id"] == self.df_store["store_id"], - how="left", - ) - .join(self.df_district, on="district_id", how="left") - .orderBy(F.col("employee_id")) - ) - dfs = ( - self.dfs_employee.join( - self.dfs_store, - on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"], - how="left", - ) - .join(self.dfs_district, on="district_id", how="left") - .orderBy(SF.col("employee_id")) - ) - self.compare_spark_with_sqlglot(df, dfs) - - def test_join_select_and_select_start(self): - df = self.df_spark_employee.select( - F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id") - ).join(self.df_spark_store, "store_id", "inner") - - dfs = self.df_sqlglot_employee.select( - SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id") - ).join(self.df_sqlglot_store, "store_id", "inner") - - self.compare_spark_with_sqlglot(df, dfs) - - def test_branching_root_dataframes(self): - """ - Test a pattern that has non-intuitive behavior in spark - - Scenario: You do a self-join in a dataframe using an original dataframe and then a modified version - of it. You then reference the columns by the dataframe name instead of the column function. - Spark will use the root dataframe's column in the result. - """ - df_hydra_employees_only = self.df_spark_employee.where(F.col("store_id") == F.lit(1)) - df_joined = ( - self.df_spark_employee.where(F.col("store_id") == F.lit(2)) - .alias("df_arrow_employees_only") - .join( - df_hydra_employees_only.alias("df_hydra_employees_only"), - on=["store_id"], - how="full_outer", - ) - .select( - self.df_spark_employee.fname, - F.col("df_arrow_employees_only.fname"), - df_hydra_employees_only.fname, - F.col("df_hydra_employees_only.fname"), - ) - ) - - dfs_hydra_employees_only = self.df_sqlglot_employee.where(SF.col("store_id") == SF.lit(1)) - dfs_joined = ( - self.df_sqlglot_employee.where(SF.col("store_id") == SF.lit(2)) - .alias("dfs_arrow_employees_only") - .join( - dfs_hydra_employees_only.alias("dfs_hydra_employees_only"), - on=["store_id"], - how="full_outer", - ) - .select( - self.df_sqlglot_employee.fname, - SF.col("dfs_arrow_employees_only.fname"), - dfs_hydra_employees_only.fname, - SF.col("dfs_hydra_employees_only.fname"), - ) - ) - self.compare_spark_with_sqlglot(df_joined, dfs_joined) - - def test_basic_union(self): - df_unioned = self.df_spark_employee.select(F.col("employee_id"), F.col("age")).union( - self.df_spark_store.select(F.col("store_id"), F.col("num_sales")) - ) - - dfs_unioned = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("age")).union( - self.df_sqlglot_store.select(SF.col("store_id"), SF.col("num_sales")) - ) - self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) - - def test_union_with_join(self): - df_joined = self.df_spark_employee.join( - self.df_spark_store, - on="store_id", - how="inner", - ) - df_unioned = df_joined.select(F.col("store_id"), F.col("store_name")).union( - self.df_spark_district.select(F.col("district_id"), F.col("district_name")) - ) - - dfs_joined = self.df_sqlglot_employee.join( - self.df_sqlglot_store, - on="store_id", - how="inner", - ) - dfs_unioned = dfs_joined.select(SF.col("store_id"), SF.col("store_name")).union( - self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")) - ) - - self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) - - def test_double_union_all(self): - df_unioned = ( - self.df_spark_employee.select(F.col("employee_id"), F.col("fname")) - .unionAll(self.df_spark_store.select(F.col("store_id"), F.col("store_name"))) - .unionAll(self.df_spark_district.select(F.col("district_id"), F.col("district_name"))) - ) - - dfs_unioned = ( - self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname")) - .unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))) - .unionAll( - self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")) - ) - ) - - self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) - - def test_union_by_name(self): - df = self.df_spark_employee.select( - F.col("employee_id"), F.col("fname"), F.col("lname") - ).unionByName( - self.df_spark_store.select( - F.col("store_name").alias("lname"), - F.col("store_id").alias("employee_id"), - F.col("store_name").alias("fname"), - ) - ) - - dfs = self.df_sqlglot_employee.select( - SF.col("employee_id"), SF.col("fname"), SF.col("lname") - ).unionByName( - self.df_sqlglot_store.select( - SF.col("store_name").alias("lname"), - SF.col("store_id").alias("employee_id"), - SF.col("store_name").alias("fname"), - ) - ) - - self.compare_spark_with_sqlglot(df, dfs) - - def test_union_by_name_allow_missing(self): - df = self.df_spark_employee.select( - F.col("age"), F.col("employee_id"), F.col("fname"), F.col("lname") - ).unionByName( - self.df_spark_store.select( - F.col("store_name").alias("lname"), - F.col("store_id").alias("employee_id"), - F.col("store_name").alias("fname"), - F.col("num_sales"), - ), - allowMissingColumns=True, - ) - - dfs = self.df_sqlglot_employee.select( - SF.col("age"), SF.col("employee_id"), SF.col("fname"), SF.col("lname") - ).unionByName( - self.df_sqlglot_store.select( - SF.col("store_name").alias("lname"), - SF.col("store_id").alias("employee_id"), - SF.col("store_name").alias("fname"), - SF.col("num_sales"), - ), - allowMissingColumns=True, - ) - - self.compare_spark_with_sqlglot(df, dfs) - - def test_order_by_default(self): - df = ( - self.df_spark_store.groupBy(F.col("district_id")) - .agg(F.min("num_sales")) - .orderBy(F.col("district_id")) - ) - - dfs = ( - self.df_sqlglot_store.groupBy(SF.col("district_id")) - .agg(SF.min("num_sales")) - .orderBy(SF.col("district_id")) - ) - - self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) - - def test_order_by_array_bool(self): - df = ( - self.df_spark_store.groupBy(F.col("district_id")) - .agg(F.min("num_sales").alias("total_sales")) - .orderBy(F.col("total_sales"), F.col("district_id"), ascending=[1, 0]) - ) - - dfs = ( - self.df_sqlglot_store.groupBy(SF.col("district_id")) - .agg(SF.min("num_sales").alias("total_sales")) - .orderBy(SF.col("total_sales"), SF.col("district_id"), ascending=[1, 0]) - ) - - self.compare_spark_with_sqlglot(df, dfs) - - def test_order_by_single_bool(self): - df = ( - self.df_spark_store.groupBy(F.col("district_id")) - .agg(F.min("num_sales").alias("total_sales")) - .orderBy(F.col("total_sales"), F.col("district_id"), ascending=False) - ) - - dfs = ( - self.df_sqlglot_store.groupBy(SF.col("district_id")) - .agg(SF.min("num_sales").alias("total_sales")) - .orderBy(SF.col("total_sales"), SF.col("district_id"), ascending=False) - ) - - self.compare_spark_with_sqlglot(df, dfs) - - def test_order_by_column_sort_method(self): - df = ( - self.df_spark_store.groupBy(F.col("district_id")) - .agg(F.min("num_sales").alias("total_sales")) - .orderBy(F.col("total_sales").asc(), F.col("district_id").desc()) - ) - - dfs = ( - self.df_sqlglot_store.groupBy(SF.col("district_id")) - .agg(SF.min("num_sales").alias("total_sales")) - .orderBy(SF.col("total_sales").asc(), SF.col("district_id").desc()) - ) - - self.compare_spark_with_sqlglot(df, dfs) - - def test_order_by_column_sort_method_nulls_last(self): - df = ( - self.df_spark_store.groupBy(F.col("district_id")) - .agg(F.min("num_sales").alias("total_sales")) - .orderBy( - F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last() - ) - ) - - dfs = ( - self.df_sqlglot_store.groupBy(SF.col("district_id")) - .agg(SF.min("num_sales").alias("total_sales")) - .orderBy( - SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last() - ) - ) - - self.compare_spark_with_sqlglot(df, dfs) - - def test_order_by_column_sort_method_nulls_first(self): - df = ( - self.df_spark_store.groupBy(F.col("district_id")) - .agg(F.min("num_sales").alias("total_sales")) - .orderBy( - F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first() - ) - ) - - dfs = ( - self.df_sqlglot_store.groupBy(SF.col("district_id")) - .agg(SF.min("num_sales").alias("total_sales")) - .orderBy( - SF.when( - SF.col("district_id") == SF.lit(1), SF.col("district_id") - ).desc_nulls_first() - ) - ) - - self.compare_spark_with_sqlglot(df, dfs) - - def test_intersect(self): - df_employee_duplicate = self.df_spark_employee.select( - F.col("employee_id"), F.col("store_id") - ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - - df_store_duplicate = self.df_spark_store.select( - F.col("store_id"), F.col("district_id") - ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) - - df = df_employee_duplicate.intersect(df_store_duplicate) - - dfs_employee_duplicate = self.df_sqlglot_employee.select( - SF.col("employee_id"), SF.col("store_id") - ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - - dfs_store_duplicate = self.df_sqlglot_store.select( - SF.col("store_id"), SF.col("district_id") - ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) - - dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate) - - self.compare_spark_with_sqlglot(df, dfs) - - def test_intersect_all(self): - df_employee_duplicate = self.df_spark_employee.select( - F.col("employee_id"), F.col("store_id") - ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - - df_store_duplicate = self.df_spark_store.select( - F.col("store_id"), F.col("district_id") - ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) - - df = df_employee_duplicate.intersectAll(df_store_duplicate) - - dfs_employee_duplicate = self.df_sqlglot_employee.select( - SF.col("employee_id"), SF.col("store_id") - ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - - dfs_store_duplicate = self.df_sqlglot_store.select( - SF.col("store_id"), SF.col("district_id") - ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) - - dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate) - - self.compare_spark_with_sqlglot(df, dfs) - - def test_except_all(self): - df_employee_duplicate = self.df_spark_employee.select( - F.col("employee_id"), F.col("store_id") - ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))) - - df_store_duplicate = self.df_spark_store.select( - F.col("store_id"), F.col("district_id") - ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id"))) - - df = df_employee_duplicate.exceptAll(df_store_duplicate) - - dfs_employee_duplicate = self.df_sqlglot_employee.select( - SF.col("employee_id"), SF.col("store_id") - ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))) - - dfs_store_duplicate = self.df_sqlglot_store.select( - SF.col("store_id"), SF.col("district_id") - ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))) - - dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate) - - self.compare_spark_with_sqlglot(df, dfs) - - def test_distinct(self): - df = self.df_spark_employee.select(F.col("age")).distinct() - - dfs = self.df_sqlglot_employee.select(SF.col("age")).distinct() - - self.compare_spark_with_sqlglot(df, dfs) - - def test_union_distinct(self): - df_unioned = ( - self.df_spark_employee.select(F.col("employee_id"), F.col("age")) - .union(self.df_spark_employee.select(F.col("employee_id"), F.col("age"))) - .distinct() - ) - - dfs_unioned = ( - self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("age")) - .union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("age"))) - .distinct() - ) - self.compare_spark_with_sqlglot(df_unioned, dfs_unioned) - - def test_drop_duplicates_no_subset(self): - df = self.df_spark_employee.select("age").dropDuplicates() - dfs = self.df_sqlglot_employee.select("age").dropDuplicates() - self.compare_spark_with_sqlglot(df, dfs) - - def test_drop_duplicates_subset(self): - df = self.df_spark_employee.dropDuplicates(["age"]) - dfs = self.df_sqlglot_employee.dropDuplicates(["age"]) - self.compare_spark_with_sqlglot(df, dfs) - - def test_drop_na_default(self): - df = self.df_spark_employee.select( - F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") - ).dropna() - - dfs = self.df_sqlglot_employee.select( - SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") - ).dropna() - - self.compare_spark_with_sqlglot(df, dfs) - - def test_dropna_how(self): - df = self.df_spark_employee.select( - F.lit(None), F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") - ).dropna(how="all") - - dfs = self.df_sqlglot_employee.select( - SF.lit(None), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") - ).dropna(how="all") - - self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) - - def test_dropna_thresh(self): - df = self.df_spark_employee.select( - F.lit(None), F.lit(1), F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") - ).dropna(how="any", thresh=2) - - dfs = self.df_sqlglot_employee.select( - SF.lit(None), - SF.lit(1), - SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"), - ).dropna(how="any", thresh=2) - - self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) - - def test_dropna_subset(self): - df = self.df_spark_employee.select( - F.lit(None), F.lit(1), F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") - ).dropna(thresh=1, subset="the_age") - - dfs = self.df_sqlglot_employee.select( - SF.lit(None), - SF.lit(1), - SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"), - ).dropna(thresh=1, subset="the_age") - - self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) - - def test_dropna_na_function(self): - df = self.df_spark_employee.select( - F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") - ).na.drop() - - dfs = self.df_sqlglot_employee.select( - SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") - ).na.drop() - - self.compare_spark_with_sqlglot(df, dfs) - - def test_fillna_default(self): - df = self.df_spark_employee.select( - F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") - ).fillna(100) - - dfs = self.df_sqlglot_employee.select( - SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") - ).fillna(100) - - self.compare_spark_with_sqlglot(df, dfs) - - def test_fillna_dict_replacement(self): - df = self.df_spark_employee.select( - F.col("fname"), - F.when(F.col("lname").startswith("L"), F.col("lname")).alias("l_lname"), - F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age"), - ).fillna({"fname": "Jacob", "l_lname": "NOT_LNAME"}) - - dfs = self.df_sqlglot_employee.select( - SF.col("fname"), - SF.when(SF.col("lname").startswith("L"), SF.col("lname")).alias("l_lname"), - SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"), - ).fillna({"fname": "Jacob", "l_lname": "NOT_LNAME"}) - - # For some reason the sqlglot results sets a column as nullable when it doesn't need to - # This seems to be a nuance in how spark dataframe from sql works so we can ignore - self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) - - def test_fillna_na_func(self): - df = self.df_spark_employee.select( - F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age") - ).na.fill(100) - - dfs = self.df_sqlglot_employee.select( - SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age") - ).na.fill(100) - - self.compare_spark_with_sqlglot(df, dfs) - - def test_replace_basic(self): - df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace( - to_replace=37, value=100 - ) - - dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( - to_replace=37, value=100 - ) - - self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) - - def test_replace_basic_subset(self): - df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace( - to_replace=37, value=100, subset="age" - ) - - dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( - to_replace=37, value=100, subset="age" - ) - - self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) - - def test_replace_mapping(self): - df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace( - {37: 100} - ) - - dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace( - {37: 100} - ) - - self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) - - def test_replace_mapping_subset(self): - df = self.df_spark_employee.select( - F.col("age"), F.lit(37).alias("test_col"), F.lit(50).alias("test_col_2") - ).replace({37: 100, 50: 1}, subset=["age", "test_col_2"]) - - dfs = self.df_sqlglot_employee.select( - SF.col("age"), SF.lit(37).alias("test_col"), SF.lit(50).alias("test_col_2") - ).replace({37: 100, 50: 1}, subset=["age", "test_col_2"]) - - self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) - - def test_replace_na_func_basic(self): - df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).na.replace( - to_replace=37, value=100 - ) - - dfs = self.df_sqlglot_employee.select( - SF.col("age"), SF.lit(37).alias("test_col") - ).na.replace(to_replace=37, value=100) - - self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) - - def test_with_column(self): - df = self.df_spark_employee.withColumn("test", F.col("age")) - - dfs = self.df_sqlglot_employee.withColumn("test", SF.col("age")) - - self.compare_spark_with_sqlglot(df, dfs) - - def test_with_column_existing_name(self): - df = self.df_spark_employee.withColumn("fname", F.lit("blah")) - - dfs = self.df_sqlglot_employee.withColumn("fname", SF.lit("blah")) - - self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) - - def test_with_column_renamed(self): - df = self.df_spark_employee.withColumnRenamed("fname", "first_name") - - dfs = self.df_sqlglot_employee.withColumnRenamed("fname", "first_name") - - self.compare_spark_with_sqlglot(df, dfs) - - def test_with_column_renamed_double(self): - df = self.df_spark_employee.select(F.col("fname").alias("first_name")).withColumnRenamed( - "first_name", "first_name_again" - ) - - dfs = self.df_sqlglot_employee.select( - SF.col("fname").alias("first_name") - ).withColumnRenamed("first_name", "first_name_again") - - self.compare_spark_with_sqlglot(df, dfs) - - def test_drop_column_single(self): - df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age") - - dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop( - "age" - ) - - self.compare_spark_with_sqlglot(df, dfs) - - def test_drop_column_reference_join(self): - df_spark_employee_cols = self.df_spark_employee.select( - F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id") - ) - df_spark_store_cols = self.df_spark_store.select(F.col("store_id"), F.col("store_name")) - df = df_spark_employee_cols.join(df_spark_store_cols, on="store_id", how="inner").drop( - df_spark_employee_cols.age, - ) - - df_sqlglot_employee_cols = self.df_sqlglot_employee.select( - SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id") - ) - df_sqlglot_store_cols = self.df_sqlglot_store.select( - SF.col("store_id"), SF.col("store_name") - ) - dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop( - df_sqlglot_employee_cols.age, - ) - - self.compare_spark_with_sqlglot(df, dfs) - - def test_limit(self): - df = self.df_spark_employee.limit(1) - - dfs = self.df_sqlglot_employee.limit(1) - - self.compare_spark_with_sqlglot(df, dfs) - - def test_hint_broadcast_alias(self): - df_joined = self.df_spark_employee.join( - self.df_spark_store.alias("store").hint("broadcast", "store"), - on=self.df_spark_employee.store_id == self.df_spark_store.store_id, - how="inner", - ).select( - self.df_spark_employee.employee_id, - self.df_spark_employee["fname"], - F.col("lname"), - F.col("age"), - self.df_spark_employee.store_id, - self.df_spark_store.store_name, - self.df_spark_store["num_sales"], - ) - dfs_joined = self.df_sqlglot_employee.join( - self.df_sqlglot_store.alias("store").hint("broadcast", "store"), - on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, - how="inner", - ).select( - self.df_sqlglot_employee.employee_id, - self.df_sqlglot_employee["fname"], - SF.col("lname"), - SF.col("age"), - self.df_sqlglot_employee.store_id, - self.df_sqlglot_store.store_name, - self.df_sqlglot_store["num_sales"], - ) - df, dfs = self.compare_spark_with_sqlglot(df_joined, dfs_joined) - self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(df)) - self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(dfs)) - - def test_hint_broadcast_no_alias(self): - df_joined = self.df_spark_employee.join( - self.df_spark_store.hint("broadcast"), - on=self.df_spark_employee.store_id == self.df_spark_store.store_id, - how="inner", - ).select( - self.df_spark_employee.employee_id, - self.df_spark_employee["fname"], - F.col("lname"), - F.col("age"), - self.df_spark_employee.store_id, - self.df_spark_store.store_name, - self.df_spark_store["num_sales"], - ) - dfs_joined = self.df_sqlglot_employee.join( - self.df_sqlglot_store.hint("broadcast"), - on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, - how="inner", - ).select( - self.df_sqlglot_employee.employee_id, - self.df_sqlglot_employee["fname"], - SF.col("lname"), - SF.col("age"), - self.df_sqlglot_employee.store_id, - self.df_sqlglot_store.store_name, - self.df_sqlglot_store["num_sales"], - ) - df, dfs = self.compare_spark_with_sqlglot(df_joined, dfs_joined) - self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(df)) - self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(dfs)) - self.assertEqual( - "'UnresolvedHint BROADCAST, ['a2]", self.get_explain_plan(dfs).split("\n")[1] - ) - - def test_broadcast_func(self): - df_joined = self.df_spark_employee.join( - F.broadcast(self.df_spark_store), - on=self.df_spark_employee.store_id == self.df_spark_store.store_id, - how="inner", - ).select( - self.df_spark_employee.employee_id, - self.df_spark_employee["fname"], - F.col("lname"), - F.col("age"), - self.df_spark_employee.store_id, - self.df_spark_store.store_name, - self.df_spark_store["num_sales"], - ) - dfs_joined = self.df_sqlglot_employee.join( - SF.broadcast(self.df_sqlglot_store), - on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, - how="inner", - ).select( - self.df_sqlglot_employee.employee_id, - self.df_sqlglot_employee["fname"], - SF.col("lname"), - SF.col("age"), - self.df_sqlglot_employee.store_id, - self.df_sqlglot_store.store_name, - self.df_sqlglot_store["num_sales"], - ) - df, dfs = self.compare_spark_with_sqlglot(df_joined, dfs_joined) - self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(df)) - self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(dfs)) - self.assertEqual( - "'UnresolvedHint BROADCAST, ['a2]", self.get_explain_plan(dfs).split("\n")[1] - ) - - def test_repartition_by_num(self): - """ - The results are different when doing the repartition on a table created using VALUES in SQL. - So I just use the views instead for these tests - """ - df = self.df_spark_employee.repartition(63) - - dfs = self.sqlglot.read.table("employee").repartition(63) - df, dfs = self.compare_spark_with_sqlglot(df, dfs) - spark_num_partitions = df.rdd.getNumPartitions() - sqlglot_num_partitions = dfs.rdd.getNumPartitions() - self.assertEqual(spark_num_partitions, 63) - self.assertEqual(spark_num_partitions, sqlglot_num_partitions) - - def test_repartition_name_only(self): - """ - We use the view here to help ensure the explain plans are similar enough to compare - """ - df = self.df_spark_employee.repartition("age") - - dfs = self.sqlglot.read.table("employee").repartition("age") - df, dfs = self.compare_spark_with_sqlglot(df, dfs) - self.assertIn("RepartitionByExpression [age", self.get_explain_plan(df)) - self.assertIn("RepartitionByExpression [age", self.get_explain_plan(dfs)) - - def test_repartition_num_and_multiple_names(self): - """ - We use the view here to help ensure the explain plans are similar enough to compare - """ - df = self.df_spark_employee.repartition(53, "age", "fname") - - dfs = self.sqlglot.read.table("employee").repartition(53, "age", "fname") - df, dfs = self.compare_spark_with_sqlglot(df, dfs) - spark_num_partitions = df.rdd.getNumPartitions() - sqlglot_num_partitions = dfs.rdd.getNumPartitions() - self.assertEqual(spark_num_partitions, 53) - self.assertEqual(spark_num_partitions, sqlglot_num_partitions) - self.assertIn("RepartitionByExpression [age#3, fname#1], 53", self.get_explain_plan(df)) - self.assertIn("RepartitionByExpression [age#3, fname#1], 53", self.get_explain_plan(dfs)) - - def test_coalesce(self): - df = self.df_spark_employee.coalesce(1) - dfs = self.df_sqlglot_employee.coalesce(1) - df, dfs = self.compare_spark_with_sqlglot(df, dfs) - spark_num_partitions = df.rdd.getNumPartitions() - sqlglot_num_partitions = dfs.rdd.getNumPartitions() - self.assertEqual(spark_num_partitions, 1) - self.assertEqual(spark_num_partitions, sqlglot_num_partitions) - - def test_cache_select(self): - df_employee = ( - self.df_spark_employee.groupBy("store_id") - .agg(F.countDistinct("employee_id").alias("num_employees")) - .cache() - ) - df_joined = df_employee.join(self.df_spark_store, on="store_id").select( - self.df_spark_store.store_id, df_employee.num_employees - ) - dfs_employee = ( - self.df_sqlglot_employee.groupBy("store_id") - .agg(SF.countDistinct("employee_id").alias("num_employees")) - .cache() - ) - dfs_joined = dfs_employee.join(self.df_sqlglot_store, on="store_id").select( - self.df_sqlglot_store.store_id, dfs_employee.num_employees - ) - self.compare_spark_with_sqlglot(df_joined, dfs_joined) - - def test_persist_select(self): - df_employee = ( - self.df_spark_employee.groupBy("store_id") - .agg(F.countDistinct("employee_id").alias("num_employees")) - .persist() - ) - df_joined = df_employee.join(self.df_spark_store, on="store_id").select( - self.df_spark_store.store_id, df_employee.num_employees - ) - dfs_employee = ( - self.df_sqlglot_employee.groupBy("store_id") - .agg(SF.countDistinct("employee_id").alias("num_employees")) - .persist() - ) - dfs_joined = dfs_employee.join(self.df_sqlglot_store, on="store_id").select( - self.df_sqlglot_store.store_id, dfs_employee.num_employees - ) - self.compare_spark_with_sqlglot(df_joined, dfs_joined) diff --git a/tests/dataframe/integration/test_grouped_data.py b/tests/dataframe/integration/test_grouped_data.py deleted file mode 100644 index 2768dda..0000000 --- a/tests/dataframe/integration/test_grouped_data.py +++ /dev/null @@ -1,71 +0,0 @@ -from pyspark.sql import functions as F - -from sqlglot.dataframe.sql import functions as SF -from tests.dataframe.integration.dataframe_validator import DataFrameValidator - - -class TestDataframeFunc(DataFrameValidator): - def test_group_by(self): - df_employee = self.df_spark_employee.groupBy(self.df_spark_employee.age).agg( - F.min(self.df_spark_employee.employee_id) - ) - dfs_employee = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).agg( - SF.min(self.df_sqlglot_employee.employee_id) - ) - self.compare_spark_with_sqlglot(df_employee, dfs_employee, skip_schema_compare=True) - - def test_group_by_where_non_aggregate(self): - df_employee = ( - self.df_spark_employee.groupBy(self.df_spark_employee.age) - .agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id")) - .where(F.col("age") > F.lit(50)) - ) - dfs_employee = ( - self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age) - .agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id")) - .where(SF.col("age") > SF.lit(50)) - ) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - def test_group_by_where_aggregate_like_having(self): - df_employee = ( - self.df_spark_employee.groupBy(self.df_spark_employee.age) - .agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id")) - .where(F.col("min_employee_id") > F.lit(1)) - ) - dfs_employee = ( - self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age) - .agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id")) - .where(SF.col("min_employee_id") > SF.lit(1)) - ) - self.compare_spark_with_sqlglot(df_employee, dfs_employee) - - def test_count(self): - df = self.df_spark_employee.groupBy(self.df_spark_employee.age).count() - dfs = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).count() - self.compare_spark_with_sqlglot(df, dfs) - - def test_mean(self): - df = self.df_spark_employee.groupBy().mean("age", "store_id") - dfs = self.df_sqlglot_employee.groupBy().mean("age", "store_id") - self.compare_spark_with_sqlglot(df, dfs) - - def test_avg(self): - df = self.df_spark_employee.groupBy("age").avg("store_id") - dfs = self.df_sqlglot_employee.groupBy("age").avg("store_id") - self.compare_spark_with_sqlglot(df, dfs) - - def test_max(self): - df = self.df_spark_employee.groupBy("age").max("store_id") - dfs = self.df_sqlglot_employee.groupBy("age").max("store_id") - self.compare_spark_with_sqlglot(df, dfs) - - def test_min(self): - df = self.df_spark_employee.groupBy("age").min("store_id") - dfs = self.df_sqlglot_employee.groupBy("age").min("store_id") - self.compare_spark_with_sqlglot(df, dfs) - - def test_sum(self): - df = self.df_spark_employee.groupBy("age").sum("store_id") - dfs = self.df_sqlglot_employee.groupBy("age").sum("store_id") - self.compare_spark_with_sqlglot(df, dfs) diff --git a/tests/dataframe/integration/test_session.py b/tests/dataframe/integration/test_session.py deleted file mode 100644 index 3bb3e20..0000000 --- a/tests/dataframe/integration/test_session.py +++ /dev/null @@ -1,43 +0,0 @@ -from pyspark.sql import functions as F - -from sqlglot.dataframe.sql import functions as SF -from tests.dataframe.integration.dataframe_validator import DataFrameValidator - - -class TestSessionFunc(DataFrameValidator): - def test_sql_simple_select(self): - query = "SELECT fname, lname FROM employee" - df = self.spark.sql(query) - dfs = self.sqlglot.sql(query) - self.compare_spark_with_sqlglot(df, dfs) - - def test_sql_with_join(self): - query = """ - SELECT - e.employee_id - , s.store_id - FROM - employee e - INNER JOIN - store s - ON - e.store_id = s.store_id - """ - df = ( - self.spark.sql(query) - .groupBy(F.col("store_id")) - .agg(F.countDistinct(F.col("employee_id"))) - ) - dfs = ( - self.sqlglot.sql(query) - .groupBy(SF.col("store_id")) - .agg(SF.countDistinct(SF.col("employee_id"))) - ) - self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) - - def test_nameless_column(self): - query = "SELECT MAX(age) FROM employee" - df = self.spark.sql(query) - dfs = self.sqlglot.sql(query) - # Spark will alias the column to `max(age)` while sqlglot will alias to `_col_0` so their schemas will differ - self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True) |