diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-21 09:29:26 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-21 09:29:26 +0000 |
commit | 8b4272814fb4585be120f183eb7c26bb8acde974 (patch) | |
tree | 85d56a8f5ac4ac94ab924d5bbc578586eeb2a998 /tests/dataframe/unit/dataframe_sql_validator.py | |
parent | Releasing debian version 7.1.3-1. (diff) | |
download | sqlglot-8b4272814fb4585be120f183eb7c26bb8acde974.tar.xz sqlglot-8b4272814fb4585be120f183eb7c26bb8acde974.zip |
Merging upstream version 9.0.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r-- | tests/dataframe/unit/dataframe_sql_validator.py | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py new file mode 100644 index 0000000..fc56553 --- /dev/null +++ b/tests/dataframe/unit/dataframe_sql_validator.py @@ -0,0 +1,35 @@ +import typing as t +import unittest + +from sqlglot.dataframe.sql import types +from sqlglot.dataframe.sql.dataframe import DataFrame +from sqlglot.dataframe.sql.session import SparkSession + + +class DataFrameSQLValidator(unittest.TestCase): + def setUp(self) -> None: + self.spark = SparkSession() + self.employee_schema = types.StructType( + [ + types.StructField("employee_id", types.IntegerType(), False), + types.StructField("fname", types.StringType(), False), + types.StructField("lname", types.StringType(), False), + types.StructField("age", types.IntegerType(), False), + types.StructField("store_id", types.IntegerType(), False), + ] + ) + employee_data = [ + (1, "Jack", "Shephard", 37, 1), + (2, "John", "Locke", 65, 1), + (3, "Kate", "Austen", 37, 2), + (4, "Claire", "Littleton", 27, 2), + (5, "Hugo", "Reyes", 29, 100), + ] + self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema) + + def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False): + actual_sqls = df.sql(pretty=pretty) + expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements + self.assertEqual(len(expected_statements), len(actual_sqls)) + for expected, actual in zip(expected_statements, actual_sqls): + self.assertEqual(expected, actual) |