From 9ebe8c99ba4be74ccebf1b013f4e56ec09e023c1 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 11 Nov 2022 09:54:30 +0100 Subject: Adding upstream version 10.0.1. Signed-off-by: Daniel Baumann --- tests/dataframe/integration/dataframe_validator.py | 52 ++++++++++++++++------ 1 file changed, 39 insertions(+), 13 deletions(-) (limited to 'tests/dataframe/integration/dataframe_validator.py') diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py index 4a89c78..16f8922 100644 --- a/tests/dataframe/integration/dataframe_validator.py +++ b/tests/dataframe/integration/dataframe_validator.py @@ -1,9 +1,9 @@ -import sys import typing as t import unittest import warnings import sqlglot +from sqlglot.helper import PYTHON_VERSION from tests.helpers import SKIP_INTEGRATION if t.TYPE_CHECKING: @@ -11,7 +11,8 @@ if t.TYPE_CHECKING: @unittest.skipIf( - SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set" + SKIP_INTEGRATION or PYTHON_VERSION > (3, 10), + "Skipping Integration Tests since `SKIP_INTEGRATION` is set", ) class DataFrameValidator(unittest.TestCase): spark = None @@ -36,7 +37,12 @@ class DataFrameValidator(unittest.TestCase): # This is for test `test_branching_root_dataframes` config = SparkConf().setAll([("spark.sql.analyzer.failAmbiguousSelfJoin", "false")]) - cls.spark = SparkSession.builder.master("local[*]").appName("Unit-tests").config(conf=config).getOrCreate() + cls.spark = ( + SparkSession.builder.master("local[*]") + .appName("Unit-tests") + .config(conf=config) + .getOrCreate() + ) cls.spark.sparkContext.setLogLevel("ERROR") cls.sqlglot = SqlglotSparkSession() cls.spark_employee_schema = types.StructType( @@ -50,7 +56,9 @@ class DataFrameValidator(unittest.TestCase): ) cls.sqlglot_employee_schema = sqlglotSparkTypes.StructType( [ - sqlglotSparkTypes.StructField("employee_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField( + "employee_id", sqlglotSparkTypes.IntegerType(), False + ), sqlglotSparkTypes.StructField("fname", sqlglotSparkTypes.StringType(), False), sqlglotSparkTypes.StructField("lname", sqlglotSparkTypes.StringType(), False), sqlglotSparkTypes.StructField("age", sqlglotSparkTypes.IntegerType(), False), @@ -64,8 +72,12 @@ class DataFrameValidator(unittest.TestCase): (4, "Claire", "Littleton", 27, 2), (5, "Hugo", "Reyes", 29, 100), ] - cls.df_employee = cls.spark.createDataFrame(data=employee_data, schema=cls.spark_employee_schema) - cls.dfs_employee = cls.sqlglot.createDataFrame(data=employee_data, schema=cls.sqlglot_employee_schema) + cls.df_employee = cls.spark.createDataFrame( + data=employee_data, schema=cls.spark_employee_schema + ) + cls.dfs_employee = cls.sqlglot.createDataFrame( + data=employee_data, schema=cls.sqlglot_employee_schema + ) cls.df_employee.createOrReplaceTempView("employee") cls.spark_store_schema = types.StructType( @@ -80,7 +92,9 @@ class DataFrameValidator(unittest.TestCase): [ sqlglotSparkTypes.StructField("store_id", sqlglotSparkTypes.IntegerType(), False), sqlglotSparkTypes.StructField("store_name", sqlglotSparkTypes.StringType(), False), - sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False), + sqlglotSparkTypes.StructField( + "district_id", sqlglotSparkTypes.IntegerType(), False + ), sqlglotSparkTypes.StructField("num_sales", sqlglotSparkTypes.IntegerType(), False), ] ) @@ -89,7 +103,9 @@ class DataFrameValidator(unittest.TestCase): (2, "Arrow", 2, 2000), ] cls.df_store = cls.spark.createDataFrame(data=store_data, schema=cls.spark_store_schema) - cls.dfs_store = cls.sqlglot.createDataFrame(data=store_data, schema=cls.sqlglot_store_schema) + cls.dfs_store = cls.sqlglot.createDataFrame( + data=store_data, schema=cls.sqlglot_store_schema + ) cls.df_store.createOrReplaceTempView("store") cls.spark_district_schema = types.StructType( @@ -101,17 +117,27 @@ class DataFrameValidator(unittest.TestCase): ) cls.sqlglot_district_schema = sqlglotSparkTypes.StructType( [ - sqlglotSparkTypes.StructField("district_id", sqlglotSparkTypes.IntegerType(), False), - sqlglotSparkTypes.StructField("district_name", sqlglotSparkTypes.StringType(), False), - sqlglotSparkTypes.StructField("manager_name", sqlglotSparkTypes.StringType(), False), + sqlglotSparkTypes.StructField( + "district_id", sqlglotSparkTypes.IntegerType(), False + ), + sqlglotSparkTypes.StructField( + "district_name", sqlglotSparkTypes.StringType(), False + ), + sqlglotSparkTypes.StructField( + "manager_name", sqlglotSparkTypes.StringType(), False + ), ] ) district_data = [ (1, "Temple", "Dogen"), (2, "Lighthouse", "Jacob"), ] - cls.df_district = cls.spark.createDataFrame(data=district_data, schema=cls.spark_district_schema) - cls.dfs_district = cls.sqlglot.createDataFrame(data=district_data, schema=cls.sqlglot_district_schema) + cls.df_district = cls.spark.createDataFrame( + data=district_data, schema=cls.spark_district_schema + ) + cls.dfs_district = cls.sqlglot.createDataFrame( + data=district_data, schema=cls.sqlglot_district_schema + ) cls.df_district.createOrReplaceTempView("district") sqlglot.schema.add_table("employee", cls.sqlglot_employee_schema) sqlglot.schema.add_table("store", cls.sqlglot_store_schema) -- cgit v1.2.3