summaryrefslogtreecommitdiffstats
path: root/tests/dialects/test_dialect.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-10-04 12:14:40 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-10-04 12:14:40 +0000
commitd7f0758e21b5111b5327f3839c5c9f49a04d272b (patch)
treea425f4ebcc159d6bd9443fe4e0e2f9eb20151027 /tests/dialects/test_dialect.py
parentAdding upstream version 18.7.0. (diff)
downloadsqlglot-d7f0758e21b5111b5327f3839c5c9f49a04d272b.tar.xz
sqlglot-d7f0758e21b5111b5327f3839c5c9f49a04d272b.zip
Adding upstream version 18.11.2.upstream/18.11.2
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/dialects/test_dialect.py')
-rw-r--r--tests/dialects/test_dialect.py77
1 files changed, 72 insertions, 5 deletions
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 3e0ffd5..91eba17 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -5,6 +5,7 @@ from sqlglot import (
Dialects,
ErrorLevel,
ParseError,
+ TokenError,
UnsupportedError,
parse_one,
)
@@ -308,6 +309,44 @@ class TestDialect(Validator):
read={"postgres": "INET '127.0.0.1/32'"},
)
+ def test_heredoc_strings(self):
+ for dialect in ("clickhouse", "postgres", "redshift"):
+ # Invalid matching tag
+ with self.assertRaises(TokenError):
+ parse_one("SELECT $tag1$invalid heredoc string$tag2$", dialect=dialect)
+
+ # Unmatched tag
+ with self.assertRaises(TokenError):
+ parse_one("SELECT $tag1$invalid heredoc string", dialect=dialect)
+
+ # Without tag
+ self.validate_all(
+ "SELECT 'this is a heredoc string'",
+ read={
+ dialect: "SELECT $$this is a heredoc string$$",
+ },
+ )
+ self.validate_all(
+ "SELECT ''",
+ read={
+ dialect: "SELECT $$$$",
+ },
+ )
+
+ # With tag
+ self.validate_all(
+ "SELECT 'this is also a heredoc string'",
+ read={
+ dialect: "SELECT $foo$this is also a heredoc string$foo$",
+ },
+ )
+ self.validate_all(
+ "SELECT ''",
+ read={
+ dialect: "SELECT $foo$$foo$",
+ },
+ )
+
def test_decode(self):
self.validate_identity("DECODE(bin, charset)")
@@ -568,6 +607,7 @@ class TestDialect(Validator):
"presto": "CAST(CAST(x AS TIMESTAMP) AS DATE)",
"snowflake": "CAST(x AS DATE)",
"doris": "TO_DATE(x)",
+ "mysql": "DATE(x)",
},
)
self.validate_all(
@@ -648,9 +688,7 @@ class TestDialect(Validator):
self.validate_all(
"DATE_ADD(x, 1, 'DAY')",
read={
- "mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
"snowflake": "DATEADD('DAY', 1, x)",
- "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)",
},
write={
"bigquery": "DATE_ADD(x, INTERVAL 1 DAY)",
@@ -842,6 +880,7 @@ class TestDialect(Validator):
"hive": "DATE_ADD('2021-02-01', 1)",
"presto": "DATE_ADD('DAY', 1, CAST(CAST('2021-02-01' AS TIMESTAMP) AS DATE))",
"spark": "DATE_ADD('2021-02-01', 1)",
+ "mysql": "DATE_ADD('2021-02-01', INTERVAL 1 DAY)",
},
)
self.validate_all(
@@ -897,10 +936,7 @@ class TestDialect(Validator):
"bigquery",
"drill",
"duckdb",
- "mysql",
"presto",
- "starrocks",
- "doris",
)
},
write={
@@ -913,8 +949,25 @@ class TestDialect(Validator):
"presto",
"hive",
"spark",
+ )
+ },
+ )
+ self.validate_all(
+ f"{unit}(TS_OR_DS_TO_DATE(x))",
+ read={
+ dialect: f"{unit}(x)"
+ for dialect in (
+ "mysql",
+ "doris",
"starrocks",
+ )
+ },
+ write={
+ dialect: f"{unit}(x)"
+ for dialect in (
+ "mysql",
"doris",
+ "starrocks",
)
},
)
@@ -1790,3 +1843,17 @@ SELECT
with self.assertRaises(ParseError):
parse_one("CAST(x AS some_udt)", read="bigquery")
+
+ def test_qualify(self):
+ self.validate_all(
+ "SELECT * FROM t QUALIFY COUNT(*) OVER () > 1",
+ write={
+ "duckdb": "SELECT * FROM t QUALIFY COUNT(*) OVER () > 1",
+ "snowflake": "SELECT * FROM t QUALIFY COUNT(*) OVER () > 1",
+ "clickhouse": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
+ "mysql": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
+ "oracle": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) _t WHERE _w > 1",
+ "postgres": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
+ "tsql": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
+ },
+ )