diff options
Diffstat (limited to 'tests/dataframe/unit/dataframe_sql_validator.py')
-rw-r--r-- | tests/dataframe/unit/dataframe_sql_validator.py | 18 |
1 files changed, 3 insertions, 15 deletions
diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py index 2dcdb39..4363b0d 100644 --- a/tests/dataframe/unit/dataframe_sql_validator.py +++ b/tests/dataframe/unit/dataframe_sql_validator.py @@ -1,14 +1,11 @@ -import typing as t -import unittest - from sqlglot.dataframe.sql import types -from sqlglot.dataframe.sql.dataframe import DataFrame from sqlglot.dataframe.sql.session import SparkSession -from sqlglot.helper import ensure_list +from tests.dataframe.unit.dataframe_test_base import DataFrameTestBase -class DataFrameSQLValidator(unittest.TestCase): +class DataFrameSQLValidator(DataFrameTestBase): def setUp(self) -> None: + super().setUp() self.spark = SparkSession() self.employee_schema = types.StructType( [ @@ -29,12 +26,3 @@ class DataFrameSQLValidator(unittest.TestCase): self.df_employee = self.spark.createDataFrame( data=employee_data, schema=self.employee_schema ) - - def compare_sql( - self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False - ): - actual_sqls = df.sql(pretty=pretty) - expected_statements = ensure_list(expected_statements) - self.assertEqual(len(expected_statements), len(actual_sqls)) - for expected, actual in zip(expected_statements, actual_sqls): - self.assertEqual(expected, actual) |