summaryrefslogtreecommitdiffstats
path: root/tests/dataframe/integration/test_dataframe.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dataframe/integration/test_dataframe.py')
-rw-r--r--tests/dataframe/integration/test_dataframe.py295
1 files changed, 196 insertions, 99 deletions
diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py
index c740bec..19e3b89 100644
--- a/tests/dataframe/integration/test_dataframe.py
+++ b/tests/dataframe/integration/test_dataframe.py
@@ -41,22 +41,32 @@ class TestDataframeFunc(DataFrameValidator):
def test_alias_with_select(self):
df_employee = self.df_spark_employee.alias("df_employee").select(
- self.df_spark_employee["employee_id"], F.col("df_employee.fname"), self.df_spark_employee.lname
+ self.df_spark_employee["employee_id"],
+ F.col("df_employee.fname"),
+ self.df_spark_employee.lname,
)
dfs_employee = self.df_sqlglot_employee.alias("dfs_employee").select(
- self.df_sqlglot_employee["employee_id"], SF.col("dfs_employee.fname"), self.df_sqlglot_employee.lname
+ self.df_sqlglot_employee["employee_id"],
+ SF.col("dfs_employee.fname"),
+ self.df_sqlglot_employee.lname,
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_case_when_otherwise(self):
df = self.df_spark_employee.select(
- F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60"))
+ F.when(
+ (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
+ F.lit("between 40 and 60"),
+ )
.when(F.col("age") < F.lit(40), "less than 40")
.otherwise("greater than 60")
)
dfs = self.df_sqlglot_employee.select(
- SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60"))
+ SF.when(
+ (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
+ SF.lit("between 40 and 60"),
+ )
.when(SF.col("age") < SF.lit(40), "less than 40")
.otherwise("greater than 60")
)
@@ -65,15 +75,17 @@ class TestDataframeFunc(DataFrameValidator):
def test_case_when_no_otherwise(self):
df = self.df_spark_employee.select(
- F.when((F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)), F.lit("between 40 and 60")).when(
- F.col("age") < F.lit(40), "less than 40"
- )
+ F.when(
+ (F.col("age") >= F.lit(40)) & (F.col("age") <= F.lit(60)),
+ F.lit("between 40 and 60"),
+ ).when(F.col("age") < F.lit(40), "less than 40")
)
dfs = self.df_sqlglot_employee.select(
- SF.when((SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)), SF.lit("between 40 and 60")).when(
- SF.col("age") < SF.lit(40), "less than 40"
- )
+ SF.when(
+ (SF.col("age") >= SF.lit(40)) & (SF.col("age") <= SF.lit(60)),
+ SF.lit("between 40 and 60"),
+ ).when(SF.col("age") < SF.lit(40), "less than 40")
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -84,7 +96,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_where_clause_multiple_and(self):
- df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack")))
+ df_employee = self.df_spark_employee.where(
+ (F.col("age") == F.lit(37)) & (F.col("fname") == F.lit("Jack"))
+ )
dfs_employee = self.df_sqlglot_employee.where(
(SF.col("age") == SF.lit(37)) & (SF.col("fname") == SF.lit("Jack"))
)
@@ -106,7 +120,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_where_clause_multiple_or(self):
- df_employee = self.df_spark_employee.where((F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate")))
+ df_employee = self.df_spark_employee.where(
+ (F.col("age") == F.lit(37)) | (F.col("fname") == F.lit("Kate"))
+ )
dfs_employee = self.df_sqlglot_employee.where(
(SF.col("age") == SF.lit(37)) | (SF.col("fname") == SF.lit("Kate"))
)
@@ -172,28 +188,43 @@ class TestDataframeFunc(DataFrameValidator):
dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] == SF.lit(37))
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] % F.lit(5) == F.lit(0))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0))
+ df_employee = self.df_spark_employee.where(
+ self.df_spark_employee["age"] % F.lit(5) == F.lit(0)
+ )
+ dfs_employee = self.df_sqlglot_employee.where(
+ self.df_sqlglot_employee["age"] % SF.lit(5) == SF.lit(0)
+ )
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] + F.lit(5) > F.lit(28))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28))
+ df_employee = self.df_spark_employee.where(
+ self.df_spark_employee["age"] + F.lit(5) > F.lit(28)
+ )
+ dfs_employee = self.df_sqlglot_employee.where(
+ self.df_sqlglot_employee["age"] + SF.lit(5) > SF.lit(28)
+ )
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
- df_employee = self.df_spark_employee.where(self.df_spark_employee["age"] - F.lit(5) > F.lit(28))
- dfs_employee = self.df_sqlglot_employee.where(self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28))
+ df_employee = self.df_spark_employee.where(
+ self.df_spark_employee["age"] - F.lit(5) > F.lit(28)
+ )
+ dfs_employee = self.df_sqlglot_employee.where(
+ self.df_sqlglot_employee["age"] - SF.lit(5) > SF.lit(28)
+ )
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
df_employee = self.df_spark_employee.where(
self.df_spark_employee["age"] * F.lit(0.5) == self.df_spark_employee["age"] / F.lit(2)
)
dfs_employee = self.df_sqlglot_employee.where(
- self.df_sqlglot_employee["age"] * SF.lit(0.5) == self.df_sqlglot_employee["age"] / SF.lit(2)
+ self.df_sqlglot_employee["age"] * SF.lit(0.5)
+ == self.df_sqlglot_employee["age"] / SF.lit(2)
)
self.compare_spark_with_sqlglot(df_employee, dfs_employee)
def test_join_inner(self):
- df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="inner").select(
+ df_joined = self.df_spark_employee.join(
+ self.df_spark_store, on=["store_id"], how="inner"
+ ).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
F.col("lname"),
@@ -202,7 +233,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store.store_name,
self.df_spark_store["num_sales"],
)
- dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="inner").select(
+ dfs_joined = self.df_sqlglot_employee.join(
+ self.df_sqlglot_store, on=["store_id"], how="inner"
+ ).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
SF.col("lname"),
@@ -214,17 +247,27 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_inner_no_select(self):
- df_joined = self.df_spark_employee.select(F.col("store_id"), F.col("fname"), F.col("lname")).join(
- self.df_spark_store.select(F.col("store_id"), F.col("store_name")), on=["store_id"], how="inner"
+ df_joined = self.df_spark_employee.select(
+ F.col("store_id"), F.col("fname"), F.col("lname")
+ ).join(
+ self.df_spark_store.select(F.col("store_id"), F.col("store_name")),
+ on=["store_id"],
+ how="inner",
)
- dfs_joined = self.df_sqlglot_employee.select(SF.col("store_id"), SF.col("fname"), SF.col("lname")).join(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")), on=["store_id"], how="inner"
+ dfs_joined = self.df_sqlglot_employee.select(
+ SF.col("store_id"), SF.col("fname"), SF.col("lname")
+ ).join(
+ self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")),
+ on=["store_id"],
+ how="inner",
)
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_inner_equality_single(self):
df_joined = self.df_spark_employee.join(
- self.df_spark_store, on=self.df_spark_employee.store_id == self.df_spark_store.store_id, how="inner"
+ self.df_spark_store,
+ on=self.df_spark_employee.store_id == self.df_spark_store.store_id,
+ how="inner",
).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
@@ -235,7 +278,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store["num_sales"],
)
dfs_joined = self.df_sqlglot_employee.join(
- self.df_sqlglot_store, on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id, how="inner"
+ self.df_sqlglot_store,
+ on=self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id,
+ how="inner",
).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
@@ -343,7 +388,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
def test_join_full_outer(self):
- df_joined = self.df_spark_employee.join(self.df_spark_store, on=["store_id"], how="full_outer").select(
+ df_joined = self.df_spark_employee.join(
+ self.df_spark_store, on=["store_id"], how="full_outer"
+ ).select(
self.df_spark_employee.employee_id,
self.df_spark_employee["fname"],
F.col("lname"),
@@ -352,7 +399,9 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_store.store_name,
self.df_spark_store["num_sales"],
)
- dfs_joined = self.df_sqlglot_employee.join(self.df_sqlglot_store, on=["store_id"], how="full_outer").select(
+ dfs_joined = self.df_sqlglot_employee.join(
+ self.df_sqlglot_store, on=["store_id"], how="full_outer"
+ ).select(
self.df_sqlglot_employee.employee_id,
self.df_sqlglot_employee["fname"],
SF.col("lname"),
@@ -365,7 +414,9 @@ class TestDataframeFunc(DataFrameValidator):
def test_triple_join(self):
df = (
- self.df_employee.join(self.df_store, on=self.df_employee.employee_id == self.df_store.store_id)
+ self.df_employee.join(
+ self.df_store, on=self.df_employee.employee_id == self.df_store.store_id
+ )
.join(self.df_district, on=self.df_store.store_id == self.df_district.district_id)
.select(
self.df_employee.employee_id,
@@ -377,7 +428,9 @@ class TestDataframeFunc(DataFrameValidator):
)
)
dfs = (
- self.dfs_employee.join(self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id)
+ self.dfs_employee.join(
+ self.dfs_store, on=self.dfs_employee.employee_id == self.dfs_store.store_id
+ )
.join(self.dfs_district, on=self.dfs_store.store_id == self.dfs_district.district_id)
.select(
self.dfs_employee.employee_id,
@@ -391,13 +444,13 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_join_select_and_select_start(self):
- df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")).join(
- self.df_spark_store, "store_id", "inner"
- )
+ df = self.df_spark_employee.select(
+ F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")
+ ).join(self.df_spark_store, "store_id", "inner")
- dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")).join(
- self.df_sqlglot_store, "store_id", "inner"
- )
+ dfs = self.df_sqlglot_employee.select(
+ SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
+ ).join(self.df_sqlglot_store, "store_id", "inner")
self.compare_spark_with_sqlglot(df, dfs)
@@ -485,13 +538,17 @@ class TestDataframeFunc(DataFrameValidator):
dfs_unioned = (
self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"))
.unionAll(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name")))
- .unionAll(self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name")))
+ .unionAll(
+ self.df_sqlglot_district.select(SF.col("district_id"), SF.col("district_name"))
+ )
)
self.compare_spark_with_sqlglot(df_unioned, dfs_unioned)
def test_union_by_name(self):
- df = self.df_spark_employee.select(F.col("employee_id"), F.col("fname"), F.col("lname")).unionByName(
+ df = self.df_spark_employee.select(
+ F.col("employee_id"), F.col("fname"), F.col("lname")
+ ).unionByName(
self.df_spark_store.select(
F.col("store_name").alias("lname"),
F.col("store_id").alias("employee_id"),
@@ -499,7 +556,9 @@ class TestDataframeFunc(DataFrameValidator):
)
)
- dfs = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("fname"), SF.col("lname")).unionByName(
+ dfs = self.df_sqlglot_employee.select(
+ SF.col("employee_id"), SF.col("fname"), SF.col("lname")
+ ).unionByName(
self.df_sqlglot_store.select(
SF.col("store_name").alias("lname"),
SF.col("store_id").alias("employee_id"),
@@ -537,10 +596,16 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_order_by_default(self):
- df = self.df_spark_store.groupBy(F.col("district_id")).agg(F.min("num_sales")).orderBy(F.col("district_id"))
+ df = (
+ self.df_spark_store.groupBy(F.col("district_id"))
+ .agg(F.min("num_sales"))
+ .orderBy(F.col("district_id"))
+ )
dfs = (
- self.df_sqlglot_store.groupBy(SF.col("district_id")).agg(SF.min("num_sales")).orderBy(SF.col("district_id"))
+ self.df_sqlglot_store.groupBy(SF.col("district_id"))
+ .agg(SF.min("num_sales"))
+ .orderBy(SF.col("district_id"))
)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -594,13 +659,17 @@ class TestDataframeFunc(DataFrameValidator):
df = (
self.df_spark_store.groupBy(F.col("district_id"))
.agg(F.min("num_sales").alias("total_sales"))
- .orderBy(F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last())
+ .orderBy(
+ F.when(F.col("district_id") == F.lit(2), F.col("district_id")).asc_nulls_last()
+ )
)
dfs = (
self.df_sqlglot_store.groupBy(SF.col("district_id"))
.agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last())
+ .orderBy(
+ SF.when(SF.col("district_id") == SF.lit(2), SF.col("district_id")).asc_nulls_last()
+ )
)
self.compare_spark_with_sqlglot(df, dfs)
@@ -609,81 +678,87 @@ class TestDataframeFunc(DataFrameValidator):
df = (
self.df_spark_store.groupBy(F.col("district_id"))
.agg(F.min("num_sales").alias("total_sales"))
- .orderBy(F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first())
+ .orderBy(
+ F.when(F.col("district_id") == F.lit(1), F.col("district_id")).desc_nulls_first()
+ )
)
dfs = (
self.df_sqlglot_store.groupBy(SF.col("district_id"))
.agg(SF.min("num_sales").alias("total_sales"))
- .orderBy(SF.when(SF.col("district_id") == SF.lit(1), SF.col("district_id")).desc_nulls_first())
+ .orderBy(
+ SF.when(
+ SF.col("district_id") == SF.lit(1), SF.col("district_id")
+ ).desc_nulls_first()
+ )
)
self.compare_spark_with_sqlglot(df, dfs)
def test_intersect(self):
- df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
- self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
- )
+ df_employee_duplicate = self.df_spark_employee.select(
+ F.col("employee_id"), F.col("store_id")
+ ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
- df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
- self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
- )
+ df_store_duplicate = self.df_spark_store.select(
+ F.col("store_id"), F.col("district_id")
+ ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.intersect(df_store_duplicate)
- dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
- self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
- )
+ dfs_employee_duplicate = self.df_sqlglot_employee.select(
+ SF.col("employee_id"), SF.col("store_id")
+ ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
- dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
- )
+ dfs_store_duplicate = self.df_sqlglot_store.select(
+ SF.col("store_id"), SF.col("district_id")
+ ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.intersect(dfs_store_duplicate)
self.compare_spark_with_sqlglot(df, dfs)
def test_intersect_all(self):
- df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
- self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
- )
+ df_employee_duplicate = self.df_spark_employee.select(
+ F.col("employee_id"), F.col("store_id")
+ ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
- df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
- self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
- )
+ df_store_duplicate = self.df_spark_store.select(
+ F.col("store_id"), F.col("district_id")
+ ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.intersectAll(df_store_duplicate)
- dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
- self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
- )
+ dfs_employee_duplicate = self.df_sqlglot_employee.select(
+ SF.col("employee_id"), SF.col("store_id")
+ ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
- dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
- )
+ dfs_store_duplicate = self.df_sqlglot_store.select(
+ SF.col("store_id"), SF.col("district_id")
+ ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.intersectAll(dfs_store_duplicate)
self.compare_spark_with_sqlglot(df, dfs)
def test_except_all(self):
- df_employee_duplicate = self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")).union(
- self.df_spark_employee.select(F.col("employee_id"), F.col("store_id"))
- )
+ df_employee_duplicate = self.df_spark_employee.select(
+ F.col("employee_id"), F.col("store_id")
+ ).union(self.df_spark_employee.select(F.col("employee_id"), F.col("store_id")))
- df_store_duplicate = self.df_spark_store.select(F.col("store_id"), F.col("district_id")).union(
- self.df_spark_store.select(F.col("store_id"), F.col("district_id"))
- )
+ df_store_duplicate = self.df_spark_store.select(
+ F.col("store_id"), F.col("district_id")
+ ).union(self.df_spark_store.select(F.col("store_id"), F.col("district_id")))
df = df_employee_duplicate.exceptAll(df_store_duplicate)
- dfs_employee_duplicate = self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")).union(
- self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id"))
- )
+ dfs_employee_duplicate = self.df_sqlglot_employee.select(
+ SF.col("employee_id"), SF.col("store_id")
+ ).union(self.df_sqlglot_employee.select(SF.col("employee_id"), SF.col("store_id")))
- dfs_store_duplicate = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")).union(
- self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id"))
- )
+ dfs_store_duplicate = self.df_sqlglot_store.select(
+ SF.col("store_id"), SF.col("district_id")
+ ).union(self.df_sqlglot_store.select(SF.col("store_id"), SF.col("district_id")))
dfs = dfs_employee_duplicate.exceptAll(dfs_store_duplicate)
@@ -721,7 +796,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_drop_na_default(self):
- df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).dropna()
+ df = self.df_spark_employee.select(
+ F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
+ ).dropna()
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@@ -746,7 +823,9 @@ class TestDataframeFunc(DataFrameValidator):
).dropna(how="any", thresh=2)
dfs = self.df_sqlglot_employee.select(
- SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
+ SF.lit(None),
+ SF.lit(1),
+ SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
).dropna(how="any", thresh=2)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -757,13 +836,17 @@ class TestDataframeFunc(DataFrameValidator):
).dropna(thresh=1, subset="the_age")
dfs = self.df_sqlglot_employee.select(
- SF.lit(None), SF.lit(1), SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
+ SF.lit(None),
+ SF.lit(1),
+ SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age"),
).dropna(thresh=1, subset="the_age")
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_dropna_na_function(self):
- df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.drop()
+ df = self.df_spark_employee.select(
+ F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
+ ).na.drop()
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@@ -772,7 +855,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_fillna_default(self):
- df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).fillna(100)
+ df = self.df_spark_employee.select(
+ F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
+ ).fillna(100)
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@@ -798,7 +883,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_fillna_na_func(self):
- df = self.df_spark_employee.select(F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")).na.fill(100)
+ df = self.df_spark_employee.select(
+ F.when(F.col("age") < F.lit(50), F.col("age")).alias("the_age")
+ ).na.fill(100)
dfs = self.df_sqlglot_employee.select(
SF.when(SF.col("age") < SF.lit(50), SF.col("age")).alias("the_age")
@@ -807,7 +894,9 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs)
def test_replace_basic(self):
- df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(to_replace=37, value=100)
+ df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
+ to_replace=37, value=100
+ )
dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
to_replace=37, value=100
@@ -827,9 +916,13 @@ class TestDataframeFunc(DataFrameValidator):
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
def test_replace_mapping(self):
- df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace({37: 100})
+ df = self.df_spark_employee.select(F.col("age"), F.lit(37).alias("test_col")).replace(
+ {37: 100}
+ )
- dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace({37: 100})
+ dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).replace(
+ {37: 100}
+ )
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -849,9 +942,9 @@ class TestDataframeFunc(DataFrameValidator):
to_replace=37, value=100
)
- dfs = self.df_sqlglot_employee.select(SF.col("age"), SF.lit(37).alias("test_col")).na.replace(
- to_replace=37, value=100
- )
+ dfs = self.df_sqlglot_employee.select(
+ SF.col("age"), SF.lit(37).alias("test_col")
+ ).na.replace(to_replace=37, value=100)
self.compare_spark_with_sqlglot(df, dfs, skip_schema_compare=True)
@@ -881,16 +974,18 @@ class TestDataframeFunc(DataFrameValidator):
"first_name", "first_name_again"
)
- dfs = self.df_sqlglot_employee.select(SF.col("fname").alias("first_name")).withColumnRenamed(
- "first_name", "first_name_again"
- )
+ dfs = self.df_sqlglot_employee.select(
+ SF.col("fname").alias("first_name")
+ ).withColumnRenamed("first_name", "first_name_again")
self.compare_spark_with_sqlglot(df, dfs)
def test_drop_column_single(self):
df = self.df_spark_employee.select(F.col("fname"), F.col("lname"), F.col("age")).drop("age")
- dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop("age")
+ dfs = self.df_sqlglot_employee.select(SF.col("fname"), SF.col("lname"), SF.col("age")).drop(
+ "age"
+ )
self.compare_spark_with_sqlglot(df, dfs)
@@ -906,7 +1001,9 @@ class TestDataframeFunc(DataFrameValidator):
df_sqlglot_employee_cols = self.df_sqlglot_employee.select(
SF.col("fname"), SF.col("lname"), SF.col("age"), SF.col("store_id")
)
- df_sqlglot_store_cols = self.df_sqlglot_store.select(SF.col("store_id"), SF.col("store_name"))
+ df_sqlglot_store_cols = self.df_sqlglot_store.select(
+ SF.col("store_id"), SF.col("store_name")
+ )
dfs = df_sqlglot_employee_cols.join(df_sqlglot_store_cols, on="store_id", how="inner").drop(
df_sqlglot_employee_cols.age,
)