summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/prql.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-04 16:13:01 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-05-04 16:13:01 +0000
commita7044b672667f2a0b48bd0b326b5a55b0815ef79 (patch)
tree4fb5238d47fb4709d47f766a74b8bbaa9c6f17d8 /sqlglot/dialects/prql.py
parentReleasing debian version 23.12.1-1. (diff)
downloadsqlglot-a7044b672667f2a0b48bd0b326b5a55b0815ef79.tar.xz
sqlglot-a7044b672667f2a0b48bd0b326b5a55b0815ef79.zip
Merging upstream version 23.13.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/prql.py')
-rw-r--r--sqlglot/dialects/prql.py40
1 files changed, 37 insertions, 3 deletions
diff --git a/sqlglot/dialects/prql.py b/sqlglot/dialects/prql.py
index 028c309..ad0c647 100644
--- a/sqlglot/dialects/prql.py
+++ b/sqlglot/dialects/prql.py
@@ -4,6 +4,7 @@ import typing as t
from sqlglot import exp, parser, tokens
from sqlglot.dialects.dialect import Dialect
+from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -53,6 +54,15 @@ class PRQL(Dialect):
_select_all(self._parse_table()), distinct=False, copy=False
),
"SORT": lambda self, query: self._parse_order_by(query),
+ "AGGREGATE": lambda self, query: self._parse_selection(
+ query, parse_method=self._parse_aggregate, append=False
+ ),
+ }
+
+ FUNCTIONS = {
+ **parser.Parser.FUNCTIONS,
+ "AVERAGE": exp.Avg.from_arg_list,
+ "SUM": lambda args: exp.func("COALESCE", exp.Sum(this=seq_get(args, 0)), 0),
}
def _parse_equality(self) -> t.Optional[exp.Expression]:
@@ -87,14 +97,20 @@ class PRQL(Dialect):
return query
- def _parse_selection(self, query: exp.Query, append: bool = True) -> exp.Query:
+ def _parse_selection(
+ self,
+ query: exp.Query,
+ parse_method: t.Optional[t.Callable] = None,
+ append: bool = True,
+ ) -> exp.Query:
+ parse_method = parse_method if parse_method else self._parse_expression
if self._match(TokenType.L_BRACE):
- selects = self._parse_csv(self._parse_expression)
+ selects = self._parse_csv(parse_method)
if not self._match(TokenType.R_BRACE, expression=query):
self.raise_error("Expecting }")
else:
- expression = self._parse_expression()
+ expression = parse_method()
selects = [expression] if expression else []
projections = {
@@ -136,6 +152,24 @@ class PRQL(Dialect):
self.raise_error("Expecting }")
return query.order_by(self.expression(exp.Order, expressions=expressions), copy=False)
+ def _parse_aggregate(self) -> t.Optional[exp.Expression]:
+ alias = None
+ if self._next and self._next.token_type == TokenType.ALIAS:
+ alias = self._parse_id_var(any_token=True)
+ self._match(TokenType.ALIAS)
+
+ name = self._curr and self._curr.text.upper()
+ func_builder = self.FUNCTIONS.get(name)
+ if func_builder:
+ self._advance()
+ args = self._parse_column()
+ func = func_builder([args])
+ else:
+ self.raise_error(f"Unsupported aggregation function {name}")
+ if alias:
+ return self.expression(exp.Alias, this=func, alias=alias)
+ return func
+
def _parse_expression(self) -> t.Optional[exp.Expression]:
if self._next and self._next.token_type == TokenType.ALIAS:
alias = self._parse_id_var(True)