1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
|
from sqlparse import parse
from sqlparse.tokens import Keyword, CTE, DML
from sqlparse.sql import Identifier, IdentifierList, Parenthesis
from collections import namedtuple
from .meta import TableMetadata, ColumnMetadata
# TableExpression is a namedtuple representing a CTE, used internally
# name: cte alias assigned in the query
# columns: list of column names
# start: index into the original string of the left parens starting the CTE
# stop: index into the original string of the right parens ending the CTE
TableExpression = namedtuple("TableExpression", "name columns start stop")
def isolate_query_ctes(full_text, text_before_cursor):
"""Simplify a query by converting CTEs into table metadata objects"""
if not full_text or not full_text.strip():
return full_text, text_before_cursor, tuple()
ctes, remainder = extract_ctes(full_text)
if not ctes:
return full_text, text_before_cursor, ()
current_position = len(text_before_cursor)
meta = []
for cte in ctes:
if cte.start < current_position < cte.stop:
# Currently editing a cte - treat its body as the current full_text
text_before_cursor = full_text[cte.start : current_position]
full_text = full_text[cte.start : cte.stop]
return full_text, text_before_cursor, meta
# Append this cte to the list of available table metadata
cols = (ColumnMetadata(name, None, ()) for name in cte.columns)
meta.append(TableMetadata(cte.name, cols))
# Editing past the last cte (ie the main body of the query)
full_text = full_text[ctes[-1].stop :]
text_before_cursor = text_before_cursor[ctes[-1].stop : current_position]
return full_text, text_before_cursor, tuple(meta)
def extract_ctes(sql):
"""Extract constant table expresseions from a query
Returns tuple (ctes, remainder_sql)
ctes is a list of TableExpression namedtuples
remainder_sql is the text from the original query after the CTEs have
been stripped.
"""
p = parse(sql)[0]
# Make sure the first meaningful token is "WITH" which is necessary to
# define CTEs
idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True)
if not (tok and tok.ttype == CTE):
return [], sql
# Get the next (meaningful) token, which should be the first CTE
idx, tok = p.token_next(idx)
if not tok:
return ([], "")
start_pos = token_start_pos(p.tokens, idx)
ctes = []
if isinstance(tok, IdentifierList):
# Multiple ctes
for t in tok.get_identifiers():
cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t))
cte = get_cte_from_token(t, start_pos + cte_start_offset)
if not cte:
continue
ctes.append(cte)
elif isinstance(tok, Identifier):
# A single CTE
cte = get_cte_from_token(tok, start_pos)
if cte:
ctes.append(cte)
idx = p.token_index(tok) + 1
# Collapse everything after the ctes into a remainder query
remainder = "".join(str(tok) for tok in p.tokens[idx:])
return ctes, remainder
def get_cte_from_token(tok, pos0):
cte_name = tok.get_real_name()
if not cte_name:
return None
# Find the start position of the opening parens enclosing the cte body
idx, parens = tok.token_next_by(Parenthesis)
if not parens:
return None
start_pos = pos0 + token_start_pos(tok.tokens, idx)
cte_len = len(str(parens)) # includes parens
stop_pos = start_pos + cte_len
column_names = extract_column_names(parens)
return TableExpression(cte_name, column_names, start_pos, stop_pos)
def extract_column_names(parsed):
# Find the first DML token to check if it's a SELECT or INSERT/UPDATE/DELETE
idx, tok = parsed.token_next_by(t=DML)
tok_val = tok and tok.value.lower()
if tok_val in ("insert", "update", "delete"):
# Jump ahead to the RETURNING clause where the list of column names is
idx, tok = parsed.token_next_by(idx, (Keyword, "returning"))
elif not tok_val == "select":
# Must be invalid CTE
return ()
# The next token should be either a column name, or a list of column names
idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True)
return tuple(t.get_name() for t in _identifiers(tok))
def token_start_pos(tokens, idx):
return sum(len(str(t)) for t in tokens[:idx])
def _identifiers(tok):
if isinstance(tok, IdentifierList):
for t in tok.get_identifiers():
# NB: IdentifierList.get_identifiers() can return non-identifiers!
if isinstance(t, Identifier):
yield t
elif isinstance(tok, Identifier):
yield tok
|