summaryrefslogtreecommitdiffstats
path: root/tests/dataframe/integration
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-11-11 08:54:30 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-11-11 08:54:30 +0000
commit9ebe8c99ba4be74ccebf1b013f4e56ec09e023c1 (patch)
tree7ab2f39fbb6fd832aeea5cef45b54bfd59ba5ba5 /tests/dataframe/integration
parentAdding upstream version 9.0.6. (diff)
downloadsqlglot-9ebe8c99ba4be74ccebf1b013f4e56ec09e023c1.tar.xz
sqlglot-9ebe8c99ba4be74ccebf1b013f4e56ec09e023c1.zip
Adding upstream version 10.0.1.upstream/10.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/dataframe/integration')
-rw-r--r--tests/dataframe/integration/dataframe_validator.py52
-rw-r--r--tests/dataframe/integration/test_dataframe.py295
-rw-r--r--tests/dataframe/integration/test_session.py12
3 files changed, 245 insertions, 114 deletions
diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py
index 4a89c78..16f8922 100644
--- a/tests/dataframe/integration/dataframe_validator.py
+++ b/tests/dataframe/integration/dataframe_validator.py
@@ -1,9 +1,9 @@
-import sys
import typing as t
import unittest
import warnings
import sqlglot
+from sqlglot.helper import PYTHON_VERSION
from tests.helpers import SKIP_INTEGRATION
if t.TYPE_CHECKING:
@@ -11,7 +11,8 @@ if t.TYPE_CHECKING:
@unittest.skipIf(
- SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set"
+ SKIP_INTEGRATION or PYTHON_VERSION > (3, 10),
+ "Skipping Integration Tests since `SKIP_INTEGRATION` is set",
)
class DataFrameValidator(unittest.TestCase):
spark = None
@@ -36,7 +37,12 @@ class DataFrameValidator(unittest.TestCase):
# This is for test `test_branching_root_dataframes`
config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
- cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate()
+ cls.spark = (
+ SparkSession.builder.master("local[*]")
+ .appName("Unit-tests")
+ .config(conf=config)
+ .getOrCreate()
+ )
cls.spark.sparkContext.setLogLevel("ERROR")
cls.sqlglot = SqlglotSparkSession()
cls.spark_employee_schema = types.StructType(
@@ -50,7 +56,9 @@ class DataFrameValidator(unittest.TestCase):
)
cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
[
- sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False),
+ sqlglotSparkTypes.StructField(
+ "employee_id", sqlglotSparkTypes.IntegerType(), False
+ ),
sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
@@ -64,8 +72,12 @@ class DataFrameValidator(unittest.TestCase):
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
- cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema)
- cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema)
+ cls.df_employee = cls.spark.createDataFrame(
+ data=employee_data, schema=cls.spark_employee_schema
+ )
+ cls.dfs_employee = cls.sqlglot.createDataFrame(
+ data=employee_data, schema=cls.sqlglot_employee_schema
+ )
cls.df_employee.createOrReplaceTempView("employee")
cls.spark_store_schema = types.StructType(
@@ -80,7 +92,9 @@ class DataFrameValidator(unittest.TestCase):
[
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
- sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
+ sqlglotSparkTypes.StructField(
+ "district_id", sqlglotSparkTypes.IntegerType(), False
+ ),
sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
]
)
@@ -89,7 +103,9 @@ class DataFrameValidator(unittest.TestCase):
(2, "Arrow", 2, 2000),
]
cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
- cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema)
+ cls.dfs_store = cls.sqlglot.createDataFrame(
+ data=store_data, schema=cls.sqlglot_store_schema
+ )
cls.df_store.createOrReplaceTempView("store")
cls.spark_district_schema = types.StructType(
@@ -101,17 +117,27 @@ class DataFrameValidator(unittest.TestCase):
)
cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
[
- sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
- sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False),
- sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False),
+ sqlglotSparkTypes.StructField(
+ "district_id", sqlglotSparkTypes.IntegerType(), False
+ ),
+ sqlglotSparkTypes.StructField(
+ "district_name", sqlglotSparkTypes.StringType(), False
+ ),
+ sqlglotSparkTypes.StructField(
+ "manager_name", sqlglotSparkTypes.StringType(), False
+ ),
]
)
district_data = [
(1, "Temple", "Dogen"),
(2, "Lighthouse", "Jacob"),
]
- cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema)
- cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema)
+ cls.df_district = cls.spark.createDataFrame(
+ data=district_data, schema=cls.spark_district_schema
+ )
+ cls.dfs_district = cls.sqlglot.createDataFrame(
+ data=district_data, schema=cls.sqlglot_district_schema
+ )
cls.df_district.createOrReplaceTempView("district")
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema)
sqlglot.schema.add_table("store", cls.sqlglot_store_schema)
diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py
index c740bec..19e3b89 100644
--- a/tests/dataframe/integration/test_dataframe.py
+++ b/tests/dataframe/integration/test_dataframe.py
@@ -41,22 +41,32 @@ class TestDataframeFunc(DataFrameValidator):
def test_alias_with_select(self):
df_employee = self.df_spark_employee.alias("df_employee").select(
- self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname
+ self.df_spark_employee["employee_id"],
+ F.col("df_employee.fname"),
+ self.df_spark_employee.lname,
)
dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select(
- self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname
+ self.df_sqlglot_employee["employee_id"],
+ SF.col("dfs_employee.fname"),
+ self.df_sqlglot_employee.lname,
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_case_when_otherwise(self):
df = self.df_spark_employee.select(
- F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60"))
+ F.when(
+ (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
+ F.lit("between 40 and 60"),
+ )
.when(F.col("age") < F.lit(40), "less than 40")
.otherwise("greater than 60")
)
dfs = self.df_sqlglot_employee.select(
- SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60"))
+ SF.when(
+ (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
+ SF.lit("between 40 and 60"),
+ )
.when(SF.col("age") < SF.lit(40), "less than 40")
.otherwise("greater than 60")
)
@@ -65,15 +75,17 @@ class TestDataframeFunc(DataFrameValidator):
def test_case_when_no_otherwise(self):
df = self.df_spark_employee.select(
- F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when(
- F.col("age") < F.lit(40), "less than 40"
- )
+ F.when(
+ (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
+ F.lit("between 40 and 60"),
+ ).when(F.col("age") < F.lit(40), "less than 40")
)
dfs = self.df_sqlglot_employee.select(
- SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when(
- SF.col("age") < SF.lit(40), "less than 40"
- )
+ SF.when(
+ (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
+ SF.lit("between 40 and 60"),
+ ).when(SF.col("age") < SF.lit(40), "less than 40")
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -84,7 +96,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_where_clause_multiple_and(self):
- df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack")))
+ df_employee = self.df_spark_employee.where(
+ (F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))
+ )
dfs_employee = self.df_sqlglot_employee.where(
(SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack"))
)
@@ -106,7 +120,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_where_clause_multiple_or(self):
- df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate")))
+ df_employee = self.df_spark_employee.where(
+ (F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))
+ )
dfs_employee = self.df_sqlglot_employee.where(
(SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate"))
)
@@ -172,28 +188,43 @@ class TestDataframeFunc(DataFrameValidator):
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37))
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0))
+ df_employee = self.df_spark_employee.where(
+ self.df_spark_employee["age"] % F.lit(5) == F.lit(0)
+ )
+ dfs_employee = self.df_sqlglot_employee.where(
+ self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)
+ )
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28))
+ df_employee = self.df_spark_employee.where(
+ self.df_spark_employee["age"] + F.lit(5) > F.lit(28)
+ )
+ dfs_employee = self.df_sqlglot_employee.where(
+ self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)
+ )
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28))
+ df_employee = self.df_spark_employee.where(
+ self.df_spark_employee["age"] - F.lit(5) > F.lit(28)
+ )
+ dfs_employee = self.df_sqlglot_employee.where(
+ self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)
+ )
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
df_employee = self.df_spark_employee.where(
self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2)
)
dfs_employee = self.df_sqlglot_employee.where(
- self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2)
+ self.df_sqlglot_employee["age"] * SF.lit(0.5)
+ == self.df_sqlglot_employee["age"] / SF.lit(2)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_join_inner(self):
- df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select(
+ df_joined = self.df_spark_employee.join(
+ self.df_spark_store, on=["store_id"], how="inner"
+ ).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
F.col("lname"),
@@ -202,7 +233,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store.store_name,
self.df_spark_store["num_sales"],
)
- dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select(
+ dfs_joined = self.df_sqlglot_employee.join(
+ self.df_sqlglot_store, on=["store_id"], how="inner"
+ ).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
SF.col("lname"),
@@ -214,17 +247,27 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_inner_no_select(self):
- df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join(
- self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner"
+ df_joined = self.df_spark_employee.select(
+ F.col("store_id"), F.col("fname"), F.col("lname")
+ ).join(
+ self.df_spark_store.select(F.col("store_id"), F.col("store_name")),
+ on=["store_id"],
+ how="inner",
)
- dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner"
+ dfs_joined = self.df_sqlglot_employee.select(
+ SF.col("store_id"), SF.col("fname"), SF.col("lname")
+ ).join(
+ self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")),
+ on=["store_id"],
+ how="inner",
)
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_inner_equality_single(self):
df_joined = self.df_spark_employee.join(
- self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner"
+ self.df_spark_store,
+ on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
+ how="inner",
).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
@@ -235,7 +278,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store["num_sales"],
)
dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner"
+ self.df_sqlglot_store,
+ on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
+ how="inner",
).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
@@ -343,7 +388,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_full_outer(self):
- df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select(
+ df_joined = self.df_spark_employee.join(
+ self.df_spark_store, on=["store_id"], how="full_outer"
+ ).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
F.col("lname"),
@@ -352,7 +399,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store.store_name,
self.df_spark_store["num_sales"],
)
- dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select(
+ dfs_joined = self.df_sqlglot_employee.join(
+ self.df_sqlglot_store, on=["store_id"], how="full_outer"
+ ).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
SF.col("lname"),
@@ -365,7 +414,9 @@ class TestDataframeFunc(DataFrameValidator):
def test_triple_join(self):
df = (
- self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id)
+ self.df_employee.join(
+ self.df_store, on=self.df_employee.employee_id == self.df_store.store_id
+ )
.join(self.df_district, on=self.df_store.store_id == self.df_district.district_id)
.select(
self.df_employee.employee_id,
@@ -377,7 +428,9 @@ class TestDataframeFunc(DataFrameValidator):
)
)
dfs = (
- self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id)
+ self.dfs_employee.join(
+ self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id
+ )
.join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id)
.select(
self.dfs_employee.employee_id,
@@ -391,13 +444,13 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_join_select_and_select_start(self):
- df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join(
- self.df_spark_store, "store_id", "inner"
- )
+ df = self.df_spark_employee.select(
+ F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")
+ ).join(self.df_spark_store, "store_id", "inner")
- dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join(
- self.df_sqlglot_store, "store_id", "inner"
- )
+ dfs = self.df_sqlglot_employee.select(
+ SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
+ ).join(self.df_sqlglot_store, "store_id", "inner")
self.compare_spark_with_sqlglot(df, dfs)
@@ -485,13 +538,17 @@ class TestDataframeFunc(DataFrameValidator):
dfs_unioned = (
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"))
.unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")))
- .unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")))
+ .unionAll(
+ self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))
+ )
)
self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
def test_union_by_name(self):
- df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName(
+ df = self.df_spark_employee.select(
+ F.col("employee_id"), F.col("fname"), F.col("lname")
+ ).unionByName(
self.df_spark_store.select(
F.col("store_name").alias("lname"),
F.col("store_id").alias("employee_id"),
@@ -499,7 +556,9 @@ class TestDataframeFunc(DataFrameValidator):
)
)
- dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName(
+ dfs = self.df_sqlglot_employee.select(
+ SF.col("employee_id"), SF.col("fname"), SF.col("lname")
+ ).unionByName(
self.df_sqlglot_store.select(
SF.col("store_name").alias("lname"),
SF.col("store_id").alias("employee_id"),
@@ -537,10 +596,16 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_order_by_default(self):
- df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id"))
+ df = (
+ self.df_spark_store.groupBy(F.col("district_id"))
+ .agg(F.min("num_sales"))
+ .orderBy(F.col("district_id"))
+ )
dfs = (
- self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id"))
+ self.df_sqlglot_store.groupBy(SF.col("district_id"))
+ .agg(SF.min("num_sales"))
+ .orderBy(SF.col("district_id"))
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -594,13 +659,17 @@ class TestDataframeFunc(DataFrameValidator):
df = (
self.df_spark_store.groupBy(F.col("district_id"))
.agg(F.min("num_sales").alias("total_sales"))
- .orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last())
+ .orderBy(
+ F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()
+ )
)
dfs = (
self.df_sqlglot_store.groupBy(SF.col("district_id"))
.agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last())
+ .orderBy(
+ SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()
+ )
)
self.compare_spark_with_sqlglot(df, dfs)
@@ -609,81 +678,87 @@ class TestDataframeFunc(DataFrameValidator):
df = (
self.df_spark_store.groupBy(F.col("district_id"))
.agg(F.min("num_sales").alias("total_sales"))
- .orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first())
+ .orderBy(
+ F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()
+ )
)
dfs = (
self.df_sqlglot_store.groupBy(SF.col("district_id"))
.agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first())
+ .orderBy(
+ SF.when(
+ SF.col("district_id") == SF.lit(1), SF.col("district_id")
+ ).desc_nulls_first()
+ )
)
self.compare_spark_with_sqlglot(df, dfs)
def test_intersect(self):
- df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
- self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
- )
+ df_employee_duplicate = self.df_spark_employee.select(
+ F.col("employee_id"), F.col("store_id")
+ ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
- df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
- self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
- )
+ df_store_duplicate = self.df_spark_store.select(
+ F.col("store_id"), F.col("district_id")
+ ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.intersect(df_store_duplicate)
- dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
- self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
- )
+ dfs_employee_duplicate = self.df_sqlglot_employee.select(
+ SF.col("employee_id"), SF.col("store_id")
+ ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
- dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
- )
+ dfs_store_duplicate = self.df_sqlglot_store.select(
+ SF.col("store_id"), SF.col("district_id")
+ ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate)
self.compare_spark_with_sqlglot(df, dfs)
def test_intersect_all(self):
- df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
- self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
- )
+ df_employee_duplicate = self.df_spark_employee.select(
+ F.col("employee_id"), F.col("store_id")
+ ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
- df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
- self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
- )
+ df_store_duplicate = self.df_spark_store.select(
+ F.col("store_id"), F.col("district_id")
+ ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.intersectAll(df_store_duplicate)
- dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
- self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
- )
+ dfs_employee_duplicate = self.df_sqlglot_employee.select(
+ SF.col("employee_id"), SF.col("store_id")
+ ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
- dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
- )
+ dfs_store_duplicate = self.df_sqlglot_store.select(
+ SF.col("store_id"), SF.col("district_id")
+ ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate)
self.compare_spark_with_sqlglot(df, dfs)
def test_except_all(self):
- df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
- self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
- )
+ df_employee_duplicate = self.df_spark_employee.select(
+ F.col("employee_id"), F.col("store_id")
+ ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
- df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
- self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
- )
+ df_store_duplicate = self.df_spark_store.select(
+ F.col("store_id"), F.col("district_id")
+ ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.exceptAll(df_store_duplicate)
- dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
- self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
- )
+ dfs_employee_duplicate = self.df_sqlglot_employee.select(
+ SF.col("employee_id"), SF.col("store_id")
+ ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
- dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
- )
+ dfs_store_duplicate = self.df_sqlglot_store.select(
+ SF.col("store_id"), SF.col("district_id")
+ ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate)
@@ -721,7 +796,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_drop_na_default(self):
- df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna()
+ df = self.df_spark_employee.select(
+ F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
+ ).dropna()
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@@ -746,7 +823,9 @@ class TestDataframeFunc(DataFrameValidator):
).dropna(how="any", thresh=2)
dfs = self.df_sqlglot_employee.select(
- SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
+ SF.lit(None),
+ SF.lit(1),
+ SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
).dropna(how="any", thresh=2)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -757,13 +836,17 @@ class TestDataframeFunc(DataFrameValidator):
).dropna(thresh=1, subset="the_age")
dfs = self.df_sqlglot_employee.select(
- SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
+ SF.lit(None),
+ SF.lit(1),
+ SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
).dropna(thresh=1, subset="the_age")
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_dropna_na_function(self):
- df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop()
+ df = self.df_spark_employee.select(
+ F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
+ ).na.drop()
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@@ -772,7 +855,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_fillna_default(self):
- df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100)
+ df = self.df_spark_employee.select(
+ F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
+ ).fillna(100)
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@@ -798,7 +883,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_fillna_na_func(self):
- df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100)
+ df = self.df_spark_employee.select(
+ F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
+ ).na.fill(100)
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@@ -807,7 +894,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_replace_basic(self):
- df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100)
+ df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
+ to_replace=37, value=100
+ )
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
to_replace=37, value=100
@@ -827,9 +916,13 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_replace_mapping(self):
- df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100})
+ df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
+ {37: 100}
+ )
- dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100})
+ dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
+ {37: 100}
+ )
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -849,9 +942,9 @@ class TestDataframeFunc(DataFrameValidator):
to_replace=37, value=100
)
- dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace(
- to_replace=37, value=100
- )
+ dfs = self.df_sqlglot_employee.select(
+ SF.col("age"), SF.lit(37).alias("test_col")
+ ).na.replace(to_replace=37, value=100)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -881,16 +974,18 @@ class TestDataframeFunc(DataFrameValidator):
"first_name", "first_name_again"
)
- dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed(
- "first_name", "first_name_again"
- )
+ dfs = self.df_sqlglot_employee.select(
+ SF.col("fname").alias("first_name")
+ ).withColumnRenamed("first_name", "first_name_again")
self.compare_spark_with_sqlglot(df, dfs)
def test_drop_column_single(self):
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age")
- dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age")
+ dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop(
+ "age"
+ )
self.compare_spark_with_sqlglot(df, dfs)
@@ -906,7 +1001,9 @@ class TestDataframeFunc(DataFrameValidator):
df_sqlglot_employee_cols = self.df_sqlglot_employee.select(
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
)
- df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))
+ df_sqlglot_store_cols = self.df_sqlglot_store.select(
+ SF.col("store_id"), SF.col("store_name")
+ )
dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop(
df_sqlglot_employee_cols.age,
)
diff --git a/tests/dataframe/integration/test_session.py b/tests/dataframe/integration/test_session.py
index ff1477b..ec50034 100644
--- a/tests/dataframe/integration/test_session.py
+++ b/tests/dataframe/integration/test_session.py
@@ -23,6 +23,14 @@ class TestSessionFunc(DataFrameValidator):
ON
e.store_id = s.store_id
"""
- df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id")))
- dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id")))
+ df = (
+ self.spark.sql(query)
+ .groupBy(F.col("store_id"))
+ .agg(F.countDistinct(F.col("employee_id")))
+ )
+ dfs = (
+ self.sqlglot.sql(query)
+ .groupBy(SF.col("store_id"))
+ .agg(SF.countDistinct(SF.col("employee_id")))
+ )
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)