import pytest from sqlparse import parse from pgcli.packages.parseutils.ctes import ( token_start_pos, extract_ctes, extract_column_names as _extract_column_names, ) def extract_column_names(sql): p = parse(sql)[0] return _extract_column_names(p) def test_token_str_pos(): sql = "SELECT * FROM xxx" p = parse(sql)[0] idx = p.token_index(p.tokens[-1]) assert token_start_pos(p.tokens, idx) == len("SELECT * FROM ") sql = "SELECT * FROM \nxxx" p = parse(sql)[0] idx = p.token_index(p.tokens[-1]) assert token_start_pos(p.tokens, idx) == len("SELECT * FROM \n") def test_single_column_name_extraction(): sql = "SELECT abc FROM xxx" assert extract_column_names(sql) == ("abc",) def test_aliased_single_column_name_extraction(): sql = "SELECT abc def FROM xxx" assert extract_column_names(sql) == ("def",) def test_aliased_expression_name_extraction(): sql = "SELECT 99 abc FROM xxx" assert extract_column_names(sql) == ("abc",) def test_multiple_column_name_extraction(): sql = "SELECT abc, def FROM xxx" assert extract_column_names(sql) == ("abc", "def") def test_missing_column_name_handled_gracefully(): sql = "SELECT abc, 99 FROM xxx" assert extract_column_names(sql) == ("abc",) sql = "SELECT abc, 99, def FROM xxx" assert extract_column_names(sql) == ("abc", "def") def test_aliased_multiple_column_name_extraction(): sql = "SELECT abc def, ghi jkl FROM xxx" assert extract_column_names(sql) == ("def", "jkl") def test_table_qualified_column_name_extraction(): sql = "SELECT abc.def, ghi.jkl FROM xxx" assert extract_column_names(sql) == ("def", "jkl") @pytest.mark.parametrize( "sql", [ "INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y", "DELETE FROM foo WHERE x > y RETURNING x, y", "UPDATE foo SET x = 9 RETURNING x, y", ], ) def test_extract_column_names_from_returning_clause(sql): assert extract_column_names(sql) == ("x", "y") def test_simple_cte_extraction(): sql = "WITH a AS (SELECT abc FROM xxx) SELECT * FROM a" start_pos = len("WITH a AS ") stop_pos = len("WITH a AS (SELECT abc FROM xxx)") ctes, remainder = extract_ctes(sql) assert tuple(ctes) == (("a", ("abc",), start_pos, stop_pos),) assert remainder.strip() == "SELECT * FROM a" def test_cte_extraction_around_comments(): sql = """--blah blah blah WITH a AS (SELECT abc def FROM x) SELECT * FROM a""" start_pos = len( """--blah blah blah WITH a AS """ ) stop_pos = len( """--blah blah blah WITH a AS (SELECT abc def FROM x)""" ) ctes, remainder = extract_ctes(sql) assert tuple(ctes) == (("a", ("def",), start_pos, stop_pos),) assert remainder.strip() == "SELECT * FROM a" def test_multiple_cte_extraction(): sql = """WITH x AS (SELECT abc, def FROM x), y AS (SELECT ghi, jkl FROM y) SELECT * FROM a, b""" start1 = len( """WITH x AS """ ) stop1 = len( """WITH x AS (SELECT abc, def FROM x)""" ) start2 = len( """WITH x AS (SELECT abc, def FROM x), y AS """ ) stop2 = len( """WITH x AS (SELECT abc, def FROM x), y AS (SELECT ghi, jkl FROM y)""" ) ctes, remainder = extract_ctes(sql) assert tuple(ctes) == ( ("x", ("abc", "def"), start1, stop1), ("y", ("ghi", "jkl"), start2, stop2), )