diff options
Diffstat (limited to 'tests/dataframe/integration')
-rw-r--r-- | tests/dataframe/integration/test_dataframe.py | 85 |
1 files changed, 81 insertions, 4 deletions
diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py index 19e3b89..d00464b 100644 --- a/tests/dataframe/integration/test_dataframe.py +++ b/tests/dataframe/integration/test_dataframe.py @@ -276,6 +276,7 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_employee.store_id, self.df_spark_store.store_name, self.df_spark_store["num_sales"], + F.lit("literal_value"), ) dfs_joined = self.df_sqlglot_employee.join( self.df_sqlglot_store, @@ -289,6 +290,7 @@ class TestDataframeFunc(DataFrameValidator): self.df_sqlglot_employee.store_id, self.df_sqlglot_store.store_name, self.df_sqlglot_store["num_sales"], + SF.lit("literal_value"), ) self.compare_spark_with_sqlglot(df_joined, dfs_joined) @@ -330,8 +332,8 @@ class TestDataframeFunc(DataFrameValidator): def test_join_inner_equality_multiple_bitwise_and(self): df_joined = self.df_spark_employee.join( self.df_spark_store, - on=(self.df_spark_employee.store_id == self.df_spark_store.store_id) - & (self.df_spark_employee.age == self.df_spark_store.num_sales), + on=(self.df_spark_store.store_id == self.df_spark_employee.store_id) + & (self.df_spark_store.num_sales == self.df_spark_employee.age), how="inner", ).select( self.df_spark_employee.employee_id, @@ -344,8 +346,8 @@ class TestDataframeFunc(DataFrameValidator): ) dfs_joined = self.df_sqlglot_employee.join( self.df_sqlglot_store, - on=(self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id) - & (self.df_sqlglot_employee.age == self.df_sqlglot_store.num_sales), + on=(self.df_sqlglot_store.store_id == self.df_sqlglot_employee.store_id) + & (self.df_sqlglot_store.num_sales == self.df_sqlglot_employee.age), how="inner", ).select( self.df_sqlglot_employee.employee_id, @@ -443,6 +445,81 @@ class TestDataframeFunc(DataFrameValidator): ) self.compare_spark_with_sqlglot(df, dfs) + def test_triple_join_no_select(self): + df = ( + self.df_employee.join( + self.df_store, + on=self.df_employee["employee_id"] == self.df_store["store_id"], + how="left", + ) + .join( + self.df_district, + on=self.df_store["store_id"] == self.df_district["district_id"], + how="left", + ) + .orderBy(F.col("employee_id")) + ) + dfs = ( + self.dfs_employee.join( + self.dfs_store, + on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"], + how="left", + ) + .join( + self.dfs_district, + on=self.dfs_store["store_id"] == self.dfs_district["district_id"], + how="left", + ) + .orderBy(SF.col("employee_id")) + ) + self.compare_spark_with_sqlglot(df, dfs) + + def test_triple_joins_filter(self): + df = ( + self.df_employee.join( + self.df_store, + on=self.df_employee["employee_id"] == self.df_store["store_id"], + how="left", + ).join( + self.df_district, + on=self.df_store["store_id"] == self.df_district["district_id"], + how="left", + ) + ).filter(F.coalesce(self.df_store["num_sales"], F.lit(0)) > 100) + dfs = ( + self.dfs_employee.join( + self.dfs_store, + on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"], + how="left", + ).join( + self.dfs_district, + on=self.dfs_store["store_id"] == self.dfs_district["district_id"], + how="left", + ) + ).filter(SF.coalesce(self.dfs_store["num_sales"], SF.lit(0)) > 100) + self.compare_spark_with_sqlglot(df, dfs) + + def test_triple_join_column_name_only(self): + df = ( + self.df_employee.join( + self.df_store, + on=self.df_employee["employee_id"] == self.df_store["store_id"], + how="left", + ) + .join(self.df_district, on="district_id", how="left") + .orderBy(F.col("employee_id")) + ) + dfs = ( + self.dfs_employee.join( + self.dfs_store, + on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"], + how="left", + ) + .join(self.dfs_district, on="district_id", how="left") + .orderBy(SF.col("employee_id")) + ) + self.compare_spark_with_sqlglot(df, dfs) + def test_join_select_and_select_start(self): df = self.df_spark_employee.select( F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id") |