summaryrefslogtreecommitdiffstats
path: root/tests/dataframe/unit/test_session_case_sensitivity.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dataframe/unit/test_session_case_sensitivity.py')
-rw-r--r--tests/dataframe/unit/test_session_case_sensitivity.py81
1 files changed, 81 insertions, 0 deletions
diff --git a/tests/dataframe/unit/test_session_case_sensitivity.py b/tests/dataframe/unit/test_session_case_sensitivity.py
new file mode 100644
index 0000000..7e35289
--- /dev/null
+++ b/tests/dataframe/unit/test_session_case_sensitivity.py
@@ -0,0 +1,81 @@
+import sqlglot
+from sqlglot.dataframe.sql import functions as F
+from sqlglot.dataframe.sql.session import SparkSession
+from sqlglot.errors import OptimizeError
+from tests.dataframe.unit.dataframe_test_base import DataFrameTestBase
+
+
+class TestSessionCaseSensitivity(DataFrameTestBase):
+ def setUp(self) -> None:
+ super().setUp()
+ self.spark = SparkSession.builder.config("sqlframe.dialect", "snowflake").getOrCreate()
+
+ tests = [
+ (
+ "All lower no intention of CS",
+ "test",
+ "test",
+ {"name": "VARCHAR"},
+ "name",
+ '''SELECT "TEST"."NAME" AS "NAME" FROM "TEST" AS "TEST"''',
+ ),
+ (
+ "Table has CS while column does not",
+ '"Test"',
+ '"Test"',
+ {"name": "VARCHAR"},
+ "name",
+ '''SELECT "TEST"."NAME" AS "NAME" FROM "Test" AS "TEST"''',
+ ),
+ (
+ "Column has CS while table does not",
+ "test",
+ "test",
+ {'"Name"': "VARCHAR"},
+ '"Name"',
+ '''SELECT "TEST"."Name" AS "Name" FROM "TEST" AS "TEST"''',
+ ),
+ (
+ "Both Table and column have CS",
+ '"Test"',
+ '"Test"',
+ {'"Name"': "VARCHAR"},
+ '"Name"',
+ '''SELECT "TEST"."Name" AS "Name" FROM "Test" AS "TEST"''',
+ ),
+ (
+ "Lowercase CS table and column",
+ '"test"',
+ '"test"',
+ {'"name"': "VARCHAR"},
+ '"name"',
+ '''SELECT "TEST"."name" AS "name" FROM "test" AS "TEST"''',
+ ),
+ (
+ "CS table and column and query table but no CS in query column",
+ '"test"',
+ '"test"',
+ {'"name"': "VARCHAR"},
+ "name",
+ OptimizeError(),
+ ),
+ (
+ "CS table and column and query column but no CS in query table",
+ '"test"',
+ "test",
+ {'"name"': "VARCHAR"},
+ '"name"',
+ OptimizeError(),
+ ),
+ ]
+
+ def test_basic_case_sensitivity(self):
+ for test_name, table_name, spark_table, schema, spark_column, expected in self.tests:
+ with self.subTest(test_name):
+ sqlglot.schema.add_table(table_name, schema, dialect=self.spark.dialect)
+ df = self.spark.table(spark_table).select(F.col(spark_column))
+ if isinstance(expected, OptimizeError):
+ with self.assertRaises(OptimizeError):
+ df.sql()
+ else:
+ self.compare_sql(df, expected)