summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/prql.py
blob: 3ee91a822c06f195e71e5cb12c566992f13e874a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from __future__ import annotations

import typing as t

from sqlglot import exp, parser, tokens
from sqlglot.dialects.dialect import Dialect
from sqlglot.tokens import TokenType


def _select_all(table: exp.Expression) -> t.Optional[exp.Select]:
    return exp.select("*").from_(table, copy=False) if table else None


class PRQL(Dialect):
    DPIPE_IS_STRING_CONCAT = False

    class Tokenizer(tokens.Tokenizer):
        IDENTIFIERS = ["`"]
        QUOTES = ["'", '"']

        SINGLE_TOKENS = {
            **tokens.Tokenizer.SINGLE_TOKENS,
            "=": TokenType.ALIAS,
            "'": TokenType.QUOTE,
            '"': TokenType.QUOTE,
            "`": TokenType.IDENTIFIER,
            "#": TokenType.COMMENT,
        }

        KEYWORDS = {
            **tokens.Tokenizer.KEYWORDS,
        }

    class Parser(parser.Parser):
        CONJUNCTION = {
            **parser.Parser.CONJUNCTION,
            TokenType.DAMP: exp.And,
            TokenType.DPIPE: exp.Or,
        }

        TRANSFORM_PARSERS = {
            "DERIVE": lambda self, query: self._parse_selection(query),
            "SELECT": lambda self, query: self._parse_selection(query, append=False),
            "TAKE": lambda self, query: self._parse_take(query),
            "FILTER": lambda self, query: query.where(self._parse_conjunction()),
            "APPEND": lambda self, query: query.union(
                _select_all(self._parse_table()), distinct=False, copy=False
            ),
            "REMOVE": lambda self, query: query.except_(
                _select_all(self._parse_table()), distinct=False, copy=False
            ),
            "INTERSECT": lambda self, query: query.intersect(
                _select_all(self._parse_table()), distinct=False, copy=False
            ),
            "SORT": lambda self, query: self._parse_order_by(query),
        }

        def _parse_statement(self) -> t.Optional[exp.Expression]:
            expression = self._parse_expression()
            expression = expression if expression else self._parse_query()
            return expression

        def _parse_query(self) -> t.Optional[exp.Query]:
            from_ = self._parse_from()

            if not from_:
                return None

            query = exp.select("*").from_(from_, copy=False)

            while self._match_texts(self.TRANSFORM_PARSERS):
                query = self.TRANSFORM_PARSERS[self._prev.text.upper()](self, query)

            return query

        def _parse_selection(self, query: exp.Query, append: bool = True) -> exp.Query:
            if self._match(TokenType.L_BRACE):
                selects = self._parse_csv(self._parse_expression)

                if not self._match(TokenType.R_BRACE, expression=query):
                    self.raise_error("Expecting }")
            else:
                expression = self._parse_expression()
                selects = [expression] if expression else []

            projections = {
                select.alias_or_name: select.this if isinstance(select, exp.Alias) else select
                for select in query.selects
            }

            selects = [
                select.transform(
                    lambda s: (projections[s.name].copy() if s.name in projections else s)
                    if isinstance(s, exp.Column)
                    else s,
                    copy=False,
                )
                for select in selects
            ]

            return query.select(*selects, append=append, copy=False)

        def _parse_take(self, query: exp.Query) -> t.Optional[exp.Query]:
            num = self._parse_number()  # TODO: TAKE for ranges a..b
            return query.limit(num) if num else None

        def _parse_ordered(
            self, parse_method: t.Optional[t.Callable] = None
        ) -> t.Optional[exp.Ordered]:
            asc = self._match(TokenType.PLUS)
            desc = self._match(TokenType.DASH) or (asc and False)
            term = term = super()._parse_ordered(parse_method=parse_method)
            if term and desc:
                term.set("desc", True)
                term.set("nulls_first", False)
            return term

        def _parse_order_by(self, query: exp.Select) -> t.Optional[exp.Query]:
            l_brace = self._match(TokenType.L_BRACE)
            expressions = self._parse_csv(self._parse_ordered)
            if l_brace and not self._match(TokenType.R_BRACE):
                self.raise_error("Expecting }")
            return query.order_by(self.expression(exp.Order, expressions=expressions), copy=False)

        def _parse_expression(self) -> t.Optional[exp.Expression]:
            if self._next and self._next.token_type == TokenType.ALIAS:
                alias = self._parse_id_var(True)
                self._match(TokenType.ALIAS)
                return self.expression(exp.Alias, this=self._parse_conjunction(), alias=alias)
            return self._parse_conjunction()

        def _parse_table(
            self,
            schema: bool = False,
            joins: bool = False,
            alias_tokens: t.Optional[t.Collection[TokenType]] = None,
            parse_bracket: bool = False,
            is_db_reference: bool = False,
        ) -> t.Optional[exp.Expression]:
            return self._parse_table_parts()

        def _parse_from(
            self, joins: bool = False, skip_from_token: bool = False
        ) -> t.Optional[exp.From]:
            if not skip_from_token and not self._match(TokenType.FROM):
                return None

            return self.expression(
                exp.From, comments=self._prev_comments, this=self._parse_table(joins=joins)
            )