1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
from __future__ import annotations
import typing as t
from sqlglot import exp, parser, tokens
from sqlglot.dialects.dialect import Dialect
from sqlglot.tokens import TokenType
class PRQL(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ["`"]
QUOTES = ["'", '"']
SINGLE_TOKENS = {
**tokens.Tokenizer.SINGLE_TOKENS,
"=": TokenType.ALIAS,
"'": TokenType.QUOTE,
'"': TokenType.QUOTE,
"`": TokenType.IDENTIFIER,
"#": TokenType.COMMENT,
}
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
}
class Parser(parser.Parser):
TRANSFORM_PARSERS = {
"DERIVE": lambda self, query: self._parse_selection(query),
"SELECT": lambda self, query: self._parse_selection(query, append=False),
"TAKE": lambda self, query: self._parse_take(query),
}
def _parse_statement(self) -> t.Optional[exp.Expression]:
expression = self._parse_expression()
expression = expression if expression else self._parse_query()
return expression
def _parse_query(self) -> t.Optional[exp.Query]:
from_ = self._parse_from()
if not from_:
return None
query = exp.select("*").from_(from_, copy=False)
while self._match_texts(self.TRANSFORM_PARSERS):
query = self.TRANSFORM_PARSERS[self._prev.text.upper()](self, query)
return query
def _parse_selection(self, query: exp.Query, append: bool = True) -> exp.Query:
if self._match(TokenType.L_BRACE):
selects = self._parse_csv(self._parse_expression)
if not self._match(TokenType.R_BRACE, expression=query):
self.raise_error("Expecting }")
else:
expression = self._parse_expression()
selects = [expression] if expression else []
projections = {
select.alias_or_name: select.this if isinstance(select, exp.Alias) else select
for select in query.selects
}
selects = [
select.transform(
lambda s: (projections[s.name].copy() if s.name in projections else s)
if isinstance(s, exp.Column)
else s,
copy=False,
)
for select in selects
]
return query.select(*selects, append=append, copy=False)
def _parse_take(self, query: exp.Query) -> t.Optional[exp.Query]:
num = self._parse_number() # TODO: TAKE for ranges a..b
return query.limit(num) if num else None
def _parse_expression(self) -> t.Optional[exp.Expression]:
if self._next and self._next.token_type == TokenType.ALIAS:
alias = self._parse_id_var(True)
self._match(TokenType.ALIAS)
return self.expression(exp.Alias, this=self._parse_conjunction(), alias=alias)
return self._parse_conjunction()
def _parse_table(
self,
schema: bool = False,
joins: bool = False,
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
is_db_reference: bool = False,
) -> t.Optional[exp.Expression]:
return self._parse_table_parts()
def _parse_from(
self, joins: bool = False, skip_from_token: bool = False
) -> t.Optional[exp.From]:
if not skip_from_token and not self._match(TokenType.FROM):
return None
return self.expression(
exp.From, comments=self._prev_comments, this=self._parse_table(joins=joins)
)
|