From 66af5c6fc22f6f11e9ea807b274e011a6f64efb7 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 19 Mar 2023 11:22:09 +0100 Subject: Merging upstream version 11.4.1. Signed-off-by: Daniel Baumann --- tests/dialects/test_dialect.py | 78 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 74 insertions(+), 4 deletions(-) (limited to 'tests/dialects/test_dialect.py') diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 21f6be6..6214c43 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -519,7 +519,7 @@ class TestDialect(Validator): "duckdb": "x + INTERVAL 1 day", "hive": "DATE_ADD(x, 1)", "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", - "postgres": "x + INTERVAL '1' 'day'", + "postgres": "x + INTERVAL '1' day", "presto": "DATE_ADD('day', 1, x)", "snowflake": "DATEADD(day, 1, x)", "spark": "DATE_ADD(x, 1)", @@ -543,11 +543,48 @@ class TestDialect(Validator): ) self.validate_all( "DATE_TRUNC('day', x)", + read={ + "bigquery": "DATE_TRUNC(x, day)", + "duckdb": "DATE_TRUNC('day', x)", + "spark": "TRUNC(x, 'day')", + }, write={ + "bigquery": "DATE_TRUNC(x, day)", + "duckdb": "DATE_TRUNC('day', x)", "mysql": "DATE(x)", + "presto": "DATE_TRUNC('day', x)", + "postgres": "DATE_TRUNC('day', x)", "snowflake": "DATE_TRUNC('day', x)", + "starrocks": "DATE_TRUNC('day', x)", + "spark": "TRUNC(x, 'day')", + }, + ) + self.validate_all( + "TIMESTAMP_TRUNC(x, day)", + read={ + "bigquery": "TIMESTAMP_TRUNC(x, day)", + "presto": "DATE_TRUNC('day', x)", + "postgres": "DATE_TRUNC('day', x)", + "snowflake": "DATE_TRUNC('day', x)", + "starrocks": "DATE_TRUNC('day', x)", + "spark": "DATE_TRUNC('day', x)", + }, + ) + self.validate_all( + "DATE_TRUNC('day', CAST(x AS DATE))", + read={ + "presto": "DATE_TRUNC('day', x::DATE)", + "snowflake": "DATE_TRUNC('day', x::DATE)", }, ) + self.validate_all( + "TIMESTAMP_TRUNC(CAST(x AS DATE), day)", + read={ + "postgres": "DATE_TRUNC('day', x::DATE)", + "starrocks": "DATE_TRUNC('day', x::DATE)", + }, + ) + self.validate_all( "DATE_TRUNC('week', x)", write={ @@ -582,8 +619,6 @@ class TestDialect(Validator): "DATE_TRUNC('year', x)", read={ "bigquery": "DATE_TRUNC(x, year)", - "snowflake": "DATE_TRUNC(year, x)", - "starrocks": "DATE_TRUNC('year', x)", "spark": "TRUNC(x, 'year')", }, write={ @@ -599,7 +634,10 @@ class TestDialect(Validator): "TIMESTAMP_TRUNC(x, year)", read={ "bigquery": "TIMESTAMP_TRUNC(x, year)", + "postgres": "DATE_TRUNC(year, x)", "spark": "DATE_TRUNC('year', x)", + "snowflake": "DATE_TRUNC(year, x)", + "starrocks": "DATE_TRUNC('year', x)", }, write={ "bigquery": "TIMESTAMP_TRUNC(x, year)", @@ -752,7 +790,6 @@ class TestDialect(Validator): "trino": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", "duckdb": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", "hive": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", - "presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", "spark": "AGGREGATE(x, 0, (acc, x) -> acc + x, acc -> acc)", "presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", }, @@ -1455,3 +1492,36 @@ SELECT "postgres": "SUBSTRING('123456' FROM 2 FOR 3)", }, ) + + def test_count_if(self): + self.validate_identity("COUNT_IF(DISTINCT cond)") + + self.validate_all( + "SELECT COUNT_IF(cond) FILTER", write={"": "SELECT COUNT_IF(cond) AS FILTER"} + ) + self.validate_all( + "SELECT COUNT_IF(col % 2 = 0) FROM foo", + write={ + "": "SELECT COUNT_IF(col % 2 = 0) FROM foo", + "databricks": "SELECT COUNT_IF(col % 2 = 0) FROM foo", + "presto": "SELECT COUNT_IF(col % 2 = 0) FROM foo", + "snowflake": "SELECT COUNT_IF(col % 2 = 0) FROM foo", + "sqlite": "SELECT SUM(CASE WHEN col % 2 = 0 THEN 1 ELSE 0 END) FROM foo", + "tsql": "SELECT COUNT_IF(col % 2 = 0) FROM foo", + }, + ) + self.validate_all( + "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", + read={ + "": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", + "databricks": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", + "tsql": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", + }, + write={ + "": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", + "databricks": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", + "presto": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", + "sqlite": "SELECT SUM(CASE WHEN col % 2 = 0 THEN 1 ELSE 0 END) FILTER(WHERE col < 1000) FROM foo", + "tsql": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", + }, + ) -- cgit v1.2.3