summaryrefslogtreecommitdiffstats
path: root/tests/dialects/test_dialect.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dialects/test_dialect.py')
-rw-r--r--tests/dialects/test_dialect.py78
1 files changed, 74 insertions, 4 deletions
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 21f6be6..6214c43 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -519,7 +519,7 @@ class TestDialect(Validator):
"duckdb": "x + INTERVAL 1 day",
"hive": "DATE_ADD(x, 1)",
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
- "postgres": "x + INTERVAL '1' 'day'",
+ "postgres": "x + INTERVAL '1' day",
"presto": "DATE_ADD('day', 1, x)",
"snowflake": "DATEADD(day, 1, x)",
"spark": "DATE_ADD(x, 1)",
@@ -543,12 +543,49 @@ class TestDialect(Validator):
)
self.validate_all(
"DATE_TRUNC('day', x)",
+ read={
+ "bigquery": "DATE_TRUNC(x, day)",
+ "duckdb": "DATE_TRUNC('day', x)",
+ "spark": "TRUNC(x, 'day')",
+ },
write={
+ "bigquery": "DATE_TRUNC(x, day)",
+ "duckdb": "DATE_TRUNC('day', x)",
"mysql": "DATE(x)",
+ "presto": "DATE_TRUNC('day', x)",
+ "postgres": "DATE_TRUNC('day', x)",
"snowflake": "DATE_TRUNC('day', x)",
+ "starrocks": "DATE_TRUNC('day', x)",
+ "spark": "TRUNC(x, 'day')",
+ },
+ )
+ self.validate_all(
+ "TIMESTAMP_TRUNC(x, day)",
+ read={
+ "bigquery": "TIMESTAMP_TRUNC(x, day)",
+ "presto": "DATE_TRUNC('day', x)",
+ "postgres": "DATE_TRUNC('day', x)",
+ "snowflake": "DATE_TRUNC('day', x)",
+ "starrocks": "DATE_TRUNC('day', x)",
+ "spark": "DATE_TRUNC('day', x)",
+ },
+ )
+ self.validate_all(
+ "DATE_TRUNC('day', CAST(x AS DATE))",
+ read={
+ "presto": "DATE_TRUNC('day', x::DATE)",
+ "snowflake": "DATE_TRUNC('day', x::DATE)",
},
)
self.validate_all(
+ "TIMESTAMP_TRUNC(CAST(x AS DATE), day)",
+ read={
+ "postgres": "DATE_TRUNC('day', x::DATE)",
+ "starrocks": "DATE_TRUNC('day', x::DATE)",
+ },
+ )
+
+ self.validate_all(
"DATE_TRUNC('week', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
@@ -582,8 +619,6 @@ class TestDialect(Validator):
"DATE_TRUNC('year', x)",
read={
"bigquery": "DATE_TRUNC(x, year)",
- "snowflake": "DATE_TRUNC(year, x)",
- "starrocks": "DATE_TRUNC('year', x)",
"spark": "TRUNC(x, 'year')",
},
write={
@@ -599,7 +634,10 @@ class TestDialect(Validator):
"TIMESTAMP_TRUNC(x, year)",
read={
"bigquery": "TIMESTAMP_TRUNC(x, year)",
+ "postgres": "DATE_TRUNC(year, x)",
"spark": "DATE_TRUNC('year', x)",
+ "snowflake": "DATE_TRUNC(year, x)",
+ "starrocks": "DATE_TRUNC('year', x)",
},
write={
"bigquery": "TIMESTAMP_TRUNC(x, year)",
@@ -752,7 +790,6 @@ class TestDialect(Validator):
"trino": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
"duckdb": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
"hive": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
- "presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
"spark": "AGGREGATE(x, 0, (acc, x) -> acc + x, acc -> acc)",
"presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)",
},
@@ -1455,3 +1492,36 @@ SELECT
"postgres": "SUBSTRING('123456' FROM 2 FOR 3)",
},
)
+
+ def test_count_if(self):
+ self.validate_identity("COUNT_IF(DISTINCT cond)")
+
+ self.validate_all(
+ "SELECT COUNT_IF(cond) FILTER", write={"": "SELECT COUNT_IF(cond) AS FILTER"}
+ )
+ self.validate_all(
+ "SELECT COUNT_IF(col % 2 = 0) FROM foo",
+ write={
+ "": "SELECT COUNT_IF(col % 2 = 0) FROM foo",
+ "databricks": "SELECT COUNT_IF(col % 2 = 0) FROM foo",
+ "presto": "SELECT COUNT_IF(col % 2 = 0) FROM foo",
+ "snowflake": "SELECT COUNT_IF(col % 2 = 0) FROM foo",
+ "sqlite": "SELECT SUM(CASE WHEN col % 2 = 0 THEN 1 ELSE 0 END) FROM foo",
+ "tsql": "SELECT COUNT_IF(col % 2 = 0) FROM foo",
+ },
+ )
+ self.validate_all(
+ "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
+ read={
+ "": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
+ "databricks": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
+ "tsql": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
+ },
+ write={
+ "": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
+ "databricks": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
+ "presto": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
+ "sqlite": "SELECT SUM(CASE WHEN col % 2 = 0 THEN 1 ELSE 0 END) FILTER(WHERE col < 1000) FROM foo",
+ "tsql": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
+ },
+ )