summaryrefslogtreecommitdiffstats
path: root/tests/dialects/test_presto.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dialects/test_presto.py')
-rw-r--r--tests/dialects/test_presto.py54
1 files changed, 54 insertions, 0 deletions
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index 15962cc..1f5953c 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -438,6 +438,36 @@ class TestPresto(Validator):
self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' week"})
self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' WEEKS"})
self.validate_all(
+ "WITH RECURSIVE t(n) AS (SELECT 1 AS n UNION ALL SELECT n + 1 AS n FROM t WHERE n < 4) SELECT SUM(n) FROM t",
+ read={
+ "postgres": "WITH RECURSIVE t AS (SELECT 1 AS n UNION ALL SELECT n + 1 AS n FROM t WHERE n < 4) SELECT SUM(n) FROM t",
+ },
+ )
+ self.validate_all(
+ "WITH RECURSIVE t(n, k) AS (SELECT 1 AS n, 2 AS k) SELECT SUM(n) FROM t",
+ read={
+ "postgres": "WITH RECURSIVE t AS (SELECT 1 AS n, 2 as k) SELECT SUM(n) FROM t",
+ },
+ )
+ self.validate_all(
+ "WITH RECURSIVE t1(n) AS (SELECT 1 AS n), t2(n) AS (SELECT 2 AS n) SELECT SUM(t1.n), SUM(t2.n) FROM t1, t2",
+ read={
+ "postgres": "WITH RECURSIVE t1 AS (SELECT 1 AS n), t2 AS (SELECT 2 AS n) SELECT SUM(t1.n), SUM(t2.n) FROM t1, t2",
+ },
+ )
+ self.validate_all(
+ "WITH RECURSIVE t(n, _c_0) AS (SELECT 1 AS n, (1 + 2)) SELECT * FROM t",
+ read={
+ "postgres": "WITH RECURSIVE t AS (SELECT 1 AS n, (1 + 2)) SELECT * FROM t",
+ },
+ )
+ self.validate_all(
+ 'WITH RECURSIVE t(n, "1") AS (SELECT n, 1 FROM tbl) SELECT * FROM t',
+ read={
+ "postgres": "WITH RECURSIVE t AS (SELECT n, 1 FROM tbl) SELECT * FROM t",
+ },
+ )
+ self.validate_all(
"SELECT JSON_OBJECT(KEY 'key1' VALUE 1, KEY 'key2' VALUE TRUE)",
write={
"presto": "SELECT JSON_OBJECT('key1': 1, 'key2': TRUE)",
@@ -757,3 +787,27 @@ class TestPresto(Validator):
"SELECT col, pos, pos_2, col_2 FROM _u CROSS JOIN UNNEST(SEQUENCE(2, 3)) WITH ORDINALITY AS _u_2(col_2, pos_2)",
read={"spark": "SELECT col, pos, POSEXPLODE(SEQUENCE(2, 3)) FROM _u"},
)
+
+ def test_match_recognize(self):
+ self.validate_identity(
+ """SELECT
+ *
+FROM orders
+MATCH_RECOGNIZE (
+ PARTITION BY custkey
+ ORDER BY
+ orderdate
+ MEASURES
+ A.totalprice AS starting_price,
+ LAST(B.totalprice) AS bottom_price,
+ LAST(C.totalprice) AS top_price
+ ONE ROW PER MATCH
+ AFTER MATCH SKIP PAST LAST ROW
+ PATTERN (A B+ C+ D+)
+ DEFINE
+ B AS totalprice < PREV(totalprice),
+ C AS totalprice > PREV(totalprice) AND totalprice <= A.totalprice,
+ D AS totalprice > PREV(totalprice)
+)""",
+ pretty=True,
+ )