summaryrefslogtreecommitdiffstats
path: root/tests/dataframe/unit/dataframe_sql_validator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dataframe/unit/dataframe_sql_validator.py')
-rw-r--r--tests/dataframe/unit/dataframe_sql_validator.py18
1 files changed, 3 insertions, 15 deletions
diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py
index 2dcdb39..4363b0d 100644
--- a/tests/dataframe/unit/dataframe_sql_validator.py
+++ b/tests/dataframe/unit/dataframe_sql_validator.py
@@ -1,14 +1,11 @@
-import typing as t
-import unittest
-
from sqlglot.dataframe.sql import types
-from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.session import SparkSession
-from sqlglot.helper import ensure_list
+from tests.dataframe.unit.dataframe_test_base import DataFrameTestBase
-class DataFrameSQLValidator(unittest.TestCase):
+class DataFrameSQLValidator(DataFrameTestBase):
def setUp(self) -> None:
+ super().setUp()
self.spark = SparkSession()
self.employee_schema = types.StructType(
[
@@ -29,12 +26,3 @@ class DataFrameSQLValidator(unittest.TestCase):
self.df_employee = self.spark.createDataFrame(
data=employee_data, schema=self.employee_schema
)
-
- def compare_sql(
- self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
- ):
- actual_sqls = df.sql(pretty=pretty)
- expected_statements = ensure_list(expected_statements)
- self.assertEqual(len(expected_statements), len(actual_sqls))
- for expected, actual in zip(expected_statements, actual_sqls):
- self.assertEqual(expected, actual)