summaryrefslogtreecommitdiffstats
path: root/tests/dataframe/integration
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dataframe/integration')
-rw-r--r--tests/dataframe/integration/__init__.py0
-rw-r--r--tests/dataframe/integration/dataframe_validator.py174
-rw-r--r--tests/dataframe/integration/test_dataframe.py1281
-rw-r--r--tests/dataframe/integration/test_grouped_data.py71
-rw-r--r--tests/dataframe/integration/test_session.py43
5 files changed, 0 insertions, 1569 deletions
diff --git a/tests/dataframe/integration/__init__.py b/tests/dataframe/integration/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/tests/dataframe/integration/__init__.py
+++ /dev/null
diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py
deleted file mode 100644
index 22d4982..0000000
--- a/tests/dataframe/integration/dataframe_validator.py
+++ /dev/null
@@ -1,174 +0,0 @@
-import typing as t
-import unittest
-import warnings
-
-import sqlglot
-from tests.helpers import SKIP_INTEGRATION
-
-if t.TYPE_CHECKING:
- from pyspark.sql import DataFrame as SparkDataFrame
-
-
-@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set")
-class DataFrameValidator(unittest.TestCase):
- spark = None
- sqlglot = None
- df_employee = None
- df_store = None
- df_district = None
- spark_employee_schema = None
- sqlglot_employee_schema = None
- spark_store_schema = None
- sqlglot_store_schema = None
- spark_district_schema = None
- sqlglot_district_schema = None
-
- @classmethod
- def setUpClass(cls):
- from pyspark import SparkConf
- from pyspark.sql import SparkSession, types
-
- from sqlglot.dataframe.sql import types as sqlglotSparkTypes
- from sqlglot.dataframe.sql.session import SparkSession as SqlglotSparkSession
-
- # This is for test `test_branching_root_dataframes`
- config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
- cls.spark = (
- SparkSession.builder.master("local[*]")
- .appName("Unit-tests")
- .config(conf=config)
- .getOrCreate()
- )
- cls.spark.sparkContext.setLogLevel("ERROR")
- cls.sqlglot = SqlglotSparkSession()
- cls.spark_employee_schema = types.StructType(
- [
- types.StructField("employee_id", types.IntegerType(), False),
- types.StructField("fname", types.StringType(), False),
- types.StructField("lname", types.StringType(), False),
- types.StructField("age", types.IntegerType(), False),
- types.StructField("store_id", types.IntegerType(), False),
- ]
- )
- cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
- [
- sqlglotSparkTypes.StructField(
- "employee_id", sqlglotSparkTypes.IntegerType(), False
- ),
- sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
- sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
- sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
- sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
- ]
- )
- employee_data = [
- (1, "Jack", "Shephard", 37, 1),
- (2, "John", "Locke", 65, 1),
- (3, "Kate", "Austen", 37, 2),
- (4, "Claire", "Littleton", 27, 2),
- (5, "Hugo", "Reyes", 29, 100),
- ]
- cls.df_employee = cls.spark.createDataFrame(
- data=employee_data, schema=cls.spark_employee_schema
- )
- cls.dfs_employee = cls.sqlglot.createDataFrame(
- data=employee_data, schema=cls.sqlglot_employee_schema
- )
- cls.df_employee.createOrReplaceTempView("employee")
-
- cls.spark_store_schema = types.StructType(
- [
- types.StructField("store_id", types.IntegerType(), False),
- types.StructField("store_name", types.StringType(), False),
- types.StructField("district_id", types.IntegerType(), False),
- types.StructField("num_sales", types.IntegerType(), False),
- ]
- )
- cls.sqlglot_store_schema = sqlglotSparkTypes.StructType(
- [
- sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
- sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
- sqlglotSparkTypes.StructField(
- "district_id", sqlglotSparkTypes.IntegerType(), False
- ),
- sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
- ]
- )
- store_data = [
- (1, "Hydra", 1, 37),
- (2, "Arrow", 2, 2000),
- ]
- cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
- cls.dfs_store = cls.sqlglot.createDataFrame(
- data=store_data, schema=cls.sqlglot_store_schema
- )
- cls.df_store.createOrReplaceTempView("store")
-
- cls.spark_district_schema = types.StructType(
- [
- types.StructField("district_id", types.IntegerType(), False),
- types.StructField("district_name", types.StringType(), False),
- types.StructField("manager_name", types.StringType(), False),
- ]
- )
- cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
- [
- sqlglotSparkTypes.StructField(
- "district_id", sqlglotSparkTypes.IntegerType(), False
- ),
- sqlglotSparkTypes.StructField(
- "district_name", sqlglotSparkTypes.StringType(), False
- ),
- sqlglotSparkTypes.StructField(
- "manager_name", sqlglotSparkTypes.StringType(), False
- ),
- ]
- )
- district_data = [
- (1, "Temple", "Dogen"),
- (2, "Lighthouse", "Jacob"),
- ]
- cls.df_district = cls.spark.createDataFrame(
- data=district_data, schema=cls.spark_district_schema
- )
- cls.dfs_district = cls.sqlglot.createDataFrame(
- data=district_data, schema=cls.sqlglot_district_schema
- )
- cls.df_district.createOrReplaceTempView("district")
- sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema, dialect="spark")
- sqlglot.schema.add_table("store", cls.sqlglot_store_schema, dialect="spark")
- sqlglot.schema.add_table("district", cls.sqlglot_district_schema, dialect="spark")
-
- def setUp(self) -> None:
- warnings.filterwarnings("ignore", category=ResourceWarning)
- self.df_spark_store = self.df_store.alias("df_store") # type: ignore
- self.df_spark_employee = self.df_employee.alias("df_employee") # type: ignore
- self.df_spark_district = self.df_district.alias("df_district") # type: ignore
- self.df_sqlglot_store = self.dfs_store.alias("store") # type: ignore
- self.df_sqlglot_employee = self.dfs_employee.alias("employee") # type: ignore
- self.df_sqlglot_district = self.dfs_district.alias("district") # type: ignore
-
- def compare_spark_with_sqlglot(
- self, df_spark, df_sqlglot, no_empty=True, skip_schema_compare=False
- ) -> t.Tuple["SparkDataFrame", "SparkDataFrame"]:
- def compare_schemas(schema_1, schema_2):
- for schema in [schema_1, schema_2]:
- for struct_field in schema.fields:
- struct_field.metadata = {}
- self.assertEqual(schema_1, schema_2)
-
- for statement in df_sqlglot.sql():
- actual_df_sqlglot = self.spark.sql(statement) # type: ignore
- df_sqlglot_results = actual_df_sqlglot.collect()
- df_spark_results = df_spark.collect()
- if not skip_schema_compare:
- compare_schemas(df_spark.schema, actual_df_sqlglot.schema)
- self.assertEqual(df_spark_results, df_sqlglot_results)
- if no_empty:
- self.assertNotEqual(len(df_spark_results), 0)
- self.assertNotEqual(len(df_sqlglot_results), 0)
- return df_spark, actual_df_sqlglot
-
- @classmethod
- def get_explain_plan(cls, df: "SparkDataFrame", mode: str = "extended") -> str:
- return df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), mode) # type: ignore
diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py
deleted file mode 100644
index 702c6ee..0000000
--- a/tests/dataframe/integration/test_dataframe.py
+++ /dev/null
@@ -1,1281 +0,0 @@
-from pyspark.sql import functions as F
-
-from sqlglot.dataframe.sql import functions as SF
-from tests.dataframe.integration.dataframe_validator import DataFrameValidator
-
-
-class TestDataframeFunc(DataFrameValidator):
- def test_simple_select(self):
- df_employee = self.df_spark_employee.select(F.col("employee_id"))
- dfs_employee = self.df_sqlglot_employee.select(SF.col("employee_id"))
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_simple_select_from_table(self):
- df = self.df_spark_employee
- dfs = self.sqlglot.read.table("employee")
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_simple_select_df_attribute(self):
- df_employee = self.df_spark_employee.select(self.df_spark_employee.employee_id)
- dfs_employee = self.df_sqlglot_employee.select(self.df_sqlglot_employee.employee_id)
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_simple_select_df_dict(self):
- df_employee = self.df_spark_employee.select(self.df_spark_employee["employee_id"])
- dfs_employee = self.df_sqlglot_employee.select(self.df_sqlglot_employee["employee_id"])
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_multiple_selects(self):
- df_employee = self.df_spark_employee.select(
- self.df_spark_employee["employee_id"], F.col("fname"), self.df_spark_employee.lname
- )
- dfs_employee = self.df_sqlglot_employee.select(
- self.df_sqlglot_employee["employee_id"], SF.col("fname"), self.df_sqlglot_employee.lname
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_alias_no_op(self):
- df_employee = self.df_spark_employee.alias("df_employee")
- dfs_employee = self.df_sqlglot_employee.alias("dfs_employee")
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_alias_with_select(self):
- df_employee = self.df_spark_employee.alias("df_employee").select(
- self.df_spark_employee["employee_id"],
- F.col("df_employee.fname"),
- self.df_spark_employee.lname,
- )
- dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select(
- self.df_sqlglot_employee["employee_id"],
- SF.col("dfs_employee.fname"),
- self.df_sqlglot_employee.lname,
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_case_when_otherwise(self):
- df = self.df_spark_employee.select(
- F.when(
- (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
- F.lit("between 40 and 60"),
- )
- .when(F.col("age") < F.lit(40), "less than 40")
- .otherwise("greater than 60")
- )
-
- dfs = self.df_sqlglot_employee.select(
- SF.when(
- (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
- SF.lit("between 40 and 60"),
- )
- .when(SF.col("age") < SF.lit(40), "less than 40")
- .otherwise("greater than 60")
- )
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_case_when_no_otherwise(self):
- df = self.df_spark_employee.select(
- F.when(
- (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
- F.lit("between 40 and 60"),
- ).when(F.col("age") < F.lit(40), "less than 40")
- )
-
- dfs = self.df_sqlglot_employee.select(
- SF.when(
- (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
- SF.lit("between 40 and 60"),
- ).when(SF.col("age") < SF.lit(40), "less than 40")
- )
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_where_clause_single(self):
- df_employee = self.df_spark_employee.where(F.col("age") == F.lit(37))
- dfs_employee = self.df_sqlglot_employee.where(SF.col("age") == SF.lit(37))
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_where_clause_multiple_and(self):
- df_employee = self.df_spark_employee.where(
- (F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))
- )
- dfs_employee = self.df_sqlglot_employee.where(
- (SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack"))
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_where_many_and(self):
- df_employee = self.df_spark_employee.where(
- (F.col("age") == F.lit(37))
- & (F.col("fname") == F.lit("Jack"))
- & (F.col("lname") == F.lit("Shephard"))
- & (F.col("employee_id") == F.lit(1))
- )
- dfs_employee = self.df_sqlglot_employee.where(
- (SF.col("age") == SF.lit(37))
- & (SF.col("fname") == SF.lit("Jack"))
- & (SF.col("lname") == SF.lit("Shephard"))
- & (SF.col("employee_id") == SF.lit(1))
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_where_clause_multiple_or(self):
- df_employee = self.df_spark_employee.where(
- (F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))
- )
- dfs_employee = self.df_sqlglot_employee.where(
- (SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate"))
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_where_many_or(self):
- df_employee = self.df_spark_employee.where(
- (F.col("age") == F.lit(37))
- | (F.col("fname") == F.lit("Kate"))
- | (F.col("lname") == F.lit("Littleton"))
- | (F.col("employee_id") == F.lit(2))
- )
- dfs_employee = self.df_sqlglot_employee.where(
- (SF.col("age") == SF.lit(37))
- | (SF.col("fname") == SF.lit("Kate"))
- | (SF.col("lname") == SF.lit("Littleton"))
- | (SF.col("employee_id") == SF.lit(2))
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_where_mixed_and_or(self):
- df_employee = self.df_spark_employee.where(
- ((F.col("age") == F.lit(65)) & (F.col("fname") == F.lit("John")))
- | ((F.col("lname") == F.lit("Shephard")) & (F.col("age") == F.lit(37)))
- )
- dfs_employee = self.df_sqlglot_employee.where(
- ((SF.col("age") == SF.lit(65)) & (SF.col("fname") == SF.lit("John")))
- | ((SF.col("lname") == SF.lit("Shephard")) & (SF.col("age") == SF.lit(37)))
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_where_multiple_chained(self):
- df_employee = self.df_spark_employee.where(F.col("age") == F.lit(37)).where(
- self.df_spark_employee.fname == F.lit("Jack")
- )
- dfs_employee = self.df_sqlglot_employee.where(SF.col("age") == SF.lit(37)).where(
- self.df_sqlglot_employee.fname == SF.lit("Jack")
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_operators(self):
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] < F.lit(50))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] < SF.lit(50))
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] <= F.lit(37))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] <= SF.lit(37))
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] > F.lit(50))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] > SF.lit(50))
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] >= F.lit(37))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] >= SF.lit(37))
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] != F.lit(50))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] != SF.lit(50))
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] == F.lit(37))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37))
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(
- self.df_spark_employee["age"] % F.lit(5) == F.lit(0)
- )
- dfs_employee = self.df_sqlglot_employee.where(
- self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(
- self.df_spark_employee["age"] + F.lit(5) > F.lit(28)
- )
- dfs_employee = self.df_sqlglot_employee.where(
- self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(
- self.df_spark_employee["age"] - F.lit(5) > F.lit(28)
- )
- dfs_employee = self.df_sqlglot_employee.where(
- self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- df_employee = self.df_spark_employee.where(
- self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2)
- )
- dfs_employee = self.df_sqlglot_employee.where(
- self.df_sqlglot_employee["age"] * SF.lit(0.5)
- == self.df_sqlglot_employee["age"] / SF.lit(2)
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_join_inner(self):
- df_joined = self.df_spark_employee.join(
- self.df_spark_store, on=["store_id"], how="inner"
- ).select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- F.col("store_id"),
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- )
- dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store, on=["store_id"], how="inner"
- ).select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- SF.col("store_id"),
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_join_inner_no_select(self):
- df_joined = self.df_spark_employee.select(
- F.col("store_id"), F.col("fname"), F.col("lname")
- ).join(
- self.df_spark_store.select(F.col("store_id"), F.col("store_name")),
- on=["store_id"],
- how="inner",
- )
- dfs_joined = self.df_sqlglot_employee.select(
- SF.col("store_id"), SF.col("fname"), SF.col("lname")
- ).join(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")),
- on=["store_id"],
- how="inner",
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_join_inner_equality_single(self):
- df_joined = self.df_spark_employee.join(
- self.df_spark_store,
- on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
- how="inner",
- ).select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- self.df_spark_employee.store_id,
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- F.lit("literal_value"),
- )
- dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store,
- on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
- how="inner",
- ).select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- self.df_sqlglot_employee.store_id,
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- SF.lit("literal_value"),
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_join_inner_equality_multiple(self):
- df_joined = self.df_spark_employee.join(
- self.df_spark_store,
- on=[
- self.df_spark_employee.store_id == self.df_spark_store.store_id,
- self.df_spark_employee.age == self.df_spark_store.num_sales,
- ],
- how="inner",
- ).select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- self.df_spark_employee.store_id,
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- )
- dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store,
- on=[
- self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
- self.df_sqlglot_employee.age == self.df_sqlglot_store.num_sales,
- ],
- how="inner",
- ).select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- self.df_sqlglot_employee.store_id,
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_join_inner_equality_multiple_bitwise_and(self):
- df_joined = self.df_spark_employee.join(
- self.df_spark_store,
- on=(self.df_spark_store.store_id == self.df_spark_employee.store_id)
- & (self.df_spark_store.num_sales == self.df_spark_employee.age),
- how="inner",
- ).select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- self.df_spark_employee.store_id,
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- )
- dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store,
- on=(self.df_sqlglot_store.store_id == self.df_sqlglot_employee.store_id)
- & (self.df_sqlglot_store.num_sales == self.df_sqlglot_employee.age),
- how="inner",
- ).select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- self.df_sqlglot_employee.store_id,
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_join_left_outer(self):
- df_joined = (
- self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="left_outer")
- .select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- F.col("store_id"),
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- )
- .orderBy(F.col("employee_id"))
- )
- dfs_joined = (
- self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="left_outer")
- .select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- SF.col("store_id"),
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- )
- .orderBy(SF.col("employee_id"))
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_join_full_outer(self):
- df_joined = self.df_spark_employee.join(
- self.df_spark_store, on=["store_id"], how="full_outer"
- ).select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- F.col("store_id"),
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- )
- dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store, on=["store_id"], how="full_outer"
- ).select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- SF.col("store_id"),
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_triple_join(self):
- df = (
- self.df_employee.join(
- self.df_store, on=self.df_employee.employee_id == self.df_store.store_id
- )
- .join(self.df_district, on=self.df_store.store_id == self.df_district.district_id)
- .select(
- self.df_employee.employee_id,
- self.df_store.store_id,
- self.df_district.district_id,
- self.df_employee.fname,
- self.df_store.store_name,
- self.df_district.district_name,
- )
- )
- dfs = (
- self.dfs_employee.join(
- self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id
- )
- .join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id)
- .select(
- self.dfs_employee.employee_id,
- self.dfs_store.store_id,
- self.dfs_district.district_id,
- self.dfs_employee.fname,
- self.dfs_store.store_name,
- self.dfs_district.district_name,
- )
- )
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_triple_join_no_select(self):
- df = (
- self.df_employee.join(
- self.df_store,
- on=self.df_employee["employee_id"] == self.df_store["store_id"],
- how="left",
- )
- .join(
- self.df_district,
- on=self.df_store["store_id"] == self.df_district["district_id"],
- how="left",
- )
- .orderBy(F.col("employee_id"))
- )
- dfs = (
- self.dfs_employee.join(
- self.dfs_store,
- on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"],
- how="left",
- )
- .join(
- self.dfs_district,
- on=self.dfs_store["store_id"] == self.dfs_district["district_id"],
- how="left",
- )
- .orderBy(SF.col("employee_id"))
- )
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_triple_joins_filter(self):
- df = (
- self.df_employee.join(
- self.df_store,
- on=self.df_employee["employee_id"] == self.df_store["store_id"],
- how="left",
- ).join(
- self.df_district,
- on=self.df_store["store_id"] == self.df_district["district_id"],
- how="left",
- )
- ).filter(F.coalesce(self.df_store["num_sales"], F.lit(0)) > 100)
- dfs = (
- self.dfs_employee.join(
- self.dfs_store,
- on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"],
- how="left",
- ).join(
- self.dfs_district,
- on=self.dfs_store["store_id"] == self.dfs_district["district_id"],
- how="left",
- )
- ).filter(SF.coalesce(self.dfs_store["num_sales"], SF.lit(0)) > 100)
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_triple_join_column_name_only(self):
- df = (
- self.df_employee.join(
- self.df_store,
- on=self.df_employee["employee_id"] == self.df_store["store_id"],
- how="left",
- )
- .join(self.df_district, on="district_id", how="left")
- .orderBy(F.col("employee_id"))
- )
- dfs = (
- self.dfs_employee.join(
- self.dfs_store,
- on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"],
- how="left",
- )
- .join(self.dfs_district, on="district_id", how="left")
- .orderBy(SF.col("employee_id"))
- )
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_join_select_and_select_start(self):
- df = self.df_spark_employee.select(
- F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")
- ).join(self.df_spark_store, "store_id", "inner")
-
- dfs = self.df_sqlglot_employee.select(
- SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
- ).join(self.df_sqlglot_store, "store_id", "inner")
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_branching_root_dataframes(self):
- """
- Test a pattern that has non-intuitive behavior in spark
-
- Scenario: You do a self-join in a dataframe using an original dataframe and then a modified version
- of it. You then reference the columns by the dataframe name instead of the column function.
- Spark will use the root dataframe's column in the result.
- """
- df_hydra_employees_only = self.df_spark_employee.where(F.col("store_id") == F.lit(1))
- df_joined = (
- self.df_spark_employee.where(F.col("store_id") == F.lit(2))
- .alias("df_arrow_employees_only")
- .join(
- df_hydra_employees_only.alias("df_hydra_employees_only"),
- on=["store_id"],
- how="full_outer",
- )
- .select(
- self.df_spark_employee.fname,
- F.col("df_arrow_employees_only.fname"),
- df_hydra_employees_only.fname,
- F.col("df_hydra_employees_only.fname"),
- )
- )
-
- dfs_hydra_employees_only = self.df_sqlglot_employee.where(SF.col("store_id") == SF.lit(1))
- dfs_joined = (
- self.df_sqlglot_employee.where(SF.col("store_id") == SF.lit(2))
- .alias("dfs_arrow_employees_only")
- .join(
- dfs_hydra_employees_only.alias("dfs_hydra_employees_only"),
- on=["store_id"],
- how="full_outer",
- )
- .select(
- self.df_sqlglot_employee.fname,
- SF.col("dfs_arrow_employees_only.fname"),
- dfs_hydra_employees_only.fname,
- SF.col("dfs_hydra_employees_only.fname"),
- )
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_basic_union(self):
- df_unioned = self.df_spark_employee.select(F.col("employee_id"), F.col("age")).union(
- self.df_spark_store.select(F.col("store_id"), F.col("num_sales"))
- )
-
- dfs_unioned = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("age")).union(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("num_sales"))
- )
- self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
-
- def test_union_with_join(self):
- df_joined = self.df_spark_employee.join(
- self.df_spark_store,
- on="store_id",
- how="inner",
- )
- df_unioned = df_joined.select(F.col("store_id"), F.col("store_name")).union(
- self.df_spark_district.select(F.col("district_id"), F.col("district_name"))
- )
-
- dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store,
- on="store_id",
- how="inner",
- )
- dfs_unioned = dfs_joined.select(SF.col("store_id"), SF.col("store_name")).union(
- self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))
- )
-
- self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
-
- def test_double_union_all(self):
- df_unioned = (
- self.df_spark_employee.select(F.col("employee_id"), F.col("fname"))
- .unionAll(self.df_spark_store.select(F.col("store_id"), F.col("store_name")))
- .unionAll(self.df_spark_district.select(F.col("district_id"), F.col("district_name")))
- )
-
- dfs_unioned = (
- self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"))
- .unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")))
- .unionAll(
- self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))
- )
- )
-
- self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
-
- def test_union_by_name(self):
- df = self.df_spark_employee.select(
- F.col("employee_id"), F.col("fname"), F.col("lname")
- ).unionByName(
- self.df_spark_store.select(
- F.col("store_name").alias("lname"),
- F.col("store_id").alias("employee_id"),
- F.col("store_name").alias("fname"),
- )
- )
-
- dfs = self.df_sqlglot_employee.select(
- SF.col("employee_id"), SF.col("fname"), SF.col("lname")
- ).unionByName(
- self.df_sqlglot_store.select(
- SF.col("store_name").alias("lname"),
- SF.col("store_id").alias("employee_id"),
- SF.col("store_name").alias("fname"),
- )
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_union_by_name_allow_missing(self):
- df = self.df_spark_employee.select(
- F.col("age"), F.col("employee_id"), F.col("fname"), F.col("lname")
- ).unionByName(
- self.df_spark_store.select(
- F.col("store_name").alias("lname"),
- F.col("store_id").alias("employee_id"),
- F.col("store_name").alias("fname"),
- F.col("num_sales"),
- ),
- allowMissingColumns=True,
- )
-
- dfs = self.df_sqlglot_employee.select(
- SF.col("age"), SF.col("employee_id"), SF.col("fname"), SF.col("lname")
- ).unionByName(
- self.df_sqlglot_store.select(
- SF.col("store_name").alias("lname"),
- SF.col("store_id").alias("employee_id"),
- SF.col("store_name").alias("fname"),
- SF.col("num_sales"),
- ),
- allowMissingColumns=True,
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_order_by_default(self):
- df = (
- self.df_spark_store.groupBy(F.col("district_id"))
- .agg(F.min("num_sales"))
- .orderBy(F.col("district_id"))
- )
-
- dfs = (
- self.df_sqlglot_store.groupBy(SF.col("district_id"))
- .agg(SF.min("num_sales"))
- .orderBy(SF.col("district_id"))
- )
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_order_by_array_bool(self):
- df = (
- self.df_spark_store.groupBy(F.col("district_id"))
- .agg(F.min("num_sales").alias("total_sales"))
- .orderBy(F.col("total_sales"), F.col("district_id"), ascending=[1, 0])
- )
-
- dfs = (
- self.df_sqlglot_store.groupBy(SF.col("district_id"))
- .agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(SF.col("total_sales"), SF.col("district_id"), ascending=[1, 0])
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_order_by_single_bool(self):
- df = (
- self.df_spark_store.groupBy(F.col("district_id"))
- .agg(F.min("num_sales").alias("total_sales"))
- .orderBy(F.col("total_sales"), F.col("district_id"), ascending=False)
- )
-
- dfs = (
- self.df_sqlglot_store.groupBy(SF.col("district_id"))
- .agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(SF.col("total_sales"), SF.col("district_id"), ascending=False)
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_order_by_column_sort_method(self):
- df = (
- self.df_spark_store.groupBy(F.col("district_id"))
- .agg(F.min("num_sales").alias("total_sales"))
- .orderBy(F.col("total_sales").asc(), F.col("district_id").desc())
- )
-
- dfs = (
- self.df_sqlglot_store.groupBy(SF.col("district_id"))
- .agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(SF.col("total_sales").asc(), SF.col("district_id").desc())
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_order_by_column_sort_method_nulls_last(self):
- df = (
- self.df_spark_store.groupBy(F.col("district_id"))
- .agg(F.min("num_sales").alias("total_sales"))
- .orderBy(
- F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()
- )
- )
-
- dfs = (
- self.df_sqlglot_store.groupBy(SF.col("district_id"))
- .agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(
- SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()
- )
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_order_by_column_sort_method_nulls_first(self):
- df = (
- self.df_spark_store.groupBy(F.col("district_id"))
- .agg(F.min("num_sales").alias("total_sales"))
- .orderBy(
- F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()
- )
- )
-
- dfs = (
- self.df_sqlglot_store.groupBy(SF.col("district_id"))
- .agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(
- SF.when(
- SF.col("district_id") == SF.lit(1), SF.col("district_id")
- ).desc_nulls_first()
- )
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_intersect(self):
- df_employee_duplicate = self.df_spark_employee.select(
- F.col("employee_id"), F.col("store_id")
- ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
-
- df_store_duplicate = self.df_spark_store.select(
- F.col("store_id"), F.col("district_id")
- ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
-
- df = df_employee_duplicate.intersect(df_store_duplicate)
-
- dfs_employee_duplicate = self.df_sqlglot_employee.select(
- SF.col("employee_id"), SF.col("store_id")
- ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
-
- dfs_store_duplicate = self.df_sqlglot_store.select(
- SF.col("store_id"), SF.col("district_id")
- ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
-
- dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate)
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_intersect_all(self):
- df_employee_duplicate = self.df_spark_employee.select(
- F.col("employee_id"), F.col("store_id")
- ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
-
- df_store_duplicate = self.df_spark_store.select(
- F.col("store_id"), F.col("district_id")
- ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
-
- df = df_employee_duplicate.intersectAll(df_store_duplicate)
-
- dfs_employee_duplicate = self.df_sqlglot_employee.select(
- SF.col("employee_id"), SF.col("store_id")
- ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
-
- dfs_store_duplicate = self.df_sqlglot_store.select(
- SF.col("store_id"), SF.col("district_id")
- ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
-
- dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate)
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_except_all(self):
- df_employee_duplicate = self.df_spark_employee.select(
- F.col("employee_id"), F.col("store_id")
- ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
-
- df_store_duplicate = self.df_spark_store.select(
- F.col("store_id"), F.col("district_id")
- ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
-
- df = df_employee_duplicate.exceptAll(df_store_duplicate)
-
- dfs_employee_duplicate = self.df_sqlglot_employee.select(
- SF.col("employee_id"), SF.col("store_id")
- ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
-
- dfs_store_duplicate = self.df_sqlglot_store.select(
- SF.col("store_id"), SF.col("district_id")
- ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
-
- dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate)
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_distinct(self):
- df = self.df_spark_employee.select(F.col("age")).distinct()
-
- dfs = self.df_sqlglot_employee.select(SF.col("age")).distinct()
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_union_distinct(self):
- df_unioned = (
- self.df_spark_employee.select(F.col("employee_id"), F.col("age"))
- .union(self.df_spark_employee.select(F.col("employee_id"), F.col("age")))
- .distinct()
- )
-
- dfs_unioned = (
- self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("age"))
- .union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("age")))
- .distinct()
- )
- self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
-
- def test_drop_duplicates_no_subset(self):
- df = self.df_spark_employee.select("age").dropDuplicates()
- dfs = self.df_sqlglot_employee.select("age").dropDuplicates()
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_drop_duplicates_subset(self):
- df = self.df_spark_employee.dropDuplicates(["age"])
- dfs = self.df_sqlglot_employee.dropDuplicates(["age"])
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_drop_na_default(self):
- df = self.df_spark_employee.select(
- F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
- ).dropna()
-
- dfs = self.df_sqlglot_employee.select(
- SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
- ).dropna()
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_dropna_how(self):
- df = self.df_spark_employee.select(
- F.lit(None), F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
- ).dropna(how="all")
-
- dfs = self.df_sqlglot_employee.select(
- SF.lit(None), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
- ).dropna(how="all")
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_dropna_thresh(self):
- df = self.df_spark_employee.select(
- F.lit(None), F.lit(1), F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
- ).dropna(how="any", thresh=2)
-
- dfs = self.df_sqlglot_employee.select(
- SF.lit(None),
- SF.lit(1),
- SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
- ).dropna(how="any", thresh=2)
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_dropna_subset(self):
- df = self.df_spark_employee.select(
- F.lit(None), F.lit(1), F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
- ).dropna(thresh=1, subset="the_age")
-
- dfs = self.df_sqlglot_employee.select(
- SF.lit(None),
- SF.lit(1),
- SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
- ).dropna(thresh=1, subset="the_age")
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_dropna_na_function(self):
- df = self.df_spark_employee.select(
- F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
- ).na.drop()
-
- dfs = self.df_sqlglot_employee.select(
- SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
- ).na.drop()
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_fillna_default(self):
- df = self.df_spark_employee.select(
- F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
- ).fillna(100)
-
- dfs = self.df_sqlglot_employee.select(
- SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
- ).fillna(100)
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_fillna_dict_replacement(self):
- df = self.df_spark_employee.select(
- F.col("fname"),
- F.when(F.col("lname").startswith("L"), F.col("lname")).alias("l_lname"),
- F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age"),
- ).fillna({"fname": "Jacob", "l_lname": "NOT_LNAME"})
-
- dfs = self.df_sqlglot_employee.select(
- SF.col("fname"),
- SF.when(SF.col("lname").startswith("L"), SF.col("lname")).alias("l_lname"),
- SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
- ).fillna({"fname": "Jacob", "l_lname": "NOT_LNAME"})
-
- # For some reason the sqlglot results sets a column as nullable when it doesn't need to
- # This seems to be a nuance in how spark dataframe from sql works so we can ignore
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_fillna_na_func(self):
- df = self.df_spark_employee.select(
- F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
- ).na.fill(100)
-
- dfs = self.df_sqlglot_employee.select(
- SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
- ).na.fill(100)
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_replace_basic(self):
- df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
- to_replace=37, value=100
- )
-
- dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
- to_replace=37, value=100
- )
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_replace_basic_subset(self):
- df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
- to_replace=37, value=100, subset="age"
- )
-
- dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
- to_replace=37, value=100, subset="age"
- )
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_replace_mapping(self):
- df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
- {37: 100}
- )
-
- dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
- {37: 100}
- )
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_replace_mapping_subset(self):
- df = self.df_spark_employee.select(
- F.col("age"), F.lit(37).alias("test_col"), F.lit(50).alias("test_col_2")
- ).replace({37: 100, 50: 1}, subset=["age", "test_col_2"])
-
- dfs = self.df_sqlglot_employee.select(
- SF.col("age"), SF.lit(37).alias("test_col"), SF.lit(50).alias("test_col_2")
- ).replace({37: 100, 50: 1}, subset=["age", "test_col_2"])
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_replace_na_func_basic(self):
- df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).na.replace(
- to_replace=37, value=100
- )
-
- dfs = self.df_sqlglot_employee.select(
- SF.col("age"), SF.lit(37).alias("test_col")
- ).na.replace(to_replace=37, value=100)
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_with_column(self):
- df = self.df_spark_employee.withColumn("test", F.col("age"))
-
- dfs = self.df_sqlglot_employee.withColumn("test", SF.col("age"))
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_with_column_existing_name(self):
- df = self.df_spark_employee.withColumn("fname", F.lit("blah"))
-
- dfs = self.df_sqlglot_employee.withColumn("fname", SF.lit("blah"))
-
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_with_column_renamed(self):
- df = self.df_spark_employee.withColumnRenamed("fname", "first_name")
-
- dfs = self.df_sqlglot_employee.withColumnRenamed("fname", "first_name")
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_with_column_renamed_double(self):
- df = self.df_spark_employee.select(F.col("fname").alias("first_name")).withColumnRenamed(
- "first_name", "first_name_again"
- )
-
- dfs = self.df_sqlglot_employee.select(
- SF.col("fname").alias("first_name")
- ).withColumnRenamed("first_name", "first_name_again")
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_drop_column_single(self):
- df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age")
-
- dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop(
- "age"
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_drop_column_reference_join(self):
- df_spark_employee_cols = self.df_spark_employee.select(
- F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")
- )
- df_spark_store_cols = self.df_spark_store.select(F.col("store_id"), F.col("store_name"))
- df = df_spark_employee_cols.join(df_spark_store_cols, on="store_id", how="inner").drop(
- df_spark_employee_cols.age,
- )
-
- df_sqlglot_employee_cols = self.df_sqlglot_employee.select(
- SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
- )
- df_sqlglot_store_cols = self.df_sqlglot_store.select(
- SF.col("store_id"), SF.col("store_name")
- )
- dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop(
- df_sqlglot_employee_cols.age,
- )
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_limit(self):
- df = self.df_spark_employee.limit(1)
-
- dfs = self.df_sqlglot_employee.limit(1)
-
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_hint_broadcast_alias(self):
- df_joined = self.df_spark_employee.join(
- self.df_spark_store.alias("store").hint("broadcast", "store"),
- on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
- how="inner",
- ).select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- self.df_spark_employee.store_id,
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- )
- dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store.alias("store").hint("broadcast", "store"),
- on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
- how="inner",
- ).select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- self.df_sqlglot_employee.store_id,
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- )
- df, dfs = self.compare_spark_with_sqlglot(df_joined, dfs_joined)
- self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(df))
- self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(dfs))
-
- def test_hint_broadcast_no_alias(self):
- df_joined = self.df_spark_employee.join(
- self.df_spark_store.hint("broadcast"),
- on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
- how="inner",
- ).select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- self.df_spark_employee.store_id,
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- )
- dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store.hint("broadcast"),
- on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
- how="inner",
- ).select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- self.df_sqlglot_employee.store_id,
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- )
- df, dfs = self.compare_spark_with_sqlglot(df_joined, dfs_joined)
- self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(df))
- self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(dfs))
- self.assertEqual(
- "'UnresolvedHint BROADCAST, ['a2]", self.get_explain_plan(dfs).split("\n")[1]
- )
-
- def test_broadcast_func(self):
- df_joined = self.df_spark_employee.join(
- F.broadcast(self.df_spark_store),
- on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
- how="inner",
- ).select(
- self.df_spark_employee.employee_id,
- self.df_spark_employee["fname"],
- F.col("lname"),
- F.col("age"),
- self.df_spark_employee.store_id,
- self.df_spark_store.store_name,
- self.df_spark_store["num_sales"],
- )
- dfs_joined = self.df_sqlglot_employee.join(
- SF.broadcast(self.df_sqlglot_store),
- on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
- how="inner",
- ).select(
- self.df_sqlglot_employee.employee_id,
- self.df_sqlglot_employee["fname"],
- SF.col("lname"),
- SF.col("age"),
- self.df_sqlglot_employee.store_id,
- self.df_sqlglot_store.store_name,
- self.df_sqlglot_store["num_sales"],
- )
- df, dfs = self.compare_spark_with_sqlglot(df_joined, dfs_joined)
- self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(df))
- self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(dfs))
- self.assertEqual(
- "'UnresolvedHint BROADCAST, ['a2]", self.get_explain_plan(dfs).split("\n")[1]
- )
-
- def test_repartition_by_num(self):
- """
- The results are different when doing the repartition on a table created using VALUES in SQL.
- So I just use the views instead for these tests
- """
- df = self.df_spark_employee.repartition(63)
-
- dfs = self.sqlglot.read.table("employee").repartition(63)
- df, dfs = self.compare_spark_with_sqlglot(df, dfs)
- spark_num_partitions = df.rdd.getNumPartitions()
- sqlglot_num_partitions = dfs.rdd.getNumPartitions()
- self.assertEqual(spark_num_partitions, 63)
- self.assertEqual(spark_num_partitions, sqlglot_num_partitions)
-
- def test_repartition_name_only(self):
- """
- We use the view here to help ensure the explain plans are similar enough to compare
- """
- df = self.df_spark_employee.repartition("age")
-
- dfs = self.sqlglot.read.table("employee").repartition("age")
- df, dfs = self.compare_spark_with_sqlglot(df, dfs)
- self.assertIn("RepartitionByExpression [age", self.get_explain_plan(df))
- self.assertIn("RepartitionByExpression [age", self.get_explain_plan(dfs))
-
- def test_repartition_num_and_multiple_names(self):
- """
- We use the view here to help ensure the explain plans are similar enough to compare
- """
- df = self.df_spark_employee.repartition(53, "age", "fname")
-
- dfs = self.sqlglot.read.table("employee").repartition(53, "age", "fname")
- df, dfs = self.compare_spark_with_sqlglot(df, dfs)
- spark_num_partitions = df.rdd.getNumPartitions()
- sqlglot_num_partitions = dfs.rdd.getNumPartitions()
- self.assertEqual(spark_num_partitions, 53)
- self.assertEqual(spark_num_partitions, sqlglot_num_partitions)
- self.assertIn("RepartitionByExpression [age#3, fname#1], 53", self.get_explain_plan(df))
- self.assertIn("RepartitionByExpression [age#3, fname#1], 53", self.get_explain_plan(dfs))
-
- def test_coalesce(self):
- df = self.df_spark_employee.coalesce(1)
- dfs = self.df_sqlglot_employee.coalesce(1)
- df, dfs = self.compare_spark_with_sqlglot(df, dfs)
- spark_num_partitions = df.rdd.getNumPartitions()
- sqlglot_num_partitions = dfs.rdd.getNumPartitions()
- self.assertEqual(spark_num_partitions, 1)
- self.assertEqual(spark_num_partitions, sqlglot_num_partitions)
-
- def test_cache_select(self):
- df_employee = (
- self.df_spark_employee.groupBy("store_id")
- .agg(F.countDistinct("employee_id").alias("num_employees"))
- .cache()
- )
- df_joined = df_employee.join(self.df_spark_store, on="store_id").select(
- self.df_spark_store.store_id, df_employee.num_employees
- )
- dfs_employee = (
- self.df_sqlglot_employee.groupBy("store_id")
- .agg(SF.countDistinct("employee_id").alias("num_employees"))
- .cache()
- )
- dfs_joined = dfs_employee.join(self.df_sqlglot_store, on="store_id").select(
- self.df_sqlglot_store.store_id, dfs_employee.num_employees
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
-
- def test_persist_select(self):
- df_employee = (
- self.df_spark_employee.groupBy("store_id")
- .agg(F.countDistinct("employee_id").alias("num_employees"))
- .persist()
- )
- df_joined = df_employee.join(self.df_spark_store, on="store_id").select(
- self.df_spark_store.store_id, df_employee.num_employees
- )
- dfs_employee = (
- self.df_sqlglot_employee.groupBy("store_id")
- .agg(SF.countDistinct("employee_id").alias("num_employees"))
- .persist()
- )
- dfs_joined = dfs_employee.join(self.df_sqlglot_store, on="store_id").select(
- self.df_sqlglot_store.store_id, dfs_employee.num_employees
- )
- self.compare_spark_with_sqlglot(df_joined, dfs_joined)
diff --git a/tests/dataframe/integration/test_grouped_data.py b/tests/dataframe/integration/test_grouped_data.py
deleted file mode 100644
index 2768dda..0000000
--- a/tests/dataframe/integration/test_grouped_data.py
+++ /dev/null
@@ -1,71 +0,0 @@
-from pyspark.sql import functions as F
-
-from sqlglot.dataframe.sql import functions as SF
-from tests.dataframe.integration.dataframe_validator import DataFrameValidator
-
-
-class TestDataframeFunc(DataFrameValidator):
- def test_group_by(self):
- df_employee = self.df_spark_employee.groupBy(self.df_spark_employee.age).agg(
- F.min(self.df_spark_employee.employee_id)
- )
- dfs_employee = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).agg(
- SF.min(self.df_sqlglot_employee.employee_id)
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee, skip_schema_compare=True)
-
- def test_group_by_where_non_aggregate(self):
- df_employee = (
- self.df_spark_employee.groupBy(self.df_spark_employee.age)
- .agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
- .where(F.col("age") > F.lit(50))
- )
- dfs_employee = (
- self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
- .agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
- .where(SF.col("age") > SF.lit(50))
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_group_by_where_aggregate_like_having(self):
- df_employee = (
- self.df_spark_employee.groupBy(self.df_spark_employee.age)
- .agg(F.min(self.df_spark_employee.employee_id).alias("min_employee_id"))
- .where(F.col("min_employee_id") > F.lit(1))
- )
- dfs_employee = (
- self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age)
- .agg(SF.min(self.df_sqlglot_employee.employee_id).alias("min_employee_id"))
- .where(SF.col("min_employee_id") > SF.lit(1))
- )
- self.compare_spark_with_sqlglot(df_employee, dfs_employee)
-
- def test_count(self):
- df = self.df_spark_employee.groupBy(self.df_spark_employee.age).count()
- dfs = self.df_sqlglot_employee.groupBy(self.df_sqlglot_employee.age).count()
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_mean(self):
- df = self.df_spark_employee.groupBy().mean("age", "store_id")
- dfs = self.df_sqlglot_employee.groupBy().mean("age", "store_id")
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_avg(self):
- df = self.df_spark_employee.groupBy("age").avg("store_id")
- dfs = self.df_sqlglot_employee.groupBy("age").avg("store_id")
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_max(self):
- df = self.df_spark_employee.groupBy("age").max("store_id")
- dfs = self.df_sqlglot_employee.groupBy("age").max("store_id")
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_min(self):
- df = self.df_spark_employee.groupBy("age").min("store_id")
- dfs = self.df_sqlglot_employee.groupBy("age").min("store_id")
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_sum(self):
- df = self.df_spark_employee.groupBy("age").sum("store_id")
- dfs = self.df_sqlglot_employee.groupBy("age").sum("store_id")
- self.compare_spark_with_sqlglot(df, dfs)
diff --git a/tests/dataframe/integration/test_session.py b/tests/dataframe/integration/test_session.py
deleted file mode 100644
index 3bb3e20..0000000
--- a/tests/dataframe/integration/test_session.py
+++ /dev/null
@@ -1,43 +0,0 @@
-from pyspark.sql import functions as F
-
-from sqlglot.dataframe.sql import functions as SF
-from tests.dataframe.integration.dataframe_validator import DataFrameValidator
-
-
-class TestSessionFunc(DataFrameValidator):
- def test_sql_simple_select(self):
- query = "SELECT fname, lname FROM employee"
- df = self.spark.sql(query)
- dfs = self.sqlglot.sql(query)
- self.compare_spark_with_sqlglot(df, dfs)
-
- def test_sql_with_join(self):
- query = """
- SELECT
- e.employee_id
- , s.store_id
- FROM
- employee e
- INNER JOIN
- store s
- ON
- e.store_id = s.store_id
- """
- df = (
- self.spark.sql(query)
- .groupBy(F.col("store_id"))
- .agg(F.countDistinct(F.col("employee_id")))
- )
- dfs = (
- self.sqlglot.sql(query)
- .groupBy(SF.col("store_id"))
- .agg(SF.countDistinct(SF.col("employee_id")))
- )
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
-
- def test_nameless_column(self):
- query = "SELECT MAX(age) FROM employee"
- df = self.spark.sql(query)
- dfs = self.sqlglot.sql(query)
- # Spark will alias the column to `max(age)` while sqlglot will alias to `_col_0` so their schemas will differ
- self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)