From f73e9af131151f1e058446361c35b05c4c90bf10 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 7 Sep 2023 13:39:48 +0200 Subject: Merging upstream version 18.2.0. Signed-off-by: Daniel Baumann --- tests/dataframe/unit/dataframe_sql_validator.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) (limited to 'tests/dataframe/unit/dataframe_sql_validator.py') diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py index 2dcdb39..4363b0d 100644 --- a/tests/dataframe/unit/dataframe_sql_validator.py +++ b/tests/dataframe/unit/dataframe_sql_validator.py @@ -1,14 +1,11 @@ -import typing as t -import unittest - from sqlglot.dataframe.sql import types -from sqlglot.dataframe.sql.dataframe import DataFrame from sqlglot.dataframe.sql.session import SparkSession -from sqlglot.helper import ensure_list +from tests.dataframe.unit.dataframe_test_base import DataFrameTestBase -class DataFrameSQLValidator(unittest.TestCase): +class DataFrameSQLValidator(DataFrameTestBase): def setUp(self) -> None: + super().setUp() self.spark = SparkSession() self.employee_schema = types.StructType( [ @@ -29,12 +26,3 @@ class DataFrameSQLValidator(unittest.TestCase): self.df_employee = self.spark.createDataFrame( data=employee_data, schema=self.employee_schema ) - - def compare_sql( - self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False - ): - actual_sqls = df.sql(pretty=pretty) - expected_statements = ensure_list(expected_statements) - self.assertEqual(len(expected_statements), len(actual_sqls)) - for expected, actual in zip(expected_statements, actual_sqls): - self.assertEqual(expected, actual) -- cgit v1.2.3