diff options
Diffstat (limited to '')
-rw-r--r-- | tests/dataframe/integration/test_dataframe.py | 85 | ||||
-rw-r--r-- | tests/dataframe/unit/test_dataframe_writer.py | 7 | ||||
-rw-r--r-- | tests/dataframe/unit/test_functions.py | 13 |
3 files changed, 96 insertions, 9 deletions
diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py index 19e3b89..d00464b 100644 --- a/tests/dataframe/integration/test_dataframe.py +++ b/tests/dataframe/integration/test_dataframe.py @@ -276,6 +276,7 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_employee.store_id, self.df_spark_store.store_name, self.df_spark_store["num_sales"], + F.lit("literal_value"), ) dfs_joined = self.df_sqlglot_employee.join( self.df_sqlglot_store, @@ -289,6 +290,7 @@ class TestDataframeFunc(DataFrameValidator): self.df_sqlglot_employee.store_id, self.df_sqlglot_store.store_name, self.df_sqlglot_store["num_sales"], + SF.lit("literal_value"), ) self.compare_spark_with_sqlglot(df_joined, dfs_joined) @@ -330,8 +332,8 @@ class TestDataframeFunc(DataFrameValidator): def test_join_inner_equality_multiple_bitwise_and(self): df_joined = self.df_spark_employee.join( self.df_spark_store, - on=(self.df_spark_employee.store_id == self.df_spark_store.store_id) - & (self.df_spark_employee.age == self.df_spark_store.num_sales), + on=(self.df_spark_store.store_id == self.df_spark_employee.store_id) + & (self.df_spark_store.num_sales == self.df_spark_employee.age), how="inner", ).select( self.df_spark_employee.employee_id, @@ -344,8 +346,8 @@ class TestDataframeFunc(DataFrameValidator): ) dfs_joined = self.df_sqlglot_employee.join( self.df_sqlglot_store, - on=(self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id) - & (self.df_sqlglot_employee.age == self.df_sqlglot_store.num_sales), + on=(self.df_sqlglot_store.store_id == self.df_sqlglot_employee.store_id) + & (self.df_sqlglot_store.num_sales == self.df_sqlglot_employee.age), how="inner", ).select( self.df_sqlglot_employee.employee_id, @@ -443,6 +445,81 @@ class TestDataframeFunc(DataFrameValidator): ) self.compare_spark_with_sqlglot(df, dfs) + def test_triple_join_no_select(self): + df = ( + self.df_employee.join( + self.df_store, + on=self.df_employee["employee_id"] == self.df_store["store_id"], + how="left", + ) + .join( + self.df_district, + on=self.df_store["store_id"] == self.df_district["district_id"], + how="left", + ) + .orderBy(F.col("employee_id")) + ) + dfs = ( + self.dfs_employee.join( + self.dfs_store, + on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"], + how="left", + ) + .join( + self.dfs_district, + on=self.dfs_store["store_id"] == self.dfs_district["district_id"], + how="left", + ) + .orderBy(SF.col("employee_id")) + ) + self.compare_spark_with_sqlglot(df, dfs) + + def test_triple_joins_filter(self): + df = ( + self.df_employee.join( + self.df_store, + on=self.df_employee["employee_id"] == self.df_store["store_id"], + how="left", + ).join( + self.df_district, + on=self.df_store["store_id"] == self.df_district["district_id"], + how="left", + ) + ).filter(F.coalesce(self.df_store["num_sales"], F.lit(0)) > 100) + dfs = ( + self.dfs_employee.join( + self.dfs_store, + on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"], + how="left", + ).join( + self.dfs_district, + on=self.dfs_store["store_id"] == self.dfs_district["district_id"], + how="left", + ) + ).filter(SF.coalesce(self.dfs_store["num_sales"], SF.lit(0)) > 100) + self.compare_spark_with_sqlglot(df, dfs) + + def test_triple_join_column_name_only(self): + df = ( + self.df_employee.join( + self.df_store, + on=self.df_employee["employee_id"] == self.df_store["store_id"], + how="left", + ) + .join(self.df_district, on="district_id", how="left") + .orderBy(F.col("employee_id")) + ) + dfs = ( + self.dfs_employee.join( + self.dfs_store, + on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"], + how="left", + ) + .join(self.dfs_district, on="district_id", how="left") + .orderBy(SF.col("employee_id")) + ) + self.compare_spark_with_sqlglot(df, dfs) + def test_join_select_and_select_start(self): df = self.df_spark_employee.select( F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id") diff --git a/tests/dataframe/unit/test_dataframe_writer.py b/tests/dataframe/unit/test_dataframe_writer.py index 042b915..3f45468 100644 --- a/tests/dataframe/unit/test_dataframe_writer.py +++ b/tests/dataframe/unit/test_dataframe_writer.py @@ -86,3 +86,10 @@ class TestDataFrameWriter(DataFrameSQLValidator): "CREATE TABLE table_name AS SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`", ] self.compare_sql(df, expected_statements) + + def test_quotes(self): + sqlglot.schema.add_table('"Test"', {'"ID"': "STRING"}) + df = self.spark.table('"Test"') + self.compare_sql( + df.select(df['"ID"']), ["SELECT `Test`.`ID` AS `ID` FROM `Test` AS `Test`"] + ) diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index d9a32c4..befa68b 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -807,14 +807,17 @@ class TestFunctions(unittest.TestCase): self.assertEqual("DATE_ADD(cola, 2)", col.sql()) col_col_for_add = SF.date_add("cola", "colb") self.assertEqual("DATE_ADD(cola, colb)", col_col_for_add.sql()) + current_date_add = SF.date_add(SF.current_date(), 5) + self.assertEqual("DATE_ADD(CURRENT_DATE, 5)", current_date_add.sql()) + self.assertEqual("DATEADD(day, 5, CURRENT_DATE)", current_date_add.sql(dialect="snowflake")) def test_date_sub(self): col_str = SF.date_sub("cola", 2) - self.assertEqual("DATE_SUB(cola, 2)", col_str.sql()) + self.assertEqual("DATE_ADD(cola, -2)", col_str.sql()) col = SF.date_sub(SF.col("cola"), 2) - self.assertEqual("DATE_SUB(cola, 2)", col.sql()) + self.assertEqual("DATE_ADD(cola, -2)", col.sql()) col_col_for_add = SF.date_sub("cola", "colb") - self.assertEqual("DATE_SUB(cola, colb)", col_col_for_add.sql()) + self.assertEqual("DATE_ADD(cola, colb * -1)", col_col_for_add.sql()) def test_date_diff(self): col_str = SF.date_diff("cola", "colb") @@ -957,9 +960,9 @@ class TestFunctions(unittest.TestCase): def test_sha1(self): col_str = SF.sha1("Spark") - self.assertEqual("SHA1('Spark')", col_str.sql()) + self.assertEqual("SHA('Spark')", col_str.sql()) col = SF.sha1(SF.col("cola")) - self.assertEqual("SHA1(cola)", col.sql()) + self.assertEqual("SHA(cola)", col.sql()) def test_sha2(self): col_str = SF.sha2("Spark", 256) |