summaryrefslogtreecommitdiffstats
path: root/tests/dataframe/unit/test_session_case_sensitivity.py
blob: 462edb69a24527dc9a48c94f8f7999df46f80550 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import sqlglot
from sqlglot.dataframe.sql import functions as F
from sqlglot.dataframe.sql.session import SparkSession
from sqlglot.errors import OptimizeError
from tests.dataframe.unit.dataframe_test_base import DataFrameTestBase


class TestSessionCaseSensitivity(DataFrameTestBase):
    def setUp(self) -> None:
        super().setUp()
        self.spark = SparkSession.builder.config("sqlframe.dialect", "snowflake").getOrCreate()

    tests = [
        (
            "All lower no intention of CS",
            "test",
            "test",
            {"name": "VARCHAR"},
            "name",
            '''SELECT "TEST"."NAME" AS "NAME" FROM "TEST" AS "TEST"''',
        ),
        (
            "Table has CS while column does not",
            '"Test"',
            '"Test"',
            {"name": "VARCHAR"},
            "name",
            '''SELECT "Test"."NAME" AS "NAME" FROM "Test" AS "Test"''',
        ),
        (
            "Column has CS while table does not",
            "test",
            "test",
            {'"Name"': "VARCHAR"},
            '"Name"',
            '''SELECT "TEST"."Name" AS "Name" FROM "TEST" AS "TEST"''',
        ),
        (
            "Both Table and column have CS",
            '"Test"',
            '"Test"',
            {'"Name"': "VARCHAR"},
            '"Name"',
            '''SELECT "Test"."Name" AS "Name" FROM "Test" AS "Test"''',
        ),
        (
            "Lowercase CS table and column",
            '"test"',
            '"test"',
            {'"name"': "VARCHAR"},
            '"name"',
            '''SELECT "test"."name" AS "name" FROM "test" AS "test"''',
        ),
        (
            "CS table and column and query table but no CS in query column",
            '"test"',
            '"test"',
            {'"name"': "VARCHAR"},
            "name",
            OptimizeError(),
        ),
        (
            "CS table and column and query column but no CS in query table",
            '"test"',
            "test",
            {'"name"': "VARCHAR"},
            '"name"',
            OptimizeError(),
        ),
    ]

    def test_basic_case_sensitivity(self):
        for test_name, table_name, spark_table, schema, spark_column, expected in self.tests:
            with self.subTest(test_name):
                sqlglot.schema.add_table(table_name, schema, dialect=self.spark.dialect)
                df = self.spark.table(spark_table).select(F.col(spark_column))
                if isinstance(expected, OptimizeError):
                    with self.assertRaises(OptimizeError):
                        df.sql()
                else:
                    self.compare_sql(df, expected)

    def test_alias(self):
        col = F.col('"Name"')
        self.assertEqual(col.sql(dialect=self.spark.dialect), '"Name"')
        self.assertEqual(col.alias("nAME").sql(dialect=self.spark.dialect), '"Name" AS NAME')
        self.assertEqual(col.alias('"nAME"').sql(dialect=self.spark.dialect), '"Name" AS "nAME"')