summaryrefslogtreecommitdiffstats
path: root/tests/dialects/test_presto.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/dialects/test_presto.py138
1 files changed, 116 insertions, 22 deletions
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index f1bbcc1..4c10a45 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -14,6 +14,23 @@ class TestPresto(Validator):
self.validate_identity("CAST(x AS HYPERLOGLOG)")
self.validate_all(
+ "CAST(x AS BOOLEAN)",
+ read={
+ "tsql": "CAST(x AS BIT)",
+ },
+ write={
+ "presto": "CAST(x AS BOOLEAN)",
+ "tsql": "CAST(x AS BIT)",
+ },
+ )
+ self.validate_all(
+ "SELECT FROM_ISO8601_TIMESTAMP('2020-05-11T11:15:05')",
+ write={
+ "duckdb": "SELECT CAST('2020-05-11T11:15:05' AS TIMESTAMPTZ)",
+ "presto": "SELECT FROM_ISO8601_TIMESTAMP('2020-05-11T11:15:05')",
+ },
+ )
+ self.validate_all(
"CAST(x AS INTERVAL YEAR TO MONTH)",
write={
"oracle": "CAST(x AS INTERVAL YEAR TO MONTH)",
@@ -151,8 +168,8 @@ class TestPresto(Validator):
write={
"duckdb": "STR_SPLIT(x, 'a.')",
"presto": "SPLIT(x, 'a.')",
- "hive": "SPLIT(x, CONCAT('\\\\Q', 'a.'))",
- "spark": "SPLIT(x, CONCAT('\\\\Q', 'a.'))",
+ "hive": "SPLIT(x, CONCAT('\\\\Q', 'a.', '\\\\E'))",
+ "spark": "SPLIT(x, CONCAT('\\\\Q', 'a.', '\\\\E'))",
},
)
self.validate_all(
@@ -269,10 +286,19 @@ class TestPresto(Validator):
self.validate_all(
"DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')",
write={
- "duckdb": "STRPTIME(SUBSTR(x, 1, 10), '%Y-%m-%d')",
- "presto": "DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')",
- "hive": "CAST(SUBSTR(x, 1, 10) AS TIMESTAMP)",
- "spark": "TO_TIMESTAMP(SUBSTR(x, 1, 10), 'yyyy-MM-dd')",
+ "duckdb": "STRPTIME(SUBSTRING(x, 1, 10), '%Y-%m-%d')",
+ "presto": "DATE_PARSE(SUBSTRING(x, 1, 10), '%Y-%m-%d')",
+ "hive": "CAST(SUBSTRING(x, 1, 10) AS TIMESTAMP)",
+ "spark": "TO_TIMESTAMP(SUBSTRING(x, 1, 10), 'yyyy-MM-dd')",
+ },
+ )
+ self.validate_all(
+ "DATE_PARSE(SUBSTRING(x, 1, 10), '%Y-%m-%d')",
+ write={
+ "duckdb": "STRPTIME(SUBSTRING(x, 1, 10), '%Y-%m-%d')",
+ "presto": "DATE_PARSE(SUBSTRING(x, 1, 10), '%Y-%m-%d')",
+ "hive": "CAST(SUBSTRING(x, 1, 10) AS TIMESTAMP)",
+ "spark": "TO_TIMESTAMP(SUBSTRING(x, 1, 10), 'yyyy-MM-dd')",
},
)
self.validate_all(
@@ -322,11 +348,20 @@ class TestPresto(Validator):
},
)
self.validate_all(
- "DAY_OF_WEEK(timestamp '2012-08-08 01:00:00')",
- write={
+ "((DAY_OF_WEEK(CAST(TRY_CAST('2012-08-08 01:00:00' AS TIMESTAMP) AS DATE)) % 7) + 1)",
+ read={
"spark": "DAYOFWEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
+ },
+ )
+ self.validate_all(
+ "DAY_OF_WEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
+ read={
+ "duckdb": "ISODOW(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
+ },
+ write={
+ "spark": "((DAYOFWEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP)) % 7) + 1)",
"presto": "DAY_OF_WEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
- "duckdb": "DAYOFWEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
+ "duckdb": "ISODOW(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
},
)
@@ -405,6 +440,27 @@ class TestPresto(Validator):
)
self.validate_identity("DATE_ADD('DAY', 1, y)")
+ self.validate_all(
+ "SELECT DATE_ADD('MINUTE', 30, col)",
+ write={
+ "presto": "SELECT DATE_ADD('MINUTE', 30, col)",
+ "trino": "SELECT DATE_ADD('MINUTE', 30, col)",
+ },
+ )
+
+ self.validate_identity("DATE_ADD('DAY', FLOOR(5), y)")
+ self.validate_identity(
+ """SELECT DATE_ADD('DAY', MOD(5, 2.5), y), DATE_ADD('DAY', CEIL(5.5), y)""",
+ """SELECT DATE_ADD('DAY', CAST(5 % 2.5 AS BIGINT), y), DATE_ADD('DAY', CAST(CEIL(5.5) AS BIGINT), y)""",
+ )
+
+ self.validate_all(
+ "DATE_ADD('MINUTE', CAST(FLOOR(CAST(EXTRACT(MINUTE FROM CURRENT_TIMESTAMP) AS DOUBLE) / NULLIF(30, 0)) * 30 AS BIGINT), col)",
+ read={
+ "spark": "TIMESTAMPADD(MINUTE, FLOOR(EXTRACT(MINUTE FROM CURRENT_TIMESTAMP)/30)*30, col)",
+ },
+ )
+
def test_ddl(self):
self.validate_all(
"CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1",
@@ -494,6 +550,9 @@ class TestPresto(Validator):
},
)
+ self.validate_identity("""CREATE OR REPLACE VIEW v SECURITY DEFINER AS SELECT id FROM t""")
+ self.validate_identity("""CREATE OR REPLACE VIEW v SECURITY INVOKER AS SELECT id FROM t""")
+
def test_quotes(self):
self.validate_all(
"''''",
@@ -564,6 +623,7 @@ class TestPresto(Validator):
self.validate_all(
f"{prefix}'Hello winter \\2603 !'",
write={
+ "oracle": "U'Hello winter \\2603 !'",
"presto": "U&'Hello winter \\2603 !'",
"snowflake": "'Hello winter \\u2603 !'",
"spark": "'Hello winter \\u2603 !'",
@@ -572,6 +632,7 @@ class TestPresto(Validator):
self.validate_all(
f"{prefix}'Hello winter #2603 !' UESCAPE '#'",
write={
+ "oracle": "U'Hello winter \\2603 !'",
"presto": "U&'Hello winter #2603 !' UESCAPE '#'",
"snowflake": "'Hello winter \\u2603 !'",
"spark": "'Hello winter \\u2603 !'",
@@ -579,6 +640,13 @@ class TestPresto(Validator):
)
def test_presto(self):
+ self.assertEqual(
+ exp.func("md5", exp.func("concat", exp.cast("x", "text"), exp.Literal.string("s"))).sql(
+ dialect="presto"
+ ),
+ "LOWER(TO_HEX(MD5(TO_UTF8(CONCAT(CAST(x AS VARCHAR), CAST('s' AS VARCHAR))))))",
+ )
+
with self.assertLogs(helper_logger):
self.validate_all(
"SELECT COALESCE(ELEMENT_AT(MAP_FROM_ENTRIES(ARRAY[(51, '1')]), id), quantity) FROM my_table",
@@ -597,6 +665,7 @@ class TestPresto(Validator):
},
)
+ self.validate_identity("SELECT a FROM t GROUP BY a, ROLLUP (b), ROLLUP (c), ROLLUP (d)")
self.validate_identity("SELECT a FROM test TABLESAMPLE BERNOULLI (50)")
self.validate_identity("SELECT a FROM test TABLESAMPLE SYSTEM (75)")
self.validate_identity("string_agg(x, ',')", "ARRAY_JOIN(ARRAY_AGG(x), ',')")
@@ -678,9 +747,6 @@ class TestPresto(Validator):
)
self.validate_all(
"SELECT ROW(1, 2)",
- read={
- "spark": "SELECT STRUCT(1, 2)",
- },
write={
"presto": "SELECT ROW(1, 2)",
"spark": "SELECT STRUCT(1, 2)",
@@ -799,12 +865,6 @@ class TestPresto(Validator):
},
)
self.validate_all(
- "SELECT a FROM t GROUP BY a, ROLLUP(b), ROLLUP(c), ROLLUP(d)",
- write={
- "presto": "SELECT a FROM t GROUP BY a, ROLLUP (b, c, d)",
- },
- )
- self.validate_all(
'SELECT a."b" FROM "foo"',
write={
"duckdb": 'SELECT a."b" FROM "foo"',
@@ -925,8 +985,8 @@ class TestPresto(Validator):
write={
"bigquery": "SELECT * FROM UNNEST(['7', '14'])",
"presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS x",
- "hive": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS x",
- "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS x",
+ "hive": "SELECT * FROM EXPLODE(ARRAY('7', '14')) AS x",
+ "spark": "SELECT * FROM EXPLODE(ARRAY('7', '14')) AS x",
},
)
self.validate_all(
@@ -934,8 +994,8 @@ class TestPresto(Validator):
write={
"bigquery": "SELECT * FROM UNNEST(['7', '14']) AS y",
"presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS x(y)",
- "hive": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS x(y)",
- "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS x(y)",
+ "hive": "SELECT * FROM EXPLODE(ARRAY('7', '14')) AS x(y)",
+ "spark": "SELECT * FROM EXPLODE(ARRAY('7', '14')) AS x(y)",
},
)
self.validate_all(
@@ -993,6 +1053,25 @@ class TestPresto(Validator):
"spark": "SELECT REGEXP_EXTRACT(TO_JSON(FROM_JSON('[[1, 2, 3]]', SCHEMA_OF_JSON('[[1, 2, 3]]'))), '^.(.*).$', 1)",
},
)
+ self.validate_all(
+ "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ read={
+ "presto": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "trino": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "duckdb": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "snowflake": "REGEXP_SUBSTR('abc', '(a)(b)(c)')",
+ },
+ write={
+ "presto": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "trino": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "duckdb": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "snowflake": "REGEXP_SUBSTR('abc', '(a)(b)(c)')",
+ "hive": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 0)",
+ "spark2": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 0)",
+ "spark": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 0)",
+ "databricks": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 0)",
+ },
+ )
def test_encode_decode(self):
self.validate_identity("FROM_UTF8(x, y)")
@@ -1190,3 +1269,18 @@ MATCH_RECOGNIZE (
"starrocks": "SIGN(x)",
},
)
+
+ def test_json_vs_row_extract(self):
+ for dialect in ("trino", "presto"):
+ s = parse_one('SELECT col:x:y."special string"', read="snowflake")
+
+ dialect_json_extract_setting = f"{dialect}, variant_extract_is_json_extract=True"
+ dialect_row_access_setting = f"{dialect}, variant_extract_is_json_extract=False"
+
+ # By default, Snowflake VARIANT will generate JSON_EXTRACT() in Presto/Trino
+ json_extract_result = """SELECT JSON_EXTRACT(col, '$.x.y["special string"]')"""
+ self.assertEqual(s.sql(dialect), json_extract_result)
+ self.assertEqual(s.sql(dialect_json_extract_setting), json_extract_result)
+
+ # If the setting is overriden to False, then generate ROW access (dot notation)
+ self.assertEqual(s.sql(dialect_row_access_setting), 'SELECT col.x.y."special string"')