diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-21 09:29:23 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-10-21 09:29:23 +0000 |
commit | dab6ba29e8eb9a5c2890ac3be8eab6e994aeb10e (patch) | |
tree | 0d209cfc6f7b9c794c254601c29aa5d8b9414876 /tests/dataframe/unit/dataframe_sql_validator.py | |
parent | Adding upstream version 7.1.3. (diff) | |
download | sqlglot-dab6ba29e8eb9a5c2890ac3be8eab6e994aeb10e.tar.xz sqlglot-dab6ba29e8eb9a5c2890ac3be8eab6e994aeb10e.zip |
Adding upstream version 9.0.1.upstream/9.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/dataframe/unit/dataframe_sql_validator.py')
-rw-r--r-- | tests/dataframe/unit/dataframe_sql_validator.py | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py new file mode 100644 index 0000000..fc56553 --- /dev/null +++ b/tests/dataframe/unit/dataframe_sql_validator.py @@ -0,0 +1,35 @@ +import typing as t +import unittest + +from sqlglot.dataframe.sql import types +from sqlglot.dataframe.sql.dataframe import DataFrame +from sqlglot.dataframe.sql.session import SparkSession + + +class DataFrameSQLValidator(unittest.TestCase): + def setUp(self) -> None: + self.spark = SparkSession() + self.employee_schema = types.StructType( + [ + types.StructField("employee_id", types.IntegerType(), False), + types.StructField("fname", types.StringType(), False), + types.StructField("lname", types.StringType(), False), + types.StructField("age", types.IntegerType(), False), + types.StructField("store_id", types.IntegerType(), False), + ] + ) + employee_data = [ + (1, "Jack", "Shephard", 37, 1), + (2, "John", "Locke", 65, 1), + (3, "Kate", "Austen", 37, 2), + (4, "Claire", "Littleton", 27, 2), + (5, "Hugo", "Reyes", 29, 100), + ] + self.df_employee = self.spark.createDataFrame(data=employee_data, schema=self.employee_schema) + + def compare_sql(self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False): + actual_sqls = df.sql(pretty=pretty) + expected_statements = [expected_statements] if isinstance(expected_statements, str) else expected_statements + self.assertEqual(len(expected_statements), len(actual_sqls)) + for expected, actual in zip(expected_statements, actual_sqls): + self.assertEqual(expected, actual) |