summaryrefslogtreecommitdiffstats
path: root/tests/dataframe/integration/dataframe_validator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dataframe/integration/dataframe_validator.py')
-rw-r--r--tests/dataframe/integration/dataframe_validator.py52
1 files changed, 39 insertions, 13 deletions
diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py
index 4a89c78..16f8922 100644
--- a/tests/dataframe/integration/dataframe_validator.py
+++ b/tests/dataframe/integration/dataframe_validator.py
@@ -1,9 +1,9 @@
-import sys
import typing as t
import unittest
import warnings
import sqlglot
+from sqlglot.helper import PYTHON_VERSION
from tests.helpers import SKIP_INTEGRATION
if t.TYPE_CHECKING:
@@ -11,7 +11,8 @@ if t.TYPE_CHECKING:
@unittest.skipIf(
- SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set"
+ SKIP_INTEGRATION or PYTHON_VERSION > (3, 10),
+ "Skipping Integration Tests since `SKIP_INTEGRATION` is set",
)
class DataFrameValidator(unittest.TestCase):
spark = None
@@ -36,7 +37,12 @@ class DataFrameValidator(unittest.TestCase):
# This is for test `test_branching_root_dataframes`
config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")])
- cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate()
+ cls.spark = (
+ SparkSession.builder.master("local[*]")
+ .appName("Unit-tests")
+ .config(conf=config)
+ .getOrCreate()
+ )
cls.spark.sparkContext.setLogLevel("ERROR")
cls.sqlglot = SqlglotSparkSession()
cls.spark_employee_schema = types.StructType(
@@ -50,7 +56,9 @@ class DataFrameValidator(unittest.TestCase):
)
cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType(
[
- sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False),
+ sqlglotSparkTypes.StructField(
+ "employee_id", sqlglotSparkTypes.IntegerType(), False
+ ),
sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False),
sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False),
@@ -64,8 +72,12 @@ class DataFrameValidator(unittest.TestCase):
(4, "Claire", "Littleton", 27, 2),
(5, "Hugo", "Reyes", 29, 100),
]
- cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema)
- cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema)
+ cls.df_employee = cls.spark.createDataFrame(
+ data=employee_data, schema=cls.spark_employee_schema
+ )
+ cls.dfs_employee = cls.sqlglot.createDataFrame(
+ data=employee_data, schema=cls.sqlglot_employee_schema
+ )
cls.df_employee.createOrReplaceTempView("employee")
cls.spark_store_schema = types.StructType(
@@ -80,7 +92,9 @@ class DataFrameValidator(unittest.TestCase):
[
sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False),
sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False),
- sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
+ sqlglotSparkTypes.StructField(
+ "district_id", sqlglotSparkTypes.IntegerType(), False
+ ),
sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False),
]
)
@@ -89,7 +103,9 @@ class DataFrameValidator(unittest.TestCase):
(2, "Arrow", 2, 2000),
]
cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema)
- cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema)
+ cls.dfs_store = cls.sqlglot.createDataFrame(
+ data=store_data, schema=cls.sqlglot_store_schema
+ )
cls.df_store.createOrReplaceTempView("store")
cls.spark_district_schema = types.StructType(
@@ -101,17 +117,27 @@ class DataFrameValidator(unittest.TestCase):
)
cls.sqlglot_district_schema = sqlglotSparkTypes.StructType(
[
- sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False),
- sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False),
- sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False),
+ sqlglotSparkTypes.StructField(
+ "district_id", sqlglotSparkTypes.IntegerType(), False
+ ),
+ sqlglotSparkTypes.StructField(
+ "district_name", sqlglotSparkTypes.StringType(), False
+ ),
+ sqlglotSparkTypes.StructField(
+ "manager_name", sqlglotSparkTypes.StringType(), False
+ ),
]
)
district_data = [
(1, "Temple", "Dogen"),
(2, "Lighthouse", "Jacob"),
]
- cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema)
- cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema)
+ cls.df_district = cls.spark.createDataFrame(
+ data=district_data, schema=cls.spark_district_schema
+ )
+ cls.dfs_district = cls.sqlglot.createDataFrame(
+ data=district_data, schema=cls.sqlglot_district_schema
+ )
cls.df_district.createOrReplaceTempView("district")
sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema)
sqlglot.schema.add_table("store", cls.sqlglot_store_schema)