diff options
Diffstat (limited to 'tests/dataframe/unit/test_session_case_sensitivity.py')
-rw-r--r-- | tests/dataframe/unit/test_session_case_sensitivity.py | 81 |
1 files changed, 81 insertions, 0 deletions
diff --git a/tests/dataframe/unit/test_session_case_sensitivity.py b/tests/dataframe/unit/test_session_case_sensitivity.py new file mode 100644 index 0000000..7e35289 --- /dev/null +++ b/tests/dataframe/unit/test_session_case_sensitivity.py @@ -0,0 +1,81 @@ +import sqlglot +from sqlglot.dataframe.sql import functions as F +from sqlglot.dataframe.sql.session import SparkSession +from sqlglot.errors import OptimizeError +from tests.dataframe.unit.dataframe_test_base import DataFrameTestBase + + +class TestSessionCaseSensitivity(DataFrameTestBase): + def setUp(self) -> None: + super().setUp() + self.spark = SparkSession.builder.config("sqlframe.dialect", "snowflake").getOrCreate() + + tests = [ + ( + "All lower no intention of CS", + "test", + "test", + {"name": "VARCHAR"}, + "name", + '''SELECT "TEST"."NAME" AS "NAME" FROM "TEST" AS "TEST"''', + ), + ( + "Table has CS while column does not", + '"Test"', + '"Test"', + {"name": "VARCHAR"}, + "name", + '''SELECT "TEST"."NAME" AS "NAME" FROM "Test" AS "TEST"''', + ), + ( + "Column has CS while table does not", + "test", + "test", + {'"Name"': "VARCHAR"}, + '"Name"', + '''SELECT "TEST"."Name" AS "Name" FROM "TEST" AS "TEST"''', + ), + ( + "Both Table and column have CS", + '"Test"', + '"Test"', + {'"Name"': "VARCHAR"}, + '"Name"', + '''SELECT "TEST"."Name" AS "Name" FROM "Test" AS "TEST"''', + ), + ( + "Lowercase CS table and column", + '"test"', + '"test"', + {'"name"': "VARCHAR"}, + '"name"', + '''SELECT "TEST"."name" AS "name" FROM "test" AS "TEST"''', + ), + ( + "CS table and column and query table but no CS in query column", + '"test"', + '"test"', + {'"name"': "VARCHAR"}, + "name", + OptimizeError(), + ), + ( + "CS table and column and query column but no CS in query table", + '"test"', + "test", + {'"name"': "VARCHAR"}, + '"name"', + OptimizeError(), + ), + ] + + def test_basic_case_sensitivity(self): + for test_name, table_name, spark_table, schema, spark_column, expected in self.tests: + with self.subTest(test_name): + sqlglot.schema.add_table(table_name, schema, dialect=self.spark.dialect) + df = self.spark.table(spark_table).select(F.col(spark_column)) + if isinstance(expected, OptimizeError): + with self.assertRaises(OptimizeError): + df.sql() + else: + self.compare_sql(df, expected) |