summaryrefslogtreecommitdiffstats
path: root/tests/dataframe/unit/dataframe_sql_validator.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-12 15:42:33 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-12 15:42:33 +0000
commit579e404567dfff42e64325a8c79f03ac627ea341 (patch)
tree12d101aa5d1b70a69132e5cbd3307741c00d097f /tests/dataframe/unit/dataframe_sql_validator.py
parentAdding upstream version 10.1.3. (diff)
downloadsqlglot-a2bc9c35bd58ff935bf37447572939c713546e14.tar.xz
sqlglot-a2bc9c35bd58ff935bf37447572939c713546e14.zip
Adding upstream version 10.2.6.upstream/10.2.6
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/dataframe/unit/dataframe_sql_validator.py')
-rw-r--r--tests/dataframe/unit/dataframe_sql_validator.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py
index 32ff8f2..2dcdb39 100644
--- a/tests/dataframe/unit/dataframe_sql_validator.py
+++ b/tests/dataframe/unit/dataframe_sql_validator.py
@@ -4,6 +4,7 @@ import unittest
from sqlglot.dataframe.sql import types
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.session import SparkSession
+from sqlglot.helper import ensure_list
class DataFrameSQLValidator(unittest.TestCase):
@@ -33,9 +34,7 @@ class DataFrameSQLValidator(unittest.TestCase):
self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
):
actual_sqls = df.sql(pretty=pretty)
- expected_statements = (
- [expected_statements] if isinstance(expected_statements, str) else expected_statements
- )
+ expected_statements = ensure_list(expected_statements)
self.assertEqual(len(expected_statements), len(actual_sqls))
for expected, actual in zip(expected_statements, actual_sqls):
self.assertEqual(expected, actual)