summaryrefslogtreecommitdiffstats
path: root/tests/dataframe/integration/test_dataframe.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dataframe/integration/test_dataframe.py')
-rw-r--r--tests/dataframe/integration/test_dataframe.py85
1 files changed, 81 insertions, 4 deletions
diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py
index 19e3b89..d00464b 100644
--- a/tests/dataframe/integration/test_dataframe.py
+++ b/tests/dataframe/integration/test_dataframe.py
@@ -276,6 +276,7 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_employee.store_id,
self.df_spark_store.store_name,
self.df_spark_store["num_sales"],
+ F.lit("literal_value"),
)
dfs_joined = self.df_sqlglot_employee.join(
self.df_sqlglot_store,
@@ -289,6 +290,7 @@ class TestDataframeFunc(DataFrameValidator):
self.df_sqlglot_employee.store_id,
self.df_sqlglot_store.store_name,
self.df_sqlglot_store["num_sales"],
+ SF.lit("literal_value"),
)
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
@@ -330,8 +332,8 @@ class TestDataframeFunc(DataFrameValidator):
def test_join_inner_equality_multiple_bitwise_and(self):
df_joined = self.df_spark_employee.join(
self.df_spark_store,
- on=(self.df_spark_employee.store_id == self.df_spark_store.store_id)
- & (self.df_spark_employee.age == self.df_spark_store.num_sales),
+ on=(self.df_spark_store.store_id == self.df_spark_employee.store_id)
+ & (self.df_spark_store.num_sales == self.df_spark_employee.age),
how="inner",
).select(
self.df_spark_employee.employee_id,
@@ -344,8 +346,8 @@ class TestDataframeFunc(DataFrameValidator):
)
dfs_joined = self.df_sqlglot_employee.join(
self.df_sqlglot_store,
- on=(self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id)
- & (self.df_sqlglot_employee.age == self.df_sqlglot_store.num_sales),
+ on=(self.df_sqlglot_store.store_id == self.df_sqlglot_employee.store_id)
+ & (self.df_sqlglot_store.num_sales == self.df_sqlglot_employee.age),
how="inner",
).select(
self.df_sqlglot_employee.employee_id,
@@ -443,6 +445,81 @@ class TestDataframeFunc(DataFrameValidator):
)
self.compare_spark_with_sqlglot(df, dfs)
+ def test_triple_join_no_select(self):
+ df = (
+ self.df_employee.join(
+ self.df_store,
+ on=self.df_employee["employee_id"] == self.df_store["store_id"],
+ how="left",
+ )
+ .join(
+ self.df_district,
+ on=self.df_store["store_id"] == self.df_district["district_id"],
+ how="left",
+ )
+ .orderBy(F.col("employee_id"))
+ )
+ dfs = (
+ self.dfs_employee.join(
+ self.dfs_store,
+ on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"],
+ how="left",
+ )
+ .join(
+ self.dfs_district,
+ on=self.dfs_store["store_id"] == self.dfs_district["district_id"],
+ how="left",
+ )
+ .orderBy(SF.col("employee_id"))
+ )
+ self.compare_spark_with_sqlglot(df, dfs)
+
+ def test_triple_joins_filter(self):
+ df = (
+ self.df_employee.join(
+ self.df_store,
+ on=self.df_employee["employee_id"] == self.df_store["store_id"],
+ how="left",
+ ).join(
+ self.df_district,
+ on=self.df_store["store_id"] == self.df_district["district_id"],
+ how="left",
+ )
+ ).filter(F.coalesce(self.df_store["num_sales"], F.lit(0)) > 100)
+ dfs = (
+ self.dfs_employee.join(
+ self.dfs_store,
+ on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"],
+ how="left",
+ ).join(
+ self.dfs_district,
+ on=self.dfs_store["store_id"] == self.dfs_district["district_id"],
+ how="left",
+ )
+ ).filter(SF.coalesce(self.dfs_store["num_sales"], SF.lit(0)) > 100)
+ self.compare_spark_with_sqlglot(df, dfs)
+
+ def test_triple_join_column_name_only(self):
+ df = (
+ self.df_employee.join(
+ self.df_store,
+ on=self.df_employee["employee_id"] == self.df_store["store_id"],
+ how="left",
+ )
+ .join(self.df_district, on="district_id", how="left")
+ .orderBy(F.col("employee_id"))
+ )
+ dfs = (
+ self.dfs_employee.join(
+ self.dfs_store,
+ on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"],
+ how="left",
+ )
+ .join(self.dfs_district, on="district_id", how="left")
+ .orderBy(SF.col("employee_id"))
+ )
+ self.compare_spark_with_sqlglot(df, dfs)
+
def test_join_select_and_select_start(self):
df = self.df_spark_employee.select(
F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")