summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/prql.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-15 05:02:18 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-15 05:02:18 +0000
commit41f1f5740d2140bfd3b2a282ca1087a4b576679a (patch)
tree0b1eb5ba5c759d08b05d56e50675784b6170f955 /sqlglot/dialects/prql.py
parentReleasing debian version 23.7.0-1. (diff)
downloadsqlglot-41f1f5740d2140bfd3b2a282ca1087a4b576679a.tar.xz
sqlglot-41f1f5740d2140bfd3b2a282ca1087a4b576679a.zip
Merging upstream version 23.10.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/prql.py')
-rw-r--r--sqlglot/dialects/prql.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/sqlglot/dialects/prql.py b/sqlglot/dialects/prql.py
index 3005753..3ee91a8 100644
--- a/sqlglot/dialects/prql.py
+++ b/sqlglot/dialects/prql.py
@@ -7,7 +7,13 @@ from sqlglot.dialects.dialect import Dialect
from sqlglot.tokens import TokenType
+def _select_all(table: exp.Expression) -> t.Optional[exp.Select]:
+ return exp.select("*").from_(table, copy=False) if table else None
+
+
class PRQL(Dialect):
+ DPIPE_IS_STRING_CONCAT = False
+
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ["`"]
QUOTES = ["'", '"']
@@ -26,10 +32,27 @@ class PRQL(Dialect):
}
class Parser(parser.Parser):
+ CONJUNCTION = {
+ **parser.Parser.CONJUNCTION,
+ TokenType.DAMP: exp.And,
+ TokenType.DPIPE: exp.Or,
+ }
+
TRANSFORM_PARSERS = {
"DERIVE": lambda self, query: self._parse_selection(query),
"SELECT": lambda self, query: self._parse_selection(query, append=False),
"TAKE": lambda self, query: self._parse_take(query),
+ "FILTER": lambda self, query: query.where(self._parse_conjunction()),
+ "APPEND": lambda self, query: query.union(
+ _select_all(self._parse_table()), distinct=False, copy=False
+ ),
+ "REMOVE": lambda self, query: query.except_(
+ _select_all(self._parse_table()), distinct=False, copy=False
+ ),
+ "INTERSECT": lambda self, query: query.intersect(
+ _select_all(self._parse_table()), distinct=False, copy=False
+ ),
+ "SORT": lambda self, query: self._parse_order_by(query),
}
def _parse_statement(self) -> t.Optional[exp.Expression]:
@@ -81,6 +104,24 @@ class PRQL(Dialect):
num = self._parse_number() # TODO: TAKE for ranges a..b
return query.limit(num) if num else None
+ def _parse_ordered(
+ self, parse_method: t.Optional[t.Callable] = None
+ ) -> t.Optional[exp.Ordered]:
+ asc = self._match(TokenType.PLUS)
+ desc = self._match(TokenType.DASH) or (asc and False)
+ term = term = super()._parse_ordered(parse_method=parse_method)
+ if term and desc:
+ term.set("desc", True)
+ term.set("nulls_first", False)
+ return term
+
+ def _parse_order_by(self, query: exp.Select) -> t.Optional[exp.Query]:
+ l_brace = self._match(TokenType.L_BRACE)
+ expressions = self._parse_csv(self._parse_ordered)
+ if l_brace and not self._match(TokenType.R_BRACE):
+ self.raise_error("Expecting }")
+ return query.order_by(self.expression(exp.Order, expressions=expressions), copy=False)
+
def _parse_expression(self) -> t.Optional[exp.Expression]:
if self._next and self._next.token_type == TokenType.ALIAS:
alias = self._parse_id_var(True)