summaryrefslogtreecommitdiffstats
path: root/tests/dataframe/unit/dataframe_sql_validator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dataframe/unit/dataframe_sql_validator.py')
-rw-r--r--tests/dataframe/unit/dataframe_sql_validator.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py
new file mode 100644
index 0000000..fc56553
--- /dev/null
+++ b/tests/dataframe/unit/dataframe_sql_validator.py
@@ -0,0 +1,35 @@
+import typing as t
+import unittest
+
+from sqlglot.dataframe.sql import types
+from sqlglot.dataframe.sql.dataframe import DataFrame
+from sqlglot.dataframe.sql.session import SparkSession
+
+
+class DataFrameSQLValidator(unittest.TestCase):
+ def setUp(self) -> None:
+ self.spark = SparkSession()
+ self.employee_schema = types.StructType(
+ [
+ types.StructField("employee_id", types.IntegerType(), False),
+ types.StructField("fname", types.StringType(), False),
+ types.StructField("lname", types.StringType(), False),
+ types.StructField("age", types.IntegerType(), False),
+ types.StructField("store_id", types.IntegerType(), False),
+ ]
+ )
+ employee_data = [
+ (1, "Jack", "Shephard", 37, 1),
+ (2, "John", "Locke", 65, 1),
+ (3, "Kate", "Austen", 37, 2),
+ (4, "Claire", "Littleton", 27, 2),
+ (5, "Hugo", "Reyes", 29, 100),
+ ]
+ self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema)
+
+ def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False):
+ actual_sqls = df.sql(pretty=pretty)
+ expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements
+ self.assertEqual(len(expected_statements), len(actual_sqls))
+ for expected, actual in zip(expected_statements, actual_sqls):
+ self.assertEqual(expected, actual)