summaryrefslogtreecommitdiffstats
path: root/tests/dataframe
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dataframe')
-rw-r--r--tests/dataframe/integration/test_dataframe.py85
-rw-r--r--tests/dataframe/unit/test_dataframe_writer.py7
-rw-r--r--tests/dataframe/unit/test_functions.py13
3 files changed, 96 insertions, 9 deletions
diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py
index 19e3b89..d00464b 100644
--- a/tests/dataframe/integration/test_dataframe.py
+++ b/tests/dataframe/integration/test_dataframe.py
@@ -276,6 +276,7 @@ class TestDataframeFunc(DataFrameValidator):
self.df_spark_employee.store_id,
self.df_spark_store.store_name,
self.df_spark_store["num_sales"],
+ F.lit("literal_value"),
)
dfs_joined = self.df_sqlglot_employee.join(
self.df_sqlglot_store,
@@ -289,6 +290,7 @@ class TestDataframeFunc(DataFrameValidator):
self.df_sqlglot_employee.store_id,
self.df_sqlglot_store.store_name,
self.df_sqlglot_store["num_sales"],
+ SF.lit("literal_value"),
)
self.compare_spark_with_sqlglot(df_joined, dfs_joined)
@@ -330,8 +332,8 @@ class TestDataframeFunc(DataFrameValidator):
def test_join_inner_equality_multiple_bitwise_and(self):
df_joined = self.df_spark_employee.join(
self.df_spark_store,
- on=(self.df_spark_employee.store_id == self.df_spark_store.store_id)
- & (self.df_spark_employee.age == self.df_spark_store.num_sales),
+ on=(self.df_spark_store.store_id == self.df_spark_employee.store_id)
+ & (self.df_spark_store.num_sales == self.df_spark_employee.age),
how="inner",
).select(
self.df_spark_employee.employee_id,
@@ -344,8 +346,8 @@ class TestDataframeFunc(DataFrameValidator):
)
dfs_joined = self.df_sqlglot_employee.join(
self.df_sqlglot_store,
- on=(self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id)
- & (self.df_sqlglot_employee.age == self.df_sqlglot_store.num_sales),
+ on=(self.df_sqlglot_store.store_id == self.df_sqlglot_employee.store_id)
+ & (self.df_sqlglot_store.num_sales == self.df_sqlglot_employee.age),
how="inner",
).select(
self.df_sqlglot_employee.employee_id,
@@ -443,6 +445,81 @@ class TestDataframeFunc(DataFrameValidator):
)
self.compare_spark_with_sqlglot(df, dfs)
+ def test_triple_join_no_select(self):
+ df = (
+ self.df_employee.join(
+ self.df_store,
+ on=self.df_employee["employee_id"] == self.df_store["store_id"],
+ how="left",
+ )
+ .join(
+ self.df_district,
+ on=self.df_store["store_id"] == self.df_district["district_id"],
+ how="left",
+ )
+ .orderBy(F.col("employee_id"))
+ )
+ dfs = (
+ self.dfs_employee.join(
+ self.dfs_store,
+ on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"],
+ how="left",
+ )
+ .join(
+ self.dfs_district,
+ on=self.dfs_store["store_id"] == self.dfs_district["district_id"],
+ how="left",
+ )
+ .orderBy(SF.col("employee_id"))
+ )
+ self.compare_spark_with_sqlglot(df, dfs)
+
+ def test_triple_joins_filter(self):
+ df = (
+ self.df_employee.join(
+ self.df_store,
+ on=self.df_employee["employee_id"] == self.df_store["store_id"],
+ how="left",
+ ).join(
+ self.df_district,
+ on=self.df_store["store_id"] == self.df_district["district_id"],
+ how="left",
+ )
+ ).filter(F.coalesce(self.df_store["num_sales"], F.lit(0)) > 100)
+ dfs = (
+ self.dfs_employee.join(
+ self.dfs_store,
+ on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"],
+ how="left",
+ ).join(
+ self.dfs_district,
+ on=self.dfs_store["store_id"] == self.dfs_district["district_id"],
+ how="left",
+ )
+ ).filter(SF.coalesce(self.dfs_store["num_sales"], SF.lit(0)) > 100)
+ self.compare_spark_with_sqlglot(df, dfs)
+
+ def test_triple_join_column_name_only(self):
+ df = (
+ self.df_employee.join(
+ self.df_store,
+ on=self.df_employee["employee_id"] == self.df_store["store_id"],
+ how="left",
+ )
+ .join(self.df_district, on="district_id", how="left")
+ .orderBy(F.col("employee_id"))
+ )
+ dfs = (
+ self.dfs_employee.join(
+ self.dfs_store,
+ on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"],
+ how="left",
+ )
+ .join(self.dfs_district, on="district_id", how="left")
+ .orderBy(SF.col("employee_id"))
+ )
+ self.compare_spark_with_sqlglot(df, dfs)
+
def test_join_select_and_select_start(self):
df = self.df_spark_employee.select(
F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id")
diff --git a/tests/dataframe/unit/test_dataframe_writer.py b/tests/dataframe/unit/test_dataframe_writer.py
index 042b915..3f45468 100644
--- a/tests/dataframe/unit/test_dataframe_writer.py
+++ b/tests/dataframe/unit/test_dataframe_writer.py
@@ -86,3 +86,10 @@ class TestDataFrameWriter(DataFrameSQLValidator):
"CREATE TABLE table_name AS SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`",
]
self.compare_sql(df, expected_statements)
+
+ def test_quotes(self):
+ sqlglot.schema.add_table('"Test"', {'"ID"': "STRING"})
+ df = self.spark.table('"Test"')
+ self.compare_sql(
+ df.select(df['"ID"']), ["SELECT `Test`.`ID` AS `ID` FROM `Test` AS `Test`"]
+ )
diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py
index d9a32c4..befa68b 100644
--- a/tests/dataframe/unit/test_functions.py
+++ b/tests/dataframe/unit/test_functions.py
@@ -807,14 +807,17 @@ class TestFunctions(unittest.TestCase):
self.assertEqual("DATE_ADD(cola, 2)", col.sql())
col_col_for_add = SF.date_add("cola", "colb")
self.assertEqual("DATE_ADD(cola, colb)", col_col_for_add.sql())
+ current_date_add = SF.date_add(SF.current_date(), 5)
+ self.assertEqual("DATE_ADD(CURRENT_DATE, 5)", current_date_add.sql())
+ self.assertEqual("DATEADD(day, 5, CURRENT_DATE)", current_date_add.sql(dialect="snowflake"))
def test_date_sub(self):
col_str = SF.date_sub("cola", 2)
- self.assertEqual("DATE_SUB(cola, 2)", col_str.sql())
+ self.assertEqual("DATE_ADD(cola, -2)", col_str.sql())
col = SF.date_sub(SF.col("cola"), 2)
- self.assertEqual("DATE_SUB(cola, 2)", col.sql())
+ self.assertEqual("DATE_ADD(cola, -2)", col.sql())
col_col_for_add = SF.date_sub("cola", "colb")
- self.assertEqual("DATE_SUB(cola, colb)", col_col_for_add.sql())
+ self.assertEqual("DATE_ADD(cola, colb * -1)", col_col_for_add.sql())
def test_date_diff(self):
col_str = SF.date_diff("cola", "colb")
@@ -957,9 +960,9 @@ class TestFunctions(unittest.TestCase):
def test_sha1(self):
col_str = SF.sha1("Spark")
- self.assertEqual("SHA1('Spark')", col_str.sql())
+ self.assertEqual("SHA('Spark')", col_str.sql())
col = SF.sha1(SF.col("cola"))
- self.assertEqual("SHA1(cola)", col.sql())
+ self.assertEqual("SHA(cola)", col.sql())
def test_sha2(self):
col_str = SF.sha2("Spark", 256)