summaryrefslogtreecommitdiffstats
path: root/tests/dataframe
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/dataframe/integration/dataframe_validator.py52
-rw-r--r--tests/dataframe/integration/test_dataframe.py295
-rw-r--r--tests/dataframe/integration/test_session.py12
-rw-r--r--tests/dataframe/unit/dataframe_sql_validator.py12
-rw-r--r--tests/dataframe/unit/test_column.py14
-rw-r--r--tests/dataframe/unit/test_dataframe.py4
-rw-r--r--tests/dataframe/unit/test_functions.py55
-rw-r--r--tests/dataframe/unit/test_session.py11
-rw-r--r--tests/dataframe/unit/test_types.py5
-rw-r--r--tests/dataframe/unit/test_window.py32
10 files changed, 340 insertions, 152 deletions
diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py
index 4a89c78..16f8922 100644
--- a/tests/dataframe/integration/dataframe_validator.py
+++ b/tests/dataframe/integration/dataframe_validator.py
@@ -1,9 +1,9 @@
-import sys
import typing as t
import unittest
import warnings
import sqlglot
+from sqlglot.helper import PYTHON_VERSION
from tests.helpers import SKIP_INTEGRATION
if t.TYPE_CHECKING:
@@ -11,7 +11,8 @@ if t.TYPE_CHECKING:
@unittest.skipIf(
- SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set"
+ SKIP_INTEGRATION or PYTHON_VERSION > (3, 10),
+ "Skipping Integration Tests since `SKIP_INTEGRATION` is set",
)
class DataFrameValidator(unittest.TestCase):
spark = None
@@ -36,7 +37,12 @@ class DataFrameValidator(unittest.TestCase):
# This is for test `test_branching_root_dataframes`
config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
- cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate()
+ cls.spark = (
+ SparkSession.builder.master("local[*]")
+ .appName("Unit-tests")
+ .config(conf=config)
+ .getOrCreate()
+ )
cls.spark.sparkContext.setLogLevel("ERROR")
cls.sqlglot = SqlglotSparkSession()
cls.spark_employee_schema = types.StructType(
@@ -50,7 +56,9 @@ class DataFrameValidator(unittest.TestCase):
)
cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
[
- sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False),
+ sqlglotSparkTypes.StructField(
+ "employee_id", sqlglotSparkTypes.IntegerType(), False
+ ),
sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
@@ -64,8 +72,12 @@ class DataFrameValidator(unittest.TestCase):
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
- cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema)
- cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema)
+ cls.df_employee = cls.spark.createDataFrame(
+ data=employee_data, schema=cls.spark_employee_schema
+ )
+ cls.dfs_employee = cls.sqlglot.createDataFrame(
+ data=employee_data, schema=cls.sqlglot_employee_schema
+ )
cls.df_employee.createOrReplaceTempView("employee")
cls.spark_store_schema = types.StructType(
@@ -80,7 +92,9 @@ class DataFrameValidator(unittest.TestCase):
[
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
- sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
+ sqlglotSparkTypes.StructField(
+ "district_id", sqlglotSparkTypes.IntegerType(), False
+ ),
sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
]
)
@@ -89,7 +103,9 @@ class DataFrameValidator(unittest.TestCase):
(2, "Arrow", 2, 2000),
]
cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
- cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema)
+ cls.dfs_store = cls.sqlglot.createDataFrame(
+ data=store_data, schema=cls.sqlglot_store_schema
+ )
cls.df_store.createOrReplaceTempView("store")
cls.spark_district_schema = types.StructType(
@@ -101,17 +117,27 @@ class DataFrameValidator(unittest.TestCase):
)
cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
[
- sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
- sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False),
- sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False),
+ sqlglotSparkTypes.StructField(
+ "district_id", sqlglotSparkTypes.IntegerType(), False
+ ),
+ sqlglotSparkTypes.StructField(
+ "district_name", sqlglotSparkTypes.StringType(), False
+ ),
+ sqlglotSparkTypes.StructField(
+ "manager_name", sqlglotSparkTypes.StringType(), False
+ ),
]
)
district_data = [
(1, "Temple", "Dogen"),
(2, "Lighthouse", "Jacob"),
]
- cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema)
- cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema)
+ cls.df_district = cls.spark.createDataFrame(
+ data=district_data, schema=cls.spark_district_schema
+ )
+ cls.dfs_district = cls.sqlglot.createDataFrame(
+ data=district_data, schema=cls.sqlglot_district_schema
+ )
cls.df_district.createOrReplaceTempView("district")
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema)
sqlglot.schema.add_table("store", cls.sqlglot_store_schema)
diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py
index c740bec..19e3b89 100644
--- a/tests/dataframe/integration/test_dataframe.py
+++ b/tests/dataframe/integration/test_dataframe.py
@@ -41,22 +41,32 @@ class TestDataframeFunc(DataFrameValidator):
def test_alias_with_select(self):
df_employee = self.df_spark_employee.alias("df_employee").select(
- self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname
+ self.df_spark_employee["employee_id"],
+ F.col("df_employee.fname"),
+ self.df_spark_employee.lname,
)
dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select(
- self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname
+ self.df_sqlglot_employee["employee_id"],
+ SF.col("dfs_employee.fname"),
+ self.df_sqlglot_employee.lname,
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_case_when_otherwise(self):
df = self.df_spark_employee.select(
- F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60"))
+ F.when(
+ (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
+ F.lit("between 40 and 60"),
+ )
.when(F.col("age") < F.lit(40), "less than 40")
.otherwise("greater than 60")
)
dfs = self.df_sqlglot_employee.select(
- SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60"))
+ SF.when(
+ (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
+ SF.lit("between 40 and 60"),
+ )
.when(SF.col("age") < SF.lit(40), "less than 40")
.otherwise("greater than 60")
)
@@ -65,15 +75,17 @@ class TestDataframeFunc(DataFrameValidator):
def test_case_when_no_otherwise(self):
df = self.df_spark_employee.select(
- F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when(
- F.col("age") < F.lit(40), "less than 40"
- )
+ F.when(
+ (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
+ F.lit("between 40 and 60"),
+ ).when(F.col("age") < F.lit(40), "less than 40")
)
dfs = self.df_sqlglot_employee.select(
- SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when(
- SF.col("age") < SF.lit(40), "less than 40"
- )
+ SF.when(
+ (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
+ SF.lit("between 40 and 60"),
+ ).when(SF.col("age") < SF.lit(40), "less than 40")
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -84,7 +96,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_where_clause_multiple_and(self):
- df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack")))
+ df_employee = self.df_spark_employee.where(
+ (F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))
+ )
dfs_employee = self.df_sqlglot_employee.where(
(SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack"))
)
@@ -106,7 +120,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_where_clause_multiple_or(self):
- df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate")))
+ df_employee = self.df_spark_employee.where(
+ (F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))
+ )
dfs_employee = self.df_sqlglot_employee.where(
(SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate"))
)
@@ -172,28 +188,43 @@ class TestDataframeFunc(DataFrameValidator):
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37))
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0))
+ df_employee = self.df_spark_employee.where(
+ self.df_spark_employee["age"] % F.lit(5) == F.lit(0)
+ )
+ dfs_employee = self.df_sqlglot_employee.where(
+ self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)
+ )
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28))
+ df_employee = self.df_spark_employee.where(
+ self.df_spark_employee["age"] + F.lit(5) > F.lit(28)
+ )
+ dfs_employee = self.df_sqlglot_employee.where(
+ self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)
+ )
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28))
+ df_employee = self.df_spark_employee.where(
+ self.df_spark_employee["age"] - F.lit(5) > F.lit(28)
+ )
+ dfs_employee = self.df_sqlglot_employee.where(
+ self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)
+ )
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
df_employee = self.df_spark_employee.where(
self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2)
)
dfs_employee = self.df_sqlglot_employee.where(
- self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2)
+ self.df_sqlglot_employee["age"] * SF.lit(0.5)
+ == self.df_sqlglot_employee["age"] / SF.lit(2)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_join_inner(self):
- df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select(
+ df_joined = self.df_spark_employee.join(
+ self.df_spark_store, on=["store_id"], how="inner"
+ ).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
F.col("lname"),
@@ -202,7 +233,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store.store_name,
self.df_spark_store["num_sales"],
)
- dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select(
+ dfs_joined = self.df_sqlglot_employee.join(
+ self.df_sqlglot_store, on=["store_id"], how="inner"
+ ).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
SF.col("lname"),
@@ -214,17 +247,27 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_inner_no_select(self):
- df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join(
- self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner"
+ df_joined = self.df_spark_employee.select(
+ F.col("store_id"), F.col("fname"), F.col("lname")
+ ).join(
+ self.df_spark_store.select(F.col("store_id"), F.col("store_name")),
+ on=["store_id"],
+ how="inner",
)
- dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner"
+ dfs_joined = self.df_sqlglot_employee.select(
+ SF.col("store_id"), SF.col("fname"), SF.col("lname")
+ ).join(
+ self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")),
+ on=["store_id"],
+ how="inner",
)
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_inner_equality_single(self):
df_joined = self.df_spark_employee.join(
- self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner"
+ self.df_spark_store,
+ on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
+ how="inner",
).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
@@ -235,7 +278,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store["num_sales"],
)
dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner"
+ self.df_sqlglot_store,
+ on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
+ how="inner",
).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
@@ -343,7 +388,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_full_outer(self):
- df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select(
+ df_joined = self.df_spark_employee.join(
+ self.df_spark_store, on=["store_id"], how="full_outer"
+ ).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
F.col("lname"),
@@ -352,7 +399,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store.store_name,
self.df_spark_store["num_sales"],
)
- dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select(
+ dfs_joined = self.df_sqlglot_employee.join(
+ self.df_sqlglot_store, on=["store_id"], how="full_outer"
+ ).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
SF.col("lname"),
@@ -365,7 +414,9 @@ class TestDataframeFunc(DataFrameValidator):
def test_triple_join(self):
df = (
- self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id)
+ self.df_employee.join(
+ self.df_store, on=self.df_employee.employee_id == self.df_store.store_id
+ )
.join(self.df_district, on=self.df_store.store_id == self.df_district.district_id)
.select(
self.df_employee.employee_id,
@@ -377,7 +428,9 @@ class TestDataframeFunc(DataFrameValidator):
)
)
dfs = (
- self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id)
+ self.dfs_employee.join(
+ self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id
+ )
.join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id)
.select(
self.dfs_employee.employee_id,
@@ -391,13 +444,13 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_join_select_and_select_start(self):
- df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join(
- self.df_spark_store, "store_id", "inner"
- )
+ df = self.df_spark_employee.select(
+ F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")
+ ).join(self.df_spark_store, "store_id", "inner")
- dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join(
- self.df_sqlglot_store, "store_id", "inner"
- )
+ dfs = self.df_sqlglot_employee.select(
+ SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
+ ).join(self.df_sqlglot_store, "store_id", "inner")
self.compare_spark_with_sqlglot(df, dfs)
@@ -485,13 +538,17 @@ class TestDataframeFunc(DataFrameValidator):
dfs_unioned = (
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"))
.unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")))
- .unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")))
+ .unionAll(
+ self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))
+ )
)
self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
def test_union_by_name(self):
- df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName(
+ df = self.df_spark_employee.select(
+ F.col("employee_id"), F.col("fname"), F.col("lname")
+ ).unionByName(
self.df_spark_store.select(
F.col("store_name").alias("lname"),
F.col("store_id").alias("employee_id"),
@@ -499,7 +556,9 @@ class TestDataframeFunc(DataFrameValidator):
)
)
- dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName(
+ dfs = self.df_sqlglot_employee.select(
+ SF.col("employee_id"), SF.col("fname"), SF.col("lname")
+ ).unionByName(
self.df_sqlglot_store.select(
SF.col("store_name").alias("lname"),
SF.col("store_id").alias("employee_id"),
@@ -537,10 +596,16 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_order_by_default(self):
- df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id"))
+ df = (
+ self.df_spark_store.groupBy(F.col("district_id"))
+ .agg(F.min("num_sales"))
+ .orderBy(F.col("district_id"))
+ )
dfs = (
- self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id"))
+ self.df_sqlglot_store.groupBy(SF.col("district_id"))
+ .agg(SF.min("num_sales"))
+ .orderBy(SF.col("district_id"))
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -594,13 +659,17 @@ class TestDataframeFunc(DataFrameValidator):
df = (
self.df_spark_store.groupBy(F.col("district_id"))
.agg(F.min("num_sales").alias("total_sales"))
- .orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last())
+ .orderBy(
+ F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()
+ )
)
dfs = (
self.df_sqlglot_store.groupBy(SF.col("district_id"))
.agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last())
+ .orderBy(
+ SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()
+ )
)
self.compare_spark_with_sqlglot(df, dfs)
@@ -609,81 +678,87 @@ class TestDataframeFunc(DataFrameValidator):
df = (
self.df_spark_store.groupBy(F.col("district_id"))
.agg(F.min("num_sales").alias("total_sales"))
- .orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first())
+ .orderBy(
+ F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()
+ )
)
dfs = (
self.df_sqlglot_store.groupBy(SF.col("district_id"))
.agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first())
+ .orderBy(
+ SF.when(
+ SF.col("district_id") == SF.lit(1), SF.col("district_id")
+ ).desc_nulls_first()
+ )
)
self.compare_spark_with_sqlglot(df, dfs)
def test_intersect(self):
- df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
- self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
- )
+ df_employee_duplicate = self.df_spark_employee.select(
+ F.col("employee_id"), F.col("store_id")
+ ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
- df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
- self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
- )
+ df_store_duplicate = self.df_spark_store.select(
+ F.col("store_id"), F.col("district_id")
+ ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.intersect(df_store_duplicate)
- dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
- self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
- )
+ dfs_employee_duplicate = self.df_sqlglot_employee.select(
+ SF.col("employee_id"), SF.col("store_id")
+ ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
- dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
- )
+ dfs_store_duplicate = self.df_sqlglot_store.select(
+ SF.col("store_id"), SF.col("district_id")
+ ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate)
self.compare_spark_with_sqlglot(df, dfs)
def test_intersect_all(self):
- df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
- self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
- )
+ df_employee_duplicate = self.df_spark_employee.select(
+ F.col("employee_id"), F.col("store_id")
+ ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
- df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
- self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
- )
+ df_store_duplicate = self.df_spark_store.select(
+ F.col("store_id"), F.col("district_id")
+ ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.intersectAll(df_store_duplicate)
- dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
- self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
- )
+ dfs_employee_duplicate = self.df_sqlglot_employee.select(
+ SF.col("employee_id"), SF.col("store_id")
+ ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
- dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
- )
+ dfs_store_duplicate = self.df_sqlglot_store.select(
+ SF.col("store_id"), SF.col("district_id")
+ ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate)
self.compare_spark_with_sqlglot(df, dfs)
def test_except_all(self):
- df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
- self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
- )
+ df_employee_duplicate = self.df_spark_employee.select(
+ F.col("employee_id"), F.col("store_id")
+ ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
- df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
- self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
- )
+ df_store_duplicate = self.df_spark_store.select(
+ F.col("store_id"), F.col("district_id")
+ ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.exceptAll(df_store_duplicate)
- dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
- self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
- )
+ dfs_employee_duplicate = self.df_sqlglot_employee.select(
+ SF.col("employee_id"), SF.col("store_id")
+ ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
- dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
- )
+ dfs_store_duplicate = self.df_sqlglot_store.select(
+ SF.col("store_id"), SF.col("district_id")
+ ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate)
@@ -721,7 +796,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_drop_na_default(self):
- df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna()
+ df = self.df_spark_employee.select(
+ F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
+ ).dropna()
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@@ -746,7 +823,9 @@ class TestDataframeFunc(DataFrameValidator):
).dropna(how="any", thresh=2)
dfs = self.df_sqlglot_employee.select(
- SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
+ SF.lit(None),
+ SF.lit(1),
+ SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
).dropna(how="any", thresh=2)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -757,13 +836,17 @@ class TestDataframeFunc(DataFrameValidator):
).dropna(thresh=1, subset="the_age")
dfs = self.df_sqlglot_employee.select(
- SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
+ SF.lit(None),
+ SF.lit(1),
+ SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
).dropna(thresh=1, subset="the_age")
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_dropna_na_function(self):
- df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop()
+ df = self.df_spark_employee.select(
+ F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
+ ).na.drop()
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@@ -772,7 +855,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_fillna_default(self):
- df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100)
+ df = self.df_spark_employee.select(
+ F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
+ ).fillna(100)
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@@ -798,7 +883,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_fillna_na_func(self):
- df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100)
+ df = self.df_spark_employee.select(
+ F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
+ ).na.fill(100)
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@@ -807,7 +894,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_replace_basic(self):
- df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100)
+ df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
+ to_replace=37, value=100
+ )
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
to_replace=37, value=100
@@ -827,9 +916,13 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_replace_mapping(self):
- df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100})
+ df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
+ {37: 100}
+ )
- dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100})
+ dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
+ {37: 100}
+ )
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -849,9 +942,9 @@ class TestDataframeFunc(DataFrameValidator):
to_replace=37, value=100
)
- dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace(
- to_replace=37, value=100
- )
+ dfs = self.df_sqlglot_employee.select(
+ SF.col("age"), SF.lit(37).alias("test_col")
+ ).na.replace(to_replace=37, value=100)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -881,16 +974,18 @@ class TestDataframeFunc(DataFrameValidator):
"first_name", "first_name_again"
)
- dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed(
- "first_name", "first_name_again"
- )
+ dfs = self.df_sqlglot_employee.select(
+ SF.col("fname").alias("first_name")
+ ).withColumnRenamed("first_name", "first_name_again")
self.compare_spark_with_sqlglot(df, dfs)
def test_drop_column_single(self):
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age")
- dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age")
+ dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop(
+ "age"
+ )
self.compare_spark_with_sqlglot(df, dfs)
@@ -906,7 +1001,9 @@ class TestDataframeFunc(DataFrameValidator):
df_sqlglot_employee_cols = self.df_sqlglot_employee.select(
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
)
- df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))
+ df_sqlglot_store_cols = self.df_sqlglot_store.select(
+ SF.col("store_id"), SF.col("store_name")
+ )
dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop(
df_sqlglot_employee_cols.age,
)
diff --git a/tests/dataframe/integration/test_session.py b/tests/dataframe/integration/test_session.py
index ff1477b..ec50034 100644
--- a/tests/dataframe/integration/test_session.py
+++ b/tests/dataframe/integration/test_session.py
@@ -23,6 +23,14 @@ class TestSessionFunc(DataFrameValidator):
ON
e.store_id = s.store_id
"""
- df = self.spark.sql(query).groupBy(F.col("store_id")).agg(F.countDistinct(F.col("employee_id")))
- dfs = self.sqlglot.sql(query).groupBy(SF.col("store_id")).agg(SF.countDistinct(SF.col("employee_id")))
+ df = (
+ self.spark.sql(query)
+ .groupBy(F.col("store_id"))
+ .agg(F.countDistinct(F.col("employee_id")))
+ )
+ dfs = (
+ self.sqlglot.sql(query)
+ .groupBy(SF.col("store_id"))
+ .agg(SF.countDistinct(SF.col("employee_id")))
+ )
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py
index fc56553..32ff8f2 100644
--- a/tests/dataframe/unit/dataframe_sql_validator.py
+++ b/tests/dataframe/unit/dataframe_sql_validator.py
@@ -25,11 +25,17 @@ class DataFrameSQLValidator(unittest.TestCase):
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
- self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema)
+ self.df_employee = self.spark.createDataFrame(
+ data=employee_data, schema=self.employee_schema
+ )
- def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False):
+ def compare_sql(
+ self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
+ ):
actual_sqls = df.sql(pretty=pretty)
- expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements
+ expected_statements = (
+ [expected_statements] if isinstance(expected_statements, str) else expected_statements
+ )
self.assertEqual(len(expected_statements), len(actual_sqls))
for expected, actual in zip(expected_statements, actual_sqls):
self.assertEqual(expected, actual)
diff --git a/tests/dataframe/unit/test_column.py b/tests/dataframe/unit/test_column.py
index 977971e..da18502 100644
--- a/tests/dataframe/unit/test_column.py
+++ b/tests/dataframe/unit/test_column.py
@@ -26,12 +26,14 @@ class TestDataframeColumn(unittest.TestCase):
def test_and(self):
self.assertEqual(
- "cola = colb AND colc = cold", ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql()
+ "cola = colb AND colc = cold",
+ ((F.col("cola") == F.col("colb")) & (F.col("colc") == F.col("cold"))).sql(),
)
def test_or(self):
self.assertEqual(
- "cola = colb OR colc = cold", ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql()
+ "cola = colb OR colc = cold",
+ ((F.col("cola") == F.col("colb")) | (F.col("colc") == F.col("cold"))).sql(),
)
def test_mod(self):
@@ -112,7 +114,9 @@ class TestDataframeColumn(unittest.TestCase):
def test_when_otherwise(self):
self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.when(F.col("cola") == 1, 2).sql())
- self.assertEqual("CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql())
+ self.assertEqual(
+ "CASE WHEN cola = 1 THEN 2 END", F.col("cola").when(F.col("cola") == 1, 2).sql()
+ )
self.assertEqual(
"CASE WHEN cola = 1 THEN 2 WHEN colb = 2 THEN 3 END",
(F.when(F.col("cola") == 1, 2).when(F.col("colb") == 2, 3)).sql(),
@@ -148,7 +152,9 @@ class TestDataframeColumn(unittest.TestCase):
self.assertEqual(
"cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) "
"AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)",
- F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(),
+ F.col("cola")
+ .between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1))
+ .sql(),
)
def test_over(self):
diff --git a/tests/dataframe/unit/test_dataframe.py b/tests/dataframe/unit/test_dataframe.py
index c222cac..e36667b 100644
--- a/tests/dataframe/unit/test_dataframe.py
+++ b/tests/dataframe/unit/test_dataframe.py
@@ -9,7 +9,9 @@ class TestDataframe(DataFrameSQLValidator):
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
def test_columns(self):
- self.assertEqual(["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns)
+ self.assertEqual(
+ ["employee_id", "fname", "lname", "age", "store_id"], self.df_employee.columns
+ )
def test_cache(self):
df = self.df_employee.select("fname").cache()
diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py
index eadbb93..8e5e5cd 100644
--- a/tests/dataframe/unit/test_functions.py
+++ b/tests/dataframe/unit/test_functions.py
@@ -925,12 +925,17 @@ class TestFunctions(unittest.TestCase):
col = SF.window(SF.col("cola"), "10 minutes")
self.assertEqual("WINDOW(cola, '10 minutes')", col.sql())
col_all_values = SF.window("cola", "2 minutes 30 seconds", "30 seconds", "15 seconds")
- self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql())
+ self.assertEqual(
+ "WINDOW(cola, '2 minutes 30 seconds', '30 seconds', '15 seconds')", col_all_values.sql()
+ )
col_no_start_time = SF.window("cola", "2 minutes 30 seconds", "30 seconds")
- self.assertEqual("WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql())
+ self.assertEqual(
+ "WINDOW(cola, '2 minutes 30 seconds', '30 seconds')", col_no_start_time.sql()
+ )
col_no_slide = SF.window("cola", "2 minutes 30 seconds", startTime="15 seconds")
self.assertEqual(
- "WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')", col_no_slide.sql()
+ "WINDOW(cola, '2 minutes 30 seconds', '2 minutes 30 seconds', '15 seconds')",
+ col_no_slide.sql(),
)
def test_session_window(self):
@@ -1359,9 +1364,13 @@ class TestFunctions(unittest.TestCase):
def test_from_json(self):
col_str = SF.from_json("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
+ self.assertEqual(
+ "FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
+ )
col = SF.from_json(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
+ self.assertEqual(
+ "FROM_JSON(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
+ )
col_no_option = SF.from_json("cola", "cola INT")
self.assertEqual("FROM_JSON(cola, 'cola INT')", col_no_option.sql())
@@ -1375,7 +1384,9 @@ class TestFunctions(unittest.TestCase):
def test_schema_of_json(self):
col_str = SF.schema_of_json("cola", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
+ self.assertEqual(
+ "SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
+ )
col = SF.schema_of_json(SF.col("cola"), dict(timestampFormat="dd/MM/yyyy"))
self.assertEqual("SCHEMA_OF_JSON(cola, MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
col_no_option = SF.schema_of_json("cola")
@@ -1429,7 +1440,10 @@ class TestFunctions(unittest.TestCase):
col = SF.array_sort(SF.col("cola"))
self.assertEqual("ARRAY_SORT(cola)", col.sql())
col_comparator = SF.array_sort(
- "cola", lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(SF.length(y) - SF.length(x))
+ "cola",
+ lambda x, y: SF.when(x.isNull() | y.isNull(), SF.lit(0)).otherwise(
+ SF.length(y) - SF.length(x)
+ ),
)
self.assertEqual(
"ARRAY_SORT(cola, (x, y) -> CASE WHEN x IS NULL OR y IS NULL THEN 0 ELSE LENGTH(y) - LENGTH(x) END)",
@@ -1504,9 +1518,13 @@ class TestFunctions(unittest.TestCase):
def test_from_csv(self):
col_str = SF.from_csv("cola", "cola INT", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql())
+ self.assertEqual(
+ "FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col_str.sql()
+ )
col = SF.from_csv(SF.col("cola"), "cola INT", dict(timestampFormat="dd/MM/yyyy"))
- self.assertEqual("FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql())
+ self.assertEqual(
+ "FROM_CSV(cola, 'cola INT', MAP('timestampFormat', 'dd/MM/yyyy'))", col.sql()
+ )
col_no_option = SF.from_csv("cola", "cola INT")
self.assertEqual("FROM_CSV(cola, 'cola INT')", col_no_option.sql())
@@ -1535,7 +1553,9 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("TRANSFORM(cola, (x, i) -> x * i)", col.sql())
col_custom_names = SF.transform("cola", lambda target, row_count: target * row_count)
- self.assertEqual("TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql())
+ self.assertEqual(
+ "TRANSFORM(cola, (target, row_count) -> target * row_count)", col_custom_names.sql()
+ )
def test_exists(self):
col_str = SF.exists("cola", lambda x: x % 2 == 0)
@@ -1558,10 +1578,13 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("FILTER(cola, x -> MONTH(TO_DATE(x)) > 6)", col_str.sql())
col = SF.filter(SF.col("cola"), lambda x, i: SF.month(SF.to_date(x)) > SF.lit(i))
self.assertEqual("FILTER(cola, (x, i) -> MONTH(TO_DATE(x)) > i)", col.sql())
- col_custom_names = SF.filter("cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count))
+ col_custom_names = SF.filter(
+ "cola", lambda target, row_count: SF.month(SF.to_date(target)) > SF.lit(row_count)
+ )
self.assertEqual(
- "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)", col_custom_names.sql()
+ "FILTER(cola, (target, row_count) -> MONTH(TO_DATE(target)) > row_count)",
+ col_custom_names.sql(),
)
def test_zip_with(self):
@@ -1570,7 +1593,9 @@ class TestFunctions(unittest.TestCase):
col = SF.zip_with(SF.col("cola"), SF.col("colb"), lambda x, y: SF.concat_ws("_", x, y))
self.assertEqual("ZIP_WITH(cola, colb, (x, y) -> CONCAT_WS('_', x, y))", col.sql())
col_custom_names = SF.zip_with("cola", "colb", lambda l, r: SF.concat_ws("_", l, r))
- self.assertEqual("ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql())
+ self.assertEqual(
+ "ZIP_WITH(cola, colb, (l, r) -> CONCAT_WS('_', l, r))", col_custom_names.sql()
+ )
def test_transform_keys(self):
col_str = SF.transform_keys("cola", lambda k, v: SF.upper(k))
@@ -1586,7 +1611,9 @@ class TestFunctions(unittest.TestCase):
col = SF.transform_values(SF.col("cola"), lambda k, v: SF.upper(v))
self.assertEqual("TRANSFORM_VALUES(cola, (k, v) -> UPPER(v))", col.sql())
col_custom_names = SF.transform_values("cola", lambda _, value: SF.upper(value))
- self.assertEqual("TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql())
+ self.assertEqual(
+ "TRANSFORM_VALUES(cola, (_, value) -> UPPER(value))", col_custom_names.sql()
+ )
def test_map_filter(self):
col_str = SF.map_filter("cola", lambda k, v: k > v)
diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py
index 158dcec..7e8bfad 100644
--- a/tests/dataframe/unit/test_session.py
+++ b/tests/dataframe/unit/test_session.py
@@ -21,9 +21,7 @@ class TestDataframeSession(DataFrameSQLValidator):
def test_cdf_no_schema(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
- expected = (
- "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
- )
+ expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
self.compare_sql(df, expected)
def test_cdf_row_mixed_primitives(self):
@@ -77,7 +75,8 @@ class TestDataframeSession(DataFrameSQLValidator):
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
self.assertIn(
- "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`", df.sql(pretty=False)
+ "SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`",
+ df.sql(pretty=False),
)
@mock.patch("sqlglot.schema", MappingSchema())
@@ -104,9 +103,7 @@ class TestDataframeSession(DataFrameSQLValidator):
query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
- expected = (
- "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
- )
+ expected = "INSERT INTO new_table SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`"
self.compare_sql(df, expected)
def test_session_create_builder_patterns(self):
diff --git a/tests/dataframe/unit/test_types.py b/tests/dataframe/unit/test_types.py
index 1f6c5dc..52f5d72 100644
--- a/tests/dataframe/unit/test_types.py
+++ b/tests/dataframe/unit/test_types.py
@@ -53,7 +53,10 @@ class TestDataframeTypes(unittest.TestCase):
self.assertEqual("array<int>", types.ArrayType(types.IntegerType()).simpleString())
def test_map(self):
- self.assertEqual("map<int, string>", types.MapType(types.IntegerType(), types.StringType()).simpleString())
+ self.assertEqual(
+ "map<int, string>",
+ types.MapType(types.IntegerType(), types.StringType()).simpleString(),
+ )
def test_struct_field(self):
self.assertEqual("cola:int", types.StructField("cola", types.IntegerType()).simpleString())
diff --git a/tests/dataframe/unit/test_window.py b/tests/dataframe/unit/test_window.py
index eea4582..70a868a 100644
--- a/tests/dataframe/unit/test_window.py
+++ b/tests/dataframe/unit/test_window.py
@@ -39,22 +39,38 @@ class TestDataframeWindow(unittest.TestCase):
def test_window_rows_unbounded(self):
rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2)
- self.assertEqual("OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql())
+ self.assertEqual(
+ "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
+ rows_between_unbounded_start.sql(),
+ )
rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing)
- self.assertEqual("OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql())
- rows_between_unbounded_both = Window.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual(
- "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql()
+ "OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
+ rows_between_unbounded_end.sql(),
+ )
+ rows_between_unbounded_both = Window.rowsBetween(
+ Window.unboundedPreceding, Window.unboundedFollowing
+ )
+ self.assertEqual(
+ "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
+ rows_between_unbounded_both.sql(),
)
def test_window_range_unbounded(self):
range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2)
self.assertEqual(
- "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql()
+ "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)",
+ range_between_unbounded_start.sql(),
)
range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing)
- self.assertEqual("OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql())
- range_between_unbounded_both = Window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
self.assertEqual(
- "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql()
+ "OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)",
+ range_between_unbounded_end.sql(),
+ )
+ range_between_unbounded_both = Window.rangeBetween(
+ Window.unboundedPreceding, Window.unboundedFollowing
+ )
+ self.assertEqual(
+ "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
+ range_between_unbounded_both.sql(),
)