summaryrefslogtreecommitdiffstats
path: root/tests/dataframe/unit/dataframe_test_base.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:43 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:43 +0000
commit341eb1a6bdf0dd5b015e5140d3b068c6fd3f4d87 (patch)
tree61fb7eca2238fb5d41d3906f4af41de03abd25ea /tests/dataframe/unit/dataframe_test_base.py
parentAdding upstream version 17.12.0. (diff)
downloadsqlglot-upstream/18.2.0.tar.xz
sqlglot-upstream/18.2.0.zip
Adding upstream version 18.2.0.upstream/18.2.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/dataframe/unit/dataframe_test_base.py')
-rw-r--r--tests/dataframe/unit/dataframe_test_base.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/tests/dataframe/unit/dataframe_test_base.py b/tests/dataframe/unit/dataframe_test_base.py
new file mode 100644
index 0000000..6b07df9
--- /dev/null
+++ b/tests/dataframe/unit/dataframe_test_base.py
@@ -0,0 +1,23 @@
+import typing as t
+import unittest
+
+import sqlglot
+from sqlglot import MappingSchema
+from sqlglot.dataframe.sql import SparkSession
+from sqlglot.dataframe.sql.dataframe import DataFrame
+from sqlglot.helper import ensure_list
+
+
+class DataFrameTestBase(unittest.TestCase):
+ def setUp(self) -> None:
+ sqlglot.schema = MappingSchema()
+ SparkSession._instance = None
+
+ def compare_sql(
+ self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
+ ):
+ actual_sqls = df.sql(pretty=pretty)
+ expected_statements = ensure_list(expected_statements)
+ self.assertEqual(len(expected_statements), len(actual_sqls))
+ for expected, actual in zip(expected_statements, actual_sqls):
+ self.assertEqual(expected, actual)