1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
|
from __future__ import annotations
import typing as t
from sqlglot import exp, parser, tokens
from sqlglot.dialects.dialect import Dialect
from sqlglot.tokens import TokenType
def _select_all(table: exp.Expression) -> t.Optional[exp.Select]:
return exp.select("*").from_(table, copy=False) if table else None
class PRQL(Dialect):
DPIPE_IS_STRING_CONCAT = False
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ["`"]
QUOTES = ["'", '"']
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
"=": TokenType.ALIAS,
"'": TokenType.QUOTE,
'"': TokenType.QUOTE,
"`": TokenType.IDENTIFIER,
"#": TokenType.COMMENT,
}
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
}
class Parser(parser.Parser):
CONJUNCTION = {
**parser.Parser.CONJUNCTION,
TokenType.DAMP: exp.And,
TokenType.DPIPE: exp.Or,
}
TRANSFORM_PARSERS = {
"DERIVE": lambda self, query: self._parse_selection(query),
"SELECT": lambda self, query: self._parse_selection(query, append=False),
"TAKE": lambda self, query: self._parse_take(query),
"FILTER": lambda self, query: query.where(self._parse_conjunction()),
"APPEND": lambda self, query: query.union(
_select_all(self._parse_table()), distinct=False, copy=False
),
"REMOVE": lambda self, query: query.except_(
_select_all(self._parse_table()), distinct=False, copy=False
),
"INTERSECT": lambda self, query: query.intersect(
_select_all(self._parse_table()), distinct=False, copy=False
),
"SORT": lambda self, query: self._parse_order_by(query),
}
def _parse_statement(self) -> t.Optional[exp.Expression]:
expression = self._parse_expression()
expression = expression if expression else self._parse_query()
return expression
def _parse_query(self) -> t.Optional[exp.Query]:
from_ = self._parse_from()
if not from_:
return None
query = exp.select("*").from_(from_, copy=False)
while self._match_texts(self.TRANSFORM_PARSERS):
query = self.TRANSFORM_PARSERS[self._prev.text.upper()](self, query)
return query
def _parse_selection(self, query: exp.Query, append: bool = True) -> exp.Query:
if self._match(TokenType.L_BRACE):
selects = self._parse_csv(self._parse_expression)
if not self._match(TokenType.R_BRACE, expression=query):
self.raise_error("Expecting }")
else:
expression = self._parse_expression()
selects = [expression] if expression else []
projections = {
select.alias_or_name: select.this if isinstance(select, exp.Alias) else select
for select in query.selects
}
selects = [
select.transform(
lambda s: (projections[s.name].copy() if s.name in projections else s)
if isinstance(s, exp.Column)
else s,
copy=False,
)
for select in selects
]
return query.select(*selects, append=append, copy=False)
def _parse_take(self, query: exp.Query) -> t.Optional[exp.Query]:
num = self._parse_number() # TODO: TAKE for ranges a..b
return query.limit(num) if num else None
def _parse_ordered(
self, parse_method: t.Optional[t.Callable] = None
) -> t.Optional[exp.Ordered]:
asc = self._match(TokenType.PLUS)
desc = self._match(TokenType.DASH) or (asc and False)
term = term = super()._parse_ordered(parse_method=parse_method)
if term and desc:
term.set("desc", True)
term.set("nulls_first", False)
return term
def _parse_order_by(self, query: exp.Select) -> t.Optional[exp.Query]:
l_brace = self._match(TokenType.L_BRACE)
expressions = self._parse_csv(self._parse_ordered)
if l_brace and not self._match(TokenType.R_BRACE):
self.raise_error("Expecting }")
return query.order_by(self.expression(exp.Order, expressions=expressions), copy=False)
def _parse_expression(self) -> t.Optional[exp.Expression]:
if self._next and self._next.token_type == TokenType.ALIAS:
alias = self._parse_id_var(True)
self._match(TokenType.ALIAS)
return self.expression(exp.Alias, this=self._parse_conjunction(), alias=alias)
return self._parse_conjunction()
def _parse_table(
self,
schema: bool = False,
joins: bool = False,
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
is_db_reference: bool = False,
) -> t.Optional[exp.Expression]:
return self._parse_table_parts()
def _parse_from(
self, joins: bool = False, skip_from_token: bool = False
) -> t.Optional[exp.From]:
if not skip_from_token and not self._match(TokenType.FROM):
return None
return self.expression(
exp.From, comments=self._prev_comments, this=self._parse_table(joins=joins)
)
|