summaryrefslogtreecommitdiffstats
path: root/tests/parseutils/test_ctes.py
blob: 3e89ccafeef33b8c25cd0b52d4370ca5137c8c3d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import pytest
from sqlparse import parse
from pgcli.packages.parseutils.ctes import (
    token_start_pos,
    extract_ctes,
    extract_column_names as _extract_column_names,
)


def extract_column_names(sql):
    p = parse(sql)[0]
    return _extract_column_names(p)


def test_token_str_pos():
    sql = "SELECT * FROM xxx"
    p = parse(sql)[0]
    idx = p.token_index(p.tokens[-1])
    assert token_start_pos(p.tokens, idx) == len("SELECT * FROM ")

    sql = "SELECT * FROM \nxxx"
    p = parse(sql)[0]
    idx = p.token_index(p.tokens[-1])
    assert token_start_pos(p.tokens, idx) == len("SELECT * FROM \n")


def test_single_column_name_extraction():
    sql = "SELECT abc FROM xxx"
    assert extract_column_names(sql) == ("abc",)


def test_aliased_single_column_name_extraction():
    sql = "SELECT abc def FROM xxx"
    assert extract_column_names(sql) == ("def",)


def test_aliased_expression_name_extraction():
    sql = "SELECT 99 abc FROM xxx"
    assert extract_column_names(sql) == ("abc",)


def test_multiple_column_name_extraction():
    sql = "SELECT abc, def FROM xxx"
    assert extract_column_names(sql) == ("abc", "def")


def test_missing_column_name_handled_gracefully():
    sql = "SELECT abc, 99 FROM xxx"
    assert extract_column_names(sql) == ("abc",)

    sql = "SELECT abc, 99, def FROM xxx"
    assert extract_column_names(sql) == ("abc", "def")


def test_aliased_multiple_column_name_extraction():
    sql = "SELECT abc def, ghi jkl FROM xxx"
    assert extract_column_names(sql) == ("def", "jkl")


def test_table_qualified_column_name_extraction():
    sql = "SELECT abc.def, ghi.jkl FROM xxx"
    assert extract_column_names(sql) == ("def", "jkl")


@pytest.mark.parametrize(
    "sql",
    [
        "INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y",
        "DELETE FROM foo WHERE x > y RETURNING x, y",
        "UPDATE foo SET x = 9 RETURNING x, y",
    ],
)
def test_extract_column_names_from_returning_clause(sql):
    assert extract_column_names(sql) == ("x", "y")


def test_simple_cte_extraction():
    sql = "WITH a AS (SELECT abc FROM xxx) SELECT * FROM a"
    start_pos = len("WITH a AS ")
    stop_pos = len("WITH a AS (SELECT abc FROM xxx)")
    ctes, remainder = extract_ctes(sql)

    assert tuple(ctes) == (("a", ("abc",), start_pos, stop_pos),)
    assert remainder.strip() == "SELECT * FROM a"


def test_cte_extraction_around_comments():
    sql = """--blah blah blah
            WITH a AS (SELECT abc def FROM x)
            SELECT * FROM a"""
    start_pos = len(
        """--blah blah blah
            WITH a AS """
    )
    stop_pos = len(
        """--blah blah blah
            WITH a AS (SELECT abc def FROM x)"""
    )

    ctes, remainder = extract_ctes(sql)
    assert tuple(ctes) == (("a", ("def",), start_pos, stop_pos),)
    assert remainder.strip() == "SELECT * FROM a"


def test_multiple_cte_extraction():
    sql = """WITH
            x AS (SELECT abc, def FROM x),
            y AS (SELECT ghi, jkl FROM y)
            SELECT * FROM a, b"""

    start1 = len(
        """WITH
            x AS """
    )

    stop1 = len(
        """WITH
            x AS (SELECT abc, def FROM x)"""
    )

    start2 = len(
        """WITH
            x AS (SELECT abc, def FROM x),
            y AS """
    )

    stop2 = len(
        """WITH
            x AS (SELECT abc, def FROM x),
            y AS (SELECT ghi, jkl FROM y)"""
    )

    ctes, remainder = extract_ctes(sql)
    assert tuple(ctes) == (
        ("x", ("abc", "def"), start1, stop1),
        ("y", ("ghi", "jkl"), start2, stop2),
    )