diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-02-08 05:38:39 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-02-08 05:38:39 +0000 |
commit | aedf35026379f52d7e2b4c1f957691410a758089 (patch) | |
tree | 86540364259b66741173d2333387b78d6f9c31e2 /tests | |
parent | Adding upstream version 20.11.0. (diff) | |
download | sqlglot-aedf35026379f52d7e2b4c1f957691410a758089.tar.xz sqlglot-aedf35026379f52d7e2b4c1f957691410a758089.zip |
Adding upstream version 21.0.1.upstream/21.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests')
-rw-r--r-- | tests/dataframe/unit/test_functions.py | 10 | ||||
-rw-r--r-- | tests/dialects/test_bigquery.py | 137 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 211 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 38 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 24 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 623 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 12 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 44 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 45 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 10 | ||||
-rw-r--r-- | tests/dialects/test_sqlite.py | 9 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 32 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 15 | ||||
-rw-r--r-- | tests/fixtures/optimizer/simplify.sql | 3 | ||||
-rw-r--r-- | tests/test_executor.py | 27 | ||||
-rw-r--r-- | tests/test_expressions.py | 8 | ||||
-rw-r--r-- | tests/test_jsonpath.py | 39 | ||||
-rw-r--r-- | tests/test_parser.py | 36 | ||||
-rw-r--r-- | tests/test_serde.py | 3 |
20 files changed, 889 insertions, 441 deletions
diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index c845441..24904b7 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -528,7 +528,7 @@ class TestFunctions(unittest.TestCase): col = SF.first(SF.col("cola")) self.assertEqual("FIRST(cola)", col.sql()) ignore_nulls = SF.first("cola", True) - self.assertEqual("FIRST(cola, TRUE)", ignore_nulls.sql()) + self.assertEqual("FIRST(cola) IGNORE NULLS", ignore_nulls.sql()) def test_grouping_id(self): col_str = SF.grouping_id("cola", "colb") @@ -562,7 +562,7 @@ class TestFunctions(unittest.TestCase): col = SF.last(SF.col("cola")) self.assertEqual("LAST(cola)", col.sql()) ignore_nulls = SF.last("cola", True) - self.assertEqual("LAST(cola, TRUE)", ignore_nulls.sql()) + self.assertEqual("LAST(cola) IGNORE NULLS", ignore_nulls.sql()) def test_monotonically_increasing_id(self): col = SF.monotonically_increasing_id() @@ -713,8 +713,10 @@ class TestFunctions(unittest.TestCase): self.assertEqual("NTH_VALUE(cola, 3)", col.sql()) col_no_offset = SF.nth_value("cola") self.assertEqual("NTH_VALUE(cola)", col_no_offset.sql()) - with self.assertRaises(NotImplementedError): - SF.nth_value("cola", ignoreNulls=True) + + self.assertEqual( + "NTH_VALUE(cola) IGNORE NULLS", SF.nth_value("cola", ignoreNulls=True).sql() + ) def test_ntile(self): col = SF.ntile(2) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 2c8ac7b..340630c 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -18,6 +18,64 @@ class TestBigQuery(Validator): maxDiff = None def test_bigquery(self): + self.validate_all( + "SELECT SUM(x IGNORE NULLS) AS x", + read={ + "bigquery": "SELECT SUM(x IGNORE NULLS) AS x", + "duckdb": "SELECT SUM(x IGNORE NULLS) AS x", + "postgres": "SELECT SUM(x) IGNORE NULLS AS x", + "spark": "SELECT SUM(x) IGNORE NULLS AS x", + "snowflake": "SELECT SUM(x) IGNORE NULLS AS x", + }, + write={ + "bigquery": "SELECT SUM(x IGNORE NULLS) AS x", + "duckdb": "SELECT SUM(x IGNORE NULLS) AS x", + "postgres": "SELECT SUM(x) IGNORE NULLS AS x", + "spark": "SELECT SUM(x) IGNORE NULLS AS x", + "snowflake": "SELECT SUM(x) IGNORE NULLS AS x", + }, + ) + self.validate_all( + "SELECT SUM(x RESPECT NULLS) AS x", + read={ + "bigquery": "SELECT SUM(x RESPECT NULLS) AS x", + "duckdb": "SELECT SUM(x RESPECT NULLS) AS x", + "postgres": "SELECT SUM(x) RESPECT NULLS AS x", + "spark": "SELECT SUM(x) RESPECT NULLS AS x", + "snowflake": "SELECT SUM(x) RESPECT NULLS AS x", + }, + write={ + "bigquery": "SELECT SUM(x RESPECT NULLS) AS x", + "duckdb": "SELECT SUM(x RESPECT NULLS) AS x", + "postgres": "SELECT SUM(x) RESPECT NULLS AS x", + "spark": "SELECT SUM(x) RESPECT NULLS AS x", + "snowflake": "SELECT SUM(x) RESPECT NULLS AS x", + }, + ) + self.validate_all( + "SELECT PERCENTILE_CONT(x, 0.5 RESPECT NULLS) OVER ()", + write={ + "duckdb": "SELECT QUANTILE_CONT(x, 0.5 RESPECT NULLS) OVER ()", + "spark": "SELECT PERCENTILE_CONT(x, 0.5) RESPECT NULLS OVER ()", + }, + ) + self.validate_all( + "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x", + write={ + "duckdb": "SELECT ARRAY_AGG(DISTINCT x ORDER BY a NULLS FIRST, b DESC LIMIT 10 IGNORE NULLS) AS x", + "spark": "SELECT COLLECT_LIST(DISTINCT x ORDER BY a, b DESC LIMIT 10) IGNORE NULLS AS x", + }, + ) + self.validate_all( + "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 1, 10) AS x", + write={ + "duckdb": "SELECT ARRAY_AGG(DISTINCT x ORDER BY a NULLS FIRST, b DESC LIMIT 1, 10 IGNORE NULLS) AS x", + "spark": "SELECT COLLECT_LIST(DISTINCT x ORDER BY a, b DESC LIMIT 1, 10) IGNORE NULLS AS x", + }, + ) + self.validate_identity("SELECT COUNT(x RESPECT NULLS)") + self.validate_identity("SELECT LAST_VALUE(x IGNORE NULLS) OVER y AS x") + self.validate_identity( "create or replace view test (tenant_id OPTIONS(description='Test description on table creation')) select 1 as tenant_id, 1 as customer_id;", "CREATE OR REPLACE VIEW test (tenant_id OPTIONS (description='Test description on table creation')) AS SELECT 1 AS tenant_id, 1 AS customer_id", @@ -358,11 +416,26 @@ class TestBigQuery(Validator): "SELECT TIMESTAMP_DIFF(TIMESTAMP_SECONDS(60), TIMESTAMP_SECONDS(0), minute)", write={ "bigquery": "SELECT TIMESTAMP_DIFF(TIMESTAMP_SECONDS(60), TIMESTAMP_SECONDS(0), MINUTE)", + "databricks": "SELECT TIMESTAMPDIFF(MINUTE, CAST(FROM_UNIXTIME(0) AS TIMESTAMP), CAST(FROM_UNIXTIME(60) AS TIMESTAMP))", "duckdb": "SELECT DATE_DIFF('MINUTE', TO_TIMESTAMP(0), TO_TIMESTAMP(60))", "snowflake": "SELECT TIMESTAMPDIFF(MINUTE, TO_TIMESTAMP(0), TO_TIMESTAMP(60))", }, ) self.validate_all( + "TIMESTAMP_DIFF(a, b, MONTH)", + read={ + "bigquery": "TIMESTAMP_DIFF(a, b, month)", + "databricks": "TIMESTAMPDIFF(month, b, a)", + "mysql": "TIMESTAMPDIFF(month, b, a)", + }, + write={ + "databricks": "TIMESTAMPDIFF(MONTH, b, a)", + "mysql": "TIMESTAMPDIFF(MONTH, b, a)", + "snowflake": "TIMESTAMPDIFF(MONTH, b, a)", + }, + ) + + self.validate_all( "SELECT TIMESTAMP_MICROS(x)", read={ "duckdb": "SELECT MAKE_TIMESTAMP(x)", @@ -419,34 +492,42 @@ class TestBigQuery(Validator): "snowflake": "CREATE OR REPLACE TABLE a.b.c CLONE a.b.d", }, ) - self.validate_all( - "SELECT DATETIME_DIFF('2023-01-01T00:00:00', '2023-01-01T05:00:00', MILLISECOND)", - write={ - "bigquery": "SELECT DATETIME_DIFF('2023-01-01T00:00:00', '2023-01-01T05:00:00', MILLISECOND)", - "databricks": "SELECT TIMESTAMPDIFF(MILLISECOND, '2023-01-01T05:00:00', '2023-01-01T00:00:00')", - }, - ), - self.validate_all( - "SELECT DATETIME_ADD('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)", - write={ - "bigquery": "SELECT DATETIME_ADD('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)", - "databricks": "SELECT TIMESTAMPADD(MILLISECOND, 1, '2023-01-01T00:00:00')", - }, - ), - self.validate_all( - "SELECT DATETIME_SUB('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)", - write={ - "bigquery": "SELECT DATETIME_SUB('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)", - "databricks": "SELECT TIMESTAMPADD(MILLISECOND, 1 * -1, '2023-01-01T00:00:00')", - }, - ), - self.validate_all( - "SELECT DATETIME_TRUNC('2023-01-01T01:01:01', HOUR)", - write={ - "bigquery": "SELECT DATETIME_TRUNC('2023-01-01T01:01:01', HOUR)", - "databricks": "SELECT DATE_TRUNC('HOUR', '2023-01-01T01:01:01')", - }, - ), + ( + self.validate_all( + "SELECT DATETIME_DIFF('2023-01-01T00:00:00', '2023-01-01T05:00:00', MILLISECOND)", + write={ + "bigquery": "SELECT DATETIME_DIFF('2023-01-01T00:00:00', '2023-01-01T05:00:00', MILLISECOND)", + "databricks": "SELECT TIMESTAMPDIFF(MILLISECOND, '2023-01-01T05:00:00', '2023-01-01T00:00:00')", + }, + ), + ) + ( + self.validate_all( + "SELECT DATETIME_ADD('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)", + write={ + "bigquery": "SELECT DATETIME_ADD('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)", + "databricks": "SELECT TIMESTAMPADD(MILLISECOND, 1, '2023-01-01T00:00:00')", + }, + ), + ) + ( + self.validate_all( + "SELECT DATETIME_SUB('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)", + write={ + "bigquery": "SELECT DATETIME_SUB('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)", + "databricks": "SELECT TIMESTAMPADD(MILLISECOND, 1 * -1, '2023-01-01T00:00:00')", + }, + ), + ) + ( + self.validate_all( + "SELECT DATETIME_TRUNC('2023-01-01T01:01:01', HOUR)", + write={ + "bigquery": "SELECT DATETIME_TRUNC('2023-01-01T01:01:01', HOUR)", + "databricks": "SELECT DATE_TRUNC('HOUR', '2023-01-01T01:01:01')", + }, + ), + ) self.validate_all("LEAST(x, y)", read={"sqlite": "MIN(x, y)"}) self.validate_all("CAST(x AS CHAR)", write={"bigquery": "CAST(x AS STRING)"}) self.validate_all("CAST(x AS NCHAR)", write={"bigquery": "CAST(x AS STRING)"}) diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index f36af41..d256fc5 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -74,6 +74,10 @@ class TestClickhouse(Validator): self.validate_identity("CAST(x AS DATETIME)") self.validate_identity("CAST(x as MEDIUMINT)", "CAST(x AS Int32)") self.validate_identity("SELECT arrayJoin([1, 2, 3] AS src) AS dst, 'Hello', src") + self.validate_identity("""SELECT JSONExtractString('{"x": {"y": 1}}', 'x', 'y')""") + self.validate_identity( + """SELECT JSONExtractString('{"a": "hello", "b": [-100, 200.0, 300]}', 'a')""" + ) self.validate_identity( "ATTACH DATABASE DEFAULT ENGINE = ORDINARY", check_command_warning=True ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 22e7d49..fd9dbdb 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -378,6 +378,31 @@ class TestDialect(Validator): read={"postgres": "INET '127.0.0.1/32'"}, ) + def test_ddl(self): + self.validate_all( + "CREATE TABLE a LIKE b", + write={ + "": "CREATE TABLE a LIKE b", + "bigquery": "CREATE TABLE a LIKE b", + "clickhouse": "CREATE TABLE a AS b", + "databricks": "CREATE TABLE a LIKE b", + "doris": "CREATE TABLE a LIKE b", + "drill": "CREATE TABLE a AS SELECT * FROM b LIMIT 0", + "duckdb": "CREATE TABLE a AS SELECT * FROM b LIMIT 0", + "hive": "CREATE TABLE a LIKE b", + "mysql": "CREATE TABLE a LIKE b", + "oracle": "CREATE TABLE a LIKE b", + "postgres": "CREATE TABLE a (LIKE b)", + "presto": "CREATE TABLE a (LIKE b)", + "redshift": "CREATE TABLE a (LIKE b)", + "snowflake": "CREATE TABLE a LIKE b", + "spark": "CREATE TABLE a LIKE b", + "sqlite": "CREATE TABLE a AS SELECT * FROM b LIMIT 0", + "trino": "CREATE TABLE a (LIKE b)", + "tsql": "SELECT TOP 0 * INTO a FROM b AS temp", + }, + ) + def test_heredoc_strings(self): for dialect in ("clickhouse", "postgres", "redshift"): # Invalid matching tag @@ -1097,61 +1122,173 @@ class TestDialect(Validator): def test_json(self): self.validate_all( - "JSON_EXTRACT(x, 'y')", - read={ - "mysql": "JSON_EXTRACT(x, 'y')", - "postgres": "x->'y'", - "presto": "JSON_EXTRACT(x, 'y')", - "starrocks": "x -> 'y'", - "doris": "x -> 'y'", - }, + """JSON_EXTRACT(x, '$["a b"]')""", write={ - "mysql": "JSON_EXTRACT(x, 'y')", - "oracle": "JSON_EXTRACT(x, 'y')", - "postgres": "x -> 'y'", - "presto": "JSON_EXTRACT(x, 'y')", - "starrocks": "x -> 'y'", - "doris": "x -> 'y'", + "": """JSON_EXTRACT(x, '$["a b"]')""", + "bigquery": """JSON_EXTRACT(x, '$[\\'a b\\']')""", + "clickhouse": "JSONExtractString(x, 'a b')", + "duckdb": """x -> '$."a b"'""", + "mysql": """JSON_EXTRACT(x, '$."a b"')""", + "postgres": "JSON_EXTRACT_PATH(x, 'a b')", + "presto": """JSON_EXTRACT(x, '$["a b"]')""", + "redshift": "JSON_EXTRACT_PATH_TEXT(x, 'a b')", + "snowflake": """GET_PATH(x, '["a b"]')""", + "spark": """GET_JSON_OBJECT(x, '$[\\'a b\\']')""", + "sqlite": """x -> '$."a b"'""", + "trino": """JSON_EXTRACT(x, '$["a b"]')""", + "tsql": """ISNULL(JSON_QUERY(x, '$."a b"'), JSON_VALUE(x, '$."a b"'))""", }, ) self.validate_all( - "JSON_EXTRACT_SCALAR(x, 'y')", + "JSON_EXTRACT(x, '$.y')", read={ - "postgres": "x ->> 'y'", - "presto": "JSON_EXTRACT_SCALAR(x, 'y')", + "bigquery": "JSON_EXTRACT(x, '$.y')", + "duckdb": "x -> 'y'", + "doris": "x -> '$.y'", + "mysql": "JSON_EXTRACT(x, '$.y')", + "postgres": "x->'y'", + "presto": "JSON_EXTRACT(x, '$.y')", + "snowflake": "GET_PATH(x, 'y')", + "sqlite": "x -> '$.y'", + "starrocks": "x -> '$.y'", }, write={ - "postgres": "x ->> 'y'", - "presto": "JSON_EXTRACT_SCALAR(x, 'y')", + "bigquery": "JSON_EXTRACT(x, '$.y')", + "clickhouse": "JSONExtractString(x, 'y')", + "doris": "x -> '$.y'", + "duckdb": "x -> '$.y'", + "mysql": "JSON_EXTRACT(x, '$.y')", + "oracle": "JSON_EXTRACT(x, '$.y')", + "postgres": "JSON_EXTRACT_PATH(x, 'y')", + "presto": "JSON_EXTRACT(x, '$.y')", + "snowflake": "GET_PATH(x, 'y')", + "spark": "GET_JSON_OBJECT(x, '$.y')", + "sqlite": "x -> '$.y'", + "starrocks": "x -> '$.y'", + "tsql": "ISNULL(JSON_QUERY(x, '$.y'), JSON_VALUE(x, '$.y'))", }, ) self.validate_all( - "JSON_EXTRACT_SCALAR(stream_data, '$.data.results')", + "JSON_EXTRACT_SCALAR(x, '$.y')", read={ - "hive": "GET_JSON_OBJECT(stream_data, '$.data.results')", - "mysql": "stream_data ->> '$.data.results'", + "bigquery": "JSON_EXTRACT_SCALAR(x, '$.y')", + "clickhouse": "JSONExtractString(x, 'y')", + "duckdb": "x ->> 'y'", + "postgres": "x ->> 'y'", + "presto": "JSON_EXTRACT_SCALAR(x, '$.y')", + "redshift": "JSON_EXTRACT_PATH_TEXT(x, 'y')", + "spark": "GET_JSON_OBJECT(x, '$.y')", + "snowflake": "JSON_EXTRACT_PATH_TEXT(x, 'y')", + "sqlite": "x ->> '$.y'", }, write={ - "hive": "GET_JSON_OBJECT(stream_data, '$.data.results')", - "mysql": "stream_data ->> '$.data.results'", + "bigquery": "JSON_EXTRACT_SCALAR(x, '$.y')", + "clickhouse": "JSONExtractString(x, 'y')", + "duckdb": "x ->> '$.y'", + "postgres": "JSON_EXTRACT_PATH_TEXT(x, 'y')", + "presto": "JSON_EXTRACT_SCALAR(x, '$.y')", + "redshift": "JSON_EXTRACT_PATH_TEXT(x, 'y')", + "snowflake": "JSON_EXTRACT_PATH_TEXT(x, 'y')", + "spark": "GET_JSON_OBJECT(x, '$.y')", + "sqlite": "x ->> '$.y'", + "tsql": "ISNULL(JSON_QUERY(x, '$.y'), JSON_VALUE(x, '$.y'))", }, ) self.validate_all( - "JSONB_EXTRACT(x, 'y')", + "JSON_EXTRACT(x, '$.y[0].z')", read={ - "postgres": "x#>'y'", - }, - write={ - "postgres": "x #> 'y'", - }, - ) - self.validate_all( - "JSONB_EXTRACT_SCALAR(x, 'y')", + "bigquery": "JSON_EXTRACT(x, '$.y[0].z')", + "duckdb": "x -> '$.y[0].z'", + "doris": "x -> '$.y[0].z'", + "mysql": "JSON_EXTRACT(x, '$.y[0].z')", + "presto": "JSON_EXTRACT(x, '$.y[0].z')", + "snowflake": "GET_PATH(x, 'y[0].z')", + "sqlite": "x -> '$.y[0].z'", + "starrocks": "x -> '$.y[0].z'", + }, + write={ + "bigquery": "JSON_EXTRACT(x, '$.y[0].z')", + "clickhouse": "JSONExtractString(x, 'y', 1, 'z')", + "doris": "x -> '$.y[0].z'", + "duckdb": "x -> '$.y[0].z'", + "mysql": "JSON_EXTRACT(x, '$.y[0].z')", + "oracle": "JSON_EXTRACT(x, '$.y[0].z')", + "postgres": "JSON_EXTRACT_PATH(x, 'y', '0', 'z')", + "presto": "JSON_EXTRACT(x, '$.y[0].z')", + "redshift": "JSON_EXTRACT_PATH_TEXT(x, 'y', '0', 'z')", + "snowflake": "GET_PATH(x, 'y[0].z')", + "spark": "GET_JSON_OBJECT(x, '$.y[0].z')", + "sqlite": "x -> '$.y[0].z'", + "starrocks": "x -> '$.y[0].z'", + "tsql": "ISNULL(JSON_QUERY(x, '$.y[0].z'), JSON_VALUE(x, '$.y[0].z'))", + }, + ) + self.validate_all( + "JSON_EXTRACT_SCALAR(x, '$.y[0].z')", read={ - "postgres": "x#>>'y'", - }, - write={ - "postgres": "x #>> 'y'", + "bigquery": "JSON_EXTRACT_SCALAR(x, '$.y[0].z')", + "clickhouse": "JSONExtractString(x, 'y', 1, 'z')", + "duckdb": "x ->> '$.y[0].z'", + "presto": "JSON_EXTRACT_SCALAR(x, '$.y[0].z')", + "snowflake": "JSON_EXTRACT_PATH_TEXT(x, 'y[0].z')", + "spark": 'GET_JSON_OBJECT(x, "$.y[0].z")', + "sqlite": "x ->> '$.y[0].z'", + }, + write={ + "bigquery": "JSON_EXTRACT_SCALAR(x, '$.y[0].z')", + "clickhouse": "JSONExtractString(x, 'y', 1, 'z')", + "duckdb": "x ->> '$.y[0].z'", + "postgres": "JSON_EXTRACT_PATH_TEXT(x, 'y', '0', 'z')", + "presto": "JSON_EXTRACT_SCALAR(x, '$.y[0].z')", + "redshift": "JSON_EXTRACT_PATH_TEXT(x, 'y', '0', 'z')", + "snowflake": "JSON_EXTRACT_PATH_TEXT(x, 'y[0].z')", + "spark": "GET_JSON_OBJECT(x, '$.y[0].z')", + "sqlite": "x ->> '$.y[0].z'", + "tsql": "ISNULL(JSON_QUERY(x, '$.y[0].z'), JSON_VALUE(x, '$.y[0].z'))", + }, + ) + self.validate_all( + "JSON_EXTRACT(x, '$.y[*]')", + write={ + "bigquery": UnsupportedError, + "clickhouse": UnsupportedError, + "duckdb": "x -> '$.y[*]'", + "mysql": "JSON_EXTRACT(x, '$.y[*]')", + "postgres": UnsupportedError, + "presto": "JSON_EXTRACT(x, '$.y[*]')", + "redshift": UnsupportedError, + "snowflake": UnsupportedError, + "spark": "GET_JSON_OBJECT(x, '$.y[*]')", + "sqlite": UnsupportedError, + "tsql": UnsupportedError, + }, + ) + self.validate_all( + "JSON_EXTRACT(x, '$.y[*]')", + write={ + "bigquery": "JSON_EXTRACT(x, '$.y')", + "clickhouse": "JSONExtractString(x, 'y')", + "postgres": "JSON_EXTRACT_PATH(x, 'y')", + "redshift": "JSON_EXTRACT_PATH_TEXT(x, 'y')", + "snowflake": "GET_PATH(x, 'y')", + "sqlite": "x -> '$.y'", + "tsql": "ISNULL(JSON_QUERY(x, '$.y'), JSON_VALUE(x, '$.y'))", + }, + ) + self.validate_all( + "JSON_EXTRACT(x, '$.y.*')", + write={ + "bigquery": UnsupportedError, + "clickhouse": UnsupportedError, + "duckdb": "x -> '$.y.*'", + "mysql": "JSON_EXTRACT(x, '$.y.*')", + "postgres": UnsupportedError, + "presto": "JSON_EXTRACT(x, '$.y.*')", + "redshift": UnsupportedError, + "snowflake": UnsupportedError, + "spark": UnsupportedError, + "sqlite": UnsupportedError, + "tsql": UnsupportedError, }, ) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index f3b41b4..9c48f69 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -41,6 +41,7 @@ class TestDuckDB(Validator): ) self.validate_identity("SELECT 1 WHERE x > $1") self.validate_identity("SELECT 1 WHERE x > $name") + self.validate_identity("""SELECT '{"x": 1}' -> c FROM t""") self.assertEqual( parse_one("select * from t limit (select 5)").sql(dialect="duckdb"), @@ -89,18 +90,26 @@ class TestDuckDB(Validator): }, ) + self.validate_identity("""SELECT '{"duck": [1, 2, 3]}' -> '$.duck[#-1]'""") + self.validate_all( + """SELECT JSON_EXTRACT('{"duck": [1, 2, 3]}', '/duck/0')""", + write={ + "": """SELECT JSON_EXTRACT('{"duck": [1, 2, 3]}', '/duck/0')""", + "duckdb": """SELECT '{"duck": [1, 2, 3]}' -> '/duck/0'""", + }, + ) self.validate_all( """SELECT JSON('{"fruit":"banana"}') -> 'fruit'""", write={ - "duckdb": """SELECT JSON('{"fruit":"banana"}') -> 'fruit'""", - "snowflake": """SELECT PARSE_JSON('{"fruit":"banana"}')['fruit']""", + "duckdb": """SELECT JSON('{"fruit":"banana"}') -> '$.fruit'""", + "snowflake": """SELECT GET_PATH(PARSE_JSON('{"fruit":"banana"}'), 'fruit')""", }, ) self.validate_all( """SELECT JSON('{"fruit": {"foo": "banana"}}') -> 'fruit' -> 'foo'""", write={ - "duckdb": """SELECT JSON('{"fruit": {"foo": "banana"}}') -> 'fruit' -> 'foo'""", - "snowflake": """SELECT PARSE_JSON('{"fruit": {"foo": "banana"}}')['fruit']['foo']""", + "duckdb": """SELECT JSON('{"fruit": {"foo": "banana"}}') -> '$.fruit' -> '$.foo'""", + "snowflake": """SELECT GET_PATH(GET_PATH(PARSE_JSON('{"fruit": {"foo": "banana"}}'), 'fruit'), 'foo')""", }, ) self.validate_all( @@ -199,6 +208,27 @@ class TestDuckDB(Validator): self.validate_identity("FROM x SELECT x UNION SELECT 1", "SELECT x FROM x UNION SELECT 1") self.validate_identity("FROM (FROM tbl)", "SELECT * FROM (SELECT * FROM tbl)") self.validate_identity("FROM tbl", "SELECT * FROM tbl") + self.validate_identity("x -> '$.family'") + self.validate_identity( + """SELECT '{"foo": [1, 2, 3]}' -> 'foo' -> 0""", + """SELECT '{"foo": [1, 2, 3]}' -> '$.foo' -> '$[0]'""", + ) + self.validate_identity( + "JSON_EXTRACT(x, '$.family')", + "x -> '$.family'", + ) + self.validate_identity( + "JSON_EXTRACT_PATH(x, '$.family')", + "x -> '$.family'", + ) + self.validate_identity( + "JSON_EXTRACT_STRING(x, '$.family')", + "x ->> '$.family'", + ) + self.validate_identity( + "JSON_EXTRACT_PATH_TEXT(x, '$.family')", + "x ->> '$.family'", + ) self.validate_identity( "ATTACH DATABASE ':memory:' AS new_database", check_command_warning=True ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 3a3e49e..fd27a1e 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -386,7 +386,6 @@ class TestMySQL(Validator): "snowflake": "SELECT 11", "spark": "SELECT 11", "sqlite": "SELECT 11", - "mysql": "SELECT b'1011'", "tableau": "SELECT 11", "teradata": "SELECT 11", "trino": "SELECT 11", @@ -591,6 +590,26 @@ class TestMySQL(Validator): def test_mysql(self): self.validate_all( + "SELECT JSON_EXTRACT('[10, 20, [30, 40]]', '$[1]')", + read={ + "sqlite": "SELECT JSON_EXTRACT('[10, 20, [30, 40]]', '$[1]')", + }, + write={ + "mysql": "SELECT JSON_EXTRACT('[10, 20, [30, 40]]', '$[1]')", + "sqlite": "SELECT '[10, 20, [30, 40]]' -> '$[1]'", + }, + ) + self.validate_all( + "SELECT JSON_EXTRACT('[10, 20, [30, 40]]', '$[1]', '$[0]')", + read={ + "sqlite": "SELECT JSON_EXTRACT('[10, 20, [30, 40]]', '$[1]', '$[0]')", + }, + write={ + "mysql": "SELECT JSON_EXTRACT('[10, 20, [30, 40]]', '$[1]', '$[0]')", + "sqlite": "SELECT JSON_EXTRACT('[10, 20, [30, 40]]', '$[1]', '$[0]')", + }, + ) + self.validate_all( "SELECT * FROM x LEFT JOIN y ON x.id = y.id UNION SELECT * FROM x RIGHT JOIN y ON x.id = y.id LIMIT 0", read={ "postgres": "SELECT * FROM x FULL JOIN y ON x.id = y.id LIMIT 0", @@ -790,6 +809,7 @@ COMMENT='客户账户表'""" ("CHARACTER SET", "CHARACTER SET"), ("COLLATION", "COLLATION"), ("DATABASES", "DATABASES"), + ("SCHEMAS", "DATABASES"), ("FUNCTION STATUS", "FUNCTION STATUS"), ("PROCEDURE STATUS", "PROCEDURE STATUS"), ("GLOBAL STATUS", "GLOBAL STATUS"), @@ -850,7 +870,7 @@ COMMENT='客户账户表'""" self.assertEqual(show.text("target"), "foo") def test_show_grants(self): - show = self.validate_identity(f"SHOW GRANTS FOR foo") + show = self.validate_identity("SHOW GRANTS FOR foo") self.assertIsInstance(show, exp.Show) self.assertEqual(show.name, "GRANTS") self.assertEqual(show.text("target"), "foo") diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index dc00c85..9c4246e 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -8,10 +8,6 @@ class TestPostgres(Validator): dialect = "postgres" def test_postgres(self): - self.validate_identity("SELECT CURRENT_USER") - self.validate_identity("CAST(1 AS DECIMAL) / CAST(2 AS DECIMAL) * -100") - self.validate_identity("EXEC AS myfunc @id = 123", check_command_warning=True) - expr = parse_one( "SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)", read="postgres" ) @@ -37,33 +33,6 @@ class TestPostgres(Validator): self.assertIsInstance(expr, exp.AlterTable) self.assertEqual(expr.sql(dialect="postgres"), alter_table_only) - self.validate_identity( - "SELECT c.oid, n.nspname, c.relname " - "FROM pg_catalog.pg_class AS c " - "LEFT JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace " - "WHERE c.relname OPERATOR(pg_catalog.~) '^(courses)$' COLLATE pg_catalog.default AND " - "pg_catalog.PG_TABLE_IS_VISIBLE(c.oid) " - "ORDER BY 2, 3" - ) - self.validate_identity( - "SELECT ARRAY[]::INT[] AS foo", - "SELECT CAST(ARRAY[] AS INT[]) AS foo", - ) - self.validate_identity( - """ALTER TABLE ONLY "Album" ADD CONSTRAINT "FK_AlbumArtistId" FOREIGN KEY ("ArtistId") REFERENCES "Artist" ("ArtistId") ON DELETE CASCADE""" - ) - self.validate_identity( - """ALTER TABLE ONLY "Album" ADD CONSTRAINT "FK_AlbumArtistId" FOREIGN KEY ("ArtistId") REFERENCES "Artist" ("ArtistId") ON DELETE RESTRICT""" - ) - self.validate_identity( - "SELECT * FROM JSON_ARRAY_ELEMENTS('[1,true, [2,false]]') WITH ORDINALITY" - ) - self.validate_identity( - "SELECT * FROM JSON_ARRAY_ELEMENTS('[1,true, [2,false]]') WITH ORDINALITY AS kv_json" - ) - self.validate_identity( - "SELECT * FROM JSON_ARRAY_ELEMENTS('[1,true, [2,false]]') WITH ORDINALITY AS kv_json(a, b)" - ) self.validate_identity("SELECT * FROM t TABLESAMPLE SYSTEM (50) REPEATABLE (55)") self.validate_identity("x @@ y") self.validate_identity("CAST(x AS MONEY)") @@ -79,12 +48,15 @@ class TestPostgres(Validator): self.validate_identity("CAST(x AS TSTZMULTIRANGE)") self.validate_identity("CAST(x AS DATERANGE)") self.validate_identity("CAST(x AS DATEMULTIRANGE)") - self.validate_identity( - """LAST_VALUE("col1") OVER (ORDER BY "col2" RANGE BETWEEN INTERVAL '1 DAY' PRECEDING AND '1 month' FOLLOWING)""" - ) self.validate_identity("SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]") self.validate_identity("SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]") - self.validate_identity("SELECT ARRAY[1, 2, 3] && ARRAY[1, 2]") + self.validate_all( + "SELECT ARRAY[1, 2, 3] && ARRAY[1, 2]", + write={ + "": "SELECT ARRAY_OVERLAPS(ARRAY(1, 2, 3), ARRAY(1, 2))", + "postgres": "SELECT ARRAY[1, 2, 3] && ARRAY[1, 2]", + }, + ) self.validate_identity("x$") self.validate_identity("SELECT ARRAY[1, 2, 3]") self.validate_identity("SELECT ARRAY(SELECT 1)") @@ -103,6 +75,28 @@ class TestPostgres(Validator): self.validate_identity("""SELECT * FROM JSON_TO_RECORDSET(z) AS y("rank" INT)""") self.validate_identity("x ~ 'y'") self.validate_identity("x ~* 'y'") + self.validate_identity("SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)") + self.validate_identity("CAST(1 AS DECIMAL) / CAST(2 AS DECIMAL) * -100") + self.validate_identity("EXEC AS myfunc @id = 123", check_command_warning=True) + self.validate_identity("SELECT CURRENT_USER") + self.validate_identity( + """LAST_VALUE("col1") OVER (ORDER BY "col2" RANGE BETWEEN INTERVAL '1 DAY' PRECEDING AND '1 month' FOLLOWING)""" + ) + self.validate_identity( + """ALTER TABLE ONLY "Album" ADD CONSTRAINT "FK_AlbumArtistId" FOREIGN KEY ("ArtistId") REFERENCES "Artist" ("ArtistId") ON DELETE CASCADE""" + ) + self.validate_identity( + """ALTER TABLE ONLY "Album" ADD CONSTRAINT "FK_AlbumArtistId" FOREIGN KEY ("ArtistId") REFERENCES "Artist" ("ArtistId") ON DELETE RESTRICT""" + ) + self.validate_identity( + "SELECT * FROM JSON_ARRAY_ELEMENTS('[1,true, [2,false]]') WITH ORDINALITY" + ) + self.validate_identity( + "SELECT * FROM JSON_ARRAY_ELEMENTS('[1,true, [2,false]]') WITH ORDINALITY AS kv_json" + ) + self.validate_identity( + "SELECT * FROM JSON_ARRAY_ELEMENTS('[1,true, [2,false]]') WITH ORDINALITY AS kv_json(a, b)" + ) self.validate_identity( "SELECT SUM(x) OVER a, SUM(y) OVER b FROM c WINDOW a AS (PARTITION BY d), b AS (PARTITION BY e)" ) @@ -127,7 +121,279 @@ class TestPostgres(Validator): self.validate_identity( "SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')" ) + self.validate_identity( + "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss" + ) + self.validate_identity( + "SELECT c.oid, n.nspname, c.relname " + "FROM pg_catalog.pg_class AS c " + "LEFT JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace " + "WHERE c.relname OPERATOR(pg_catalog.~) '^(courses)$' COLLATE pg_catalog.default AND " + "pg_catalog.PG_TABLE_IS_VISIBLE(c.oid) " + "ORDER BY 2, 3" + ) + self.validate_identity( + "SELECT ARRAY[]::INT[] AS foo", + "SELECT CAST(ARRAY[] AS INT[]) AS foo", + ) + self.validate_identity( + "SELECT DATE_PART('isodow'::varchar(6), current_date)", + "SELECT EXTRACT(CAST('isodow' AS VARCHAR(6)) FROM CURRENT_DATE)", + ) + self.validate_identity( + "END WORK AND NO CHAIN", + "COMMIT AND NO CHAIN", + ) + self.validate_identity( + "END AND CHAIN", + "COMMIT AND CHAIN", + ) + self.validate_identity( + """x ? 'x'""", + "x ? 'x'", + ) + self.validate_identity( + "SELECT $$a$$", + "SELECT 'a'", + ) + self.validate_identity( + "SELECT $$Dianne's horse$$", + "SELECT 'Dianne''s horse'", + ) + self.validate_identity( + "UPDATE MYTABLE T1 SET T1.COL = 13", + "UPDATE MYTABLE AS T1 SET T1.COL = 13", + ) + self.validate_identity( + "x !~ 'y'", + "NOT x ~ 'y'", + ) + self.validate_identity( + "x !~* 'y'", + "NOT x ~* 'y'", + ) + self.validate_identity( + "x ~~ 'y'", + "x LIKE 'y'", + ) + self.validate_identity( + "x ~~* 'y'", + "x ILIKE 'y'", + ) + self.validate_identity( + "x !~~ 'y'", + "NOT x LIKE 'y'", + ) + self.validate_identity( + "x !~~* 'y'", + "NOT x ILIKE 'y'", + ) + self.validate_identity( + "'45 days'::interval day", + "CAST('45 days' AS INTERVAL DAY)", + ) + self.validate_identity( + "'x' 'y' 'z'", + "CONCAT('x', 'y', 'z')", + ) + self.validate_identity( + "x::cstring", + "CAST(x AS CSTRING)", + ) + self.validate_identity( + "x::oid", + "CAST(x AS OID)", + ) + self.validate_identity( + "x::regclass", + "CAST(x AS REGCLASS)", + ) + self.validate_identity( + "x::regcollation", + "CAST(x AS REGCOLLATION)", + ) + self.validate_identity( + "x::regconfig", + "CAST(x AS REGCONFIG)", + ) + self.validate_identity( + "x::regdictionary", + "CAST(x AS REGDICTIONARY)", + ) + self.validate_identity( + "x::regnamespace", + "CAST(x AS REGNAMESPACE)", + ) + self.validate_identity( + "x::regoper", + "CAST(x AS REGOPER)", + ) + self.validate_identity( + "x::regoperator", + "CAST(x AS REGOPERATOR)", + ) + self.validate_identity( + "x::regproc", + "CAST(x AS REGPROC)", + ) + self.validate_identity( + "x::regprocedure", + "CAST(x AS REGPROCEDURE)", + ) + self.validate_identity( + "x::regrole", + "CAST(x AS REGROLE)", + ) + self.validate_identity( + "x::regtype", + "CAST(x AS REGTYPE)", + ) + self.validate_identity( + "123::CHARACTER VARYING", + "CAST(123 AS VARCHAR)", + ) + self.validate_identity( + "TO_TIMESTAMP(123::DOUBLE PRECISION)", + "TO_TIMESTAMP(CAST(123 AS DOUBLE PRECISION))", + ) + self.validate_identity( + "SELECT to_timestamp(123)::time without time zone", + "SELECT CAST(TO_TIMESTAMP(123) AS TIME)", + ) + self.validate_identity( + "SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS 1 PRECEDING)", + "SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)", + ) + self.validate_identity( + "SELECT SUBSTRING(2022::CHAR(4) || LPAD(3::CHAR(2), 2, '0') FROM 3 FOR 4)", + "SELECT SUBSTRING(CAST(2022 AS CHAR(4)) || LPAD(CAST(3 AS CHAR(2)), 2, '0') FROM 3 FOR 4)", + ) + self.validate_identity( + "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) pname ON TRUE WHERE pname IS NULL", + "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL", + ) + self.validate_identity( + "SELECT p1.id, p2.id, v1, v2 FROM polygons AS p1, polygons AS p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id <> p2.id", + "SELECT p1.id, p2.id, v1, v2 FROM polygons AS p1, polygons AS p2, LATERAL VERTICES(p1.poly) AS v1, LATERAL VERTICES(p2.poly) AS v2 WHERE (v1 <-> v2) < 10 AND p1.id <> p2.id", + ) + self.validate_identity( + "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL", + "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL", + ) + self.validate_identity( + "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL", + "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL", + ) + self.validate_identity( + """'{"x": {"y": 1}}'::json->'x'->'y'""", + """JSON_EXTRACT_PATH(JSON_EXTRACT_PATH(CAST('{"x": {"y": 1}}' AS JSON), 'x'), 'y')""", + ) + self.validate_identity( + """'[1,2,3]'::json->>2""", + "JSON_EXTRACT_PATH_TEXT(CAST('[1,2,3]' AS JSON), '2')", + ) + self.validate_identity( + """'{"a":1,"b":2}'::json->>'b'""", + """JSON_EXTRACT_PATH_TEXT(CAST('{"a":1,"b":2}' AS JSON), 'b')""", + ) + self.validate_identity( + """'{"a":[1,2,3],"b":[4,5,6]}'::json#>'{a,2}'""", + """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON) #> '{a,2}'""", + ) + self.validate_identity( + """'{"a":[1,2,3],"b":[4,5,6]}'::json#>>'{a,2}'""", + """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON) #>> '{a,2}'""", + ) + self.validate_identity( + "'[1,2,3]'::json->2", + "JSON_EXTRACT_PATH(CAST('[1,2,3]' AS JSON), '2')", + ) + self.validate_identity( + """SELECT JSON_ARRAY_ELEMENTS((foo->'sections')::JSON) AS sections""", + """SELECT JSON_ARRAY_ELEMENTS(CAST((JSON_EXTRACT_PATH(foo, 'sections')) AS JSON)) AS sections""", + ) + self.validate_identity( + "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET x.a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)", + "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)", + ) + + self.validate_all( + "SELECT JSON_EXTRACT_PATH_TEXT(x, k1, k2, k3) FROM t", + read={ + "clickhouse": "SELECT JSONExtractString(x, k1, k2, k3) FROM t", + "redshift": "SELECT JSON_EXTRACT_PATH_TEXT(x, k1, k2, k3) FROM t", + }, + write={ + "clickhouse": "SELECT JSONExtractString(x, k1, k2, k3) FROM t", + "postgres": "SELECT JSON_EXTRACT_PATH_TEXT(x, k1, k2, k3) FROM t", + "redshift": "SELECT JSON_EXTRACT_PATH_TEXT(x, k1, k2, k3) FROM t", + }, + ) + self.validate_all( + "x #> 'y'", + read={ + "": "JSONB_EXTRACT(x, 'y')", + }, + write={ + "": "JSONB_EXTRACT(x, 'y')", + "postgres": "x #> 'y'", + }, + ) + self.validate_all( + "x #>> 'y'", + read={ + "": "JSONB_EXTRACT_SCALAR(x, 'y')", + }, + write={ + "": "JSONB_EXTRACT_SCALAR(x, 'y')", + "postgres": "x #>> 'y'", + }, + ) + self.validate_all( + "x -> 'y' -> 0 -> 'z'", + write={ + "": "JSON_EXTRACT(JSON_EXTRACT(JSON_EXTRACT(x, '$.y'), '$[0]'), '$.z')", + "postgres": "JSON_EXTRACT_PATH(JSON_EXTRACT_PATH(JSON_EXTRACT_PATH(x, 'y'), '0'), 'z')", + }, + ) + self.validate_all( + """JSON_EXTRACT_PATH('{"f2":{"f3":1},"f4":{"f5":99,"f6":"foo"}}','f4')""", + write={ + "bigquery": """JSON_EXTRACT('{"f2":{"f3":1},"f4":{"f5":99,"f6":"foo"}}', '$.f4')""", + "duckdb": """'{"f2":{"f3":1},"f4":{"f5":99,"f6":"foo"}}' -> '$.f4'""", + "mysql": """JSON_EXTRACT('{"f2":{"f3":1},"f4":{"f5":99,"f6":"foo"}}', '$.f4')""", + "postgres": """JSON_EXTRACT_PATH('{"f2":{"f3":1},"f4":{"f5":99,"f6":"foo"}}', 'f4')""", + "presto": """JSON_EXTRACT('{"f2":{"f3":1},"f4":{"f5":99,"f6":"foo"}}', '$.f4')""", + "redshift": """JSON_EXTRACT_PATH_TEXT('{"f2":{"f3":1},"f4":{"f5":99,"f6":"foo"}}', 'f4')""", + "spark": """GET_JSON_OBJECT('{"f2":{"f3":1},"f4":{"f5":99,"f6":"foo"}}', '$.f4')""", + "sqlite": """'{"f2":{"f3":1},"f4":{"f5":99,"f6":"foo"}}' -> '$.f4'""", + "tsql": """ISNULL(JSON_QUERY('{"f2":{"f3":1},"f4":{"f5":99,"f6":"foo"}}', '$.f4'), JSON_VALUE('{"f2":{"f3":1},"f4":{"f5":99,"f6":"foo"}}', '$.f4'))""", + }, + ) + self.validate_all( + """JSON_EXTRACT_PATH_TEXT('{"farm": ["a", "b", "c"]}', 'farm', '0')""", + read={ + "duckdb": """'{"farm": ["a", "b", "c"]}' ->> '$.farm[0]'""", + "redshift": """JSON_EXTRACT_PATH_TEXT('{"farm": ["a", "b", "c"]}', 'farm', '0')""", + }, + write={ + "duckdb": """'{"farm": ["a", "b", "c"]}' ->> '$.farm[0]'""", + "postgres": """JSON_EXTRACT_PATH_TEXT('{"farm": ["a", "b", "c"]}', 'farm', '0')""", + "redshift": """JSON_EXTRACT_PATH_TEXT('{"farm": ["a", "b", "c"]}', 'farm', '0')""", + }, + ) + self.validate_all( + "JSON_EXTRACT_PATH(x, 'x', 'y', 'z')", + read={ + "duckdb": "x -> '$.x.y.z'", + "postgres": "JSON_EXTRACT_PATH(x, 'x', 'y', 'z')", + }, + write={ + "duckdb": "x -> '$.x.y.z'", + "redshift": "JSON_EXTRACT_PATH_TEXT(x, 'x', 'y', 'z')", + }, + ) self.validate_all( "SELECT * FROM t TABLESAMPLE SYSTEM (50)", write={ @@ -152,12 +418,6 @@ class TestPostgres(Validator): }, ) self.validate_all( - "SELECT DATE_PART('isodow'::varchar(6), current_date)", - write={ - "postgres": "SELECT EXTRACT(CAST('isodow' AS VARCHAR(6)) FROM CURRENT_DATE)", - }, - ) - self.validate_all( "SELECT DATE_PART('minute', timestamp '2023-01-04 04:05:06.789')", write={ "postgres": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", @@ -247,14 +507,6 @@ class TestPostgres(Validator): }, ) self.validate_all( - "END WORK AND NO CHAIN", - write={"postgres": "COMMIT AND NO CHAIN"}, - ) - self.validate_all( - "END AND CHAIN", - write={"postgres": "COMMIT AND CHAIN"}, - ) - self.validate_all( "CREATE TABLE x (a UUID, b BYTEA)", write={ "duckdb": "CREATE TABLE x (a UUID, b BLOB)", @@ -264,24 +516,6 @@ class TestPostgres(Validator): }, ) self.validate_all( - "123::CHARACTER VARYING", - write={"postgres": "CAST(123 AS VARCHAR)"}, - ) - self.validate_all( - "TO_TIMESTAMP(123::DOUBLE PRECISION)", - write={"postgres": "TO_TIMESTAMP(CAST(123 AS DOUBLE PRECISION))"}, - ) - self.validate_all( - "SELECT to_timestamp(123)::time without time zone", - write={"postgres": "SELECT CAST(TO_TIMESTAMP(123) AS TIME)"}, - ) - self.validate_all( - "SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS 1 PRECEDING)", - write={ - "postgres": "SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)", - }, - ) - self.validate_all( "SELECT * FROM x FETCH 1 ROW", write={ "postgres": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY", @@ -315,12 +549,6 @@ class TestPostgres(Validator): }, ) self.validate_all( - "SELECT SUBSTRING(CAST(2022 AS CHAR(4)) || LPAD(CAST(3 AS CHAR(2)), 2, '0') FROM 3 FOR 4)", - read={ - "postgres": "SELECT SUBSTRING(2022::CHAR(4) || LPAD(3::CHAR(2), 2, '0') FROM 3 FOR 4)", - }, - ) - self.validate_all( "SELECT TRIM(BOTH ' XXX ')", write={ "mysql": "SELECT TRIM(' XXX ')", @@ -347,185 +575,13 @@ class TestPostgres(Validator): }, ) self.validate_all( - "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss", - read={ - "postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss" - }, - ) - self.validate_all( - "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) pname ON TRUE WHERE pname IS NULL", - write={ - "postgres": "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL", - }, - ) - self.validate_all( - "SELECT p1.id, p2.id, v1, v2 FROM polygons AS p1, polygons AS p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id <> p2.id", - write={ - "postgres": "SELECT p1.id, p2.id, v1, v2 FROM polygons AS p1, polygons AS p2, LATERAL VERTICES(p1.poly) AS v1, LATERAL VERTICES(p2.poly) AS v2 WHERE (v1 <-> v2) < 10 AND p1.id <> p2.id", - }, - ) - self.validate_all( - "SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)", - write={ - "postgres": "SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)", - }, - ) - self.validate_all( - "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL", - read={ - "postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL" - }, - ) - self.validate_all( - "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL", - read={ - "postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL" - }, - ) - self.validate_all( - "'[1,2,3]'::json->2", - write={"postgres": "CAST('[1,2,3]' AS JSON) -> 2"}, - ) - self.validate_all( """'{"a":1,"b":2}'::json->'b'""", write={ - "postgres": """CAST('{"a":1,"b":2}' AS JSON) -> 'b'""", - "redshift": """CAST('{"a":1,"b":2}' AS JSON)."b\"""", - }, - ) - self.validate_all( - """'{"x": {"y": 1}}'::json->'x'->'y'""", - write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON) -> 'x' -> 'y'"""}, - ) - self.validate_all( - """'{"x": {"y": 1}}'::json->'x'::json->'y'""", - write={"postgres": """CAST(CAST('{"x": {"y": 1}}' AS JSON) -> 'x' AS JSON) -> 'y'"""}, - ) - self.validate_all( - """'[1,2,3]'::json->>2""", - write={"postgres": "CAST('[1,2,3]' AS JSON) ->> 2"}, - ) - self.validate_all( - """'{"a":1,"b":2}'::json->>'b'""", - write={"postgres": """CAST('{"a":1,"b":2}' AS JSON) ->> 'b'"""}, - ) - self.validate_all( - """'{"a":[1,2,3],"b":[4,5,6]}'::json#>'{a,2}'""", - write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON) #> '{a,2}'"""}, - ) - self.validate_all( - """'{"a":[1,2,3],"b":[4,5,6]}'::json#>>'{a,2}'""", - write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON) #>> '{a,2}'"""}, - ) - self.validate_all( - """SELECT JSON_ARRAY_ELEMENTS((foo->'sections')::JSON) AS sections""", - write={ - "postgres": """SELECT JSON_ARRAY_ELEMENTS(CAST((foo -> 'sections') AS JSON)) AS sections""", - "presto": """SELECT JSON_ARRAY_ELEMENTS(CAST((JSON_EXTRACT(foo, 'sections')) AS JSON)) AS sections""", + "postgres": """JSON_EXTRACT_PATH(CAST('{"a":1,"b":2}' AS JSON), 'b')""", + "redshift": """JSON_EXTRACT_PATH_TEXT('{"a":1,"b":2}', 'b')""", }, ) self.validate_all( - """x ? 'x'""", - write={"postgres": "x ? 'x'"}, - ) - self.validate_all( - "SELECT $$a$$", - write={"postgres": "SELECT 'a'"}, - ) - self.validate_all( - "SELECT $$Dianne's horse$$", - write={"postgres": "SELECT 'Dianne''s horse'"}, - ) - self.validate_all( - "UPDATE MYTABLE T1 SET T1.COL = 13", - write={"postgres": "UPDATE MYTABLE AS T1 SET T1.COL = 13"}, - ) - self.validate_all( - "x !~ 'y'", - write={"postgres": "NOT x ~ 'y'"}, - ) - self.validate_all( - "x !~* 'y'", - write={"postgres": "NOT x ~* 'y'"}, - ) - - self.validate_all( - "x ~~ 'y'", - write={"postgres": "x LIKE 'y'"}, - ) - self.validate_all( - "x ~~* 'y'", - write={"postgres": "x ILIKE 'y'"}, - ) - self.validate_all( - "x !~~ 'y'", - write={"postgres": "NOT x LIKE 'y'"}, - ) - self.validate_all( - "x !~~* 'y'", - write={"postgres": "NOT x ILIKE 'y'"}, - ) - self.validate_all( - "'45 days'::interval day", - write={"postgres": "CAST('45 days' AS INTERVAL DAY)"}, - ) - self.validate_all( - "'x' 'y' 'z'", - write={"postgres": "CONCAT('x', 'y', 'z')"}, - ) - self.validate_all( - "x::cstring", - write={"postgres": "CAST(x AS CSTRING)"}, - ) - self.validate_all( - "x::oid", - write={"postgres": "CAST(x AS OID)"}, - ) - self.validate_all( - "x::regclass", - write={"postgres": "CAST(x AS REGCLASS)"}, - ) - self.validate_all( - "x::regcollation", - write={"postgres": "CAST(x AS REGCOLLATION)"}, - ) - self.validate_all( - "x::regconfig", - write={"postgres": "CAST(x AS REGCONFIG)"}, - ) - self.validate_all( - "x::regdictionary", - write={"postgres": "CAST(x AS REGDICTIONARY)"}, - ) - self.validate_all( - "x::regnamespace", - write={"postgres": "CAST(x AS REGNAMESPACE)"}, - ) - self.validate_all( - "x::regoper", - write={"postgres": "CAST(x AS REGOPER)"}, - ) - self.validate_all( - "x::regoperator", - write={"postgres": "CAST(x AS REGOPERATOR)"}, - ) - self.validate_all( - "x::regproc", - write={"postgres": "CAST(x AS REGPROC)"}, - ) - self.validate_all( - "x::regprocedure", - write={"postgres": "CAST(x AS REGPROCEDURE)"}, - ) - self.validate_all( - "x::regrole", - write={"postgres": "CAST(x AS REGROLE)"}, - ) - self.validate_all( - "x::regtype", - write={"postgres": "CAST(x AS REGTYPE)"}, - ) - self.validate_all( "TRIM(BOTH 'as' FROM 'as string as')", write={ "postgres": "TRIM(BOTH 'as' FROM 'as string as')", @@ -562,13 +618,6 @@ class TestPostgres(Validator): }, ) self.validate_all( - "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET x.a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)", - write={ - "postgres": "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)", - }, - ) - - self.validate_all( "x / y ^ z", write={ "": "x / POWER(y, z)", @@ -685,48 +734,30 @@ class TestPostgres(Validator): "CREATE TABLE test (x TIMESTAMP WITHOUT TIME ZONE[][])", "CREATE TABLE test (x TIMESTAMP[][])", ) - - self.validate_all( + self.validate_identity( "CREATE OR REPLACE FUNCTION function_name (input_a character varying DEFAULT NULL::character varying)", - write={ - "postgres": "CREATE OR REPLACE FUNCTION function_name(input_a VARCHAR DEFAULT CAST(NULL AS VARCHAR))", - }, + "CREATE OR REPLACE FUNCTION function_name(input_a VARCHAR DEFAULT CAST(NULL AS VARCHAR))", ) - self.validate_all( + self.validate_identity( + "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)", "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)", - write={ - "postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)" - }, ) - self.validate_all( + self.validate_identity( + "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)", "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)", - write={ - "postgres": "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)" - }, ) - self.validate_all( + self.validate_identity( + "CREATE TABLE products (product_no INT, name TEXT, price DECIMAL, UNIQUE (product_no, name))", "CREATE TABLE products (product_no INT, name TEXT, price DECIMAL, UNIQUE (product_no, name))", - write={ - "postgres": "CREATE TABLE products (product_no INT, name TEXT, price DECIMAL, UNIQUE (product_no, name))" - }, ) - self.validate_all( + self.validate_identity( "CREATE TABLE products (" "product_no INT UNIQUE," " name TEXT," " price DECIMAL CHECK (price > 0)," " discounted_price DECIMAL CONSTRAINT positive_discount CHECK (discounted_price > 0)," " CHECK (product_no > 1)," - " CONSTRAINT valid_discount CHECK (price > discounted_price))", - write={ - "postgres": "CREATE TABLE products (" - "product_no INT UNIQUE," - " name TEXT," - " price DECIMAL CHECK (price > 0)," - " discounted_price DECIMAL CONSTRAINT positive_discount CHECK (discounted_price > 0)," - " CHECK (product_no > 1)," - " CONSTRAINT valid_discount CHECK (price > discounted_price))" - }, + " CONSTRAINT valid_discount CHECK (price > discounted_price))" ) self.validate_identity( """ @@ -819,9 +850,9 @@ class TestPostgres(Validator): self.validate_identity("SELECT 1 OPERATOR(pg_catalog.+) 2") def test_bool_or(self): - self.validate_all( + self.validate_identity( "SELECT a, LOGICAL_OR(b) FROM table GROUP BY a", - write={"postgres": "SELECT a, BOOL_OR(b) FROM table GROUP BY a"}, + "SELECT a, BOOL_OR(b) FROM table GROUP BY a", ) def test_string_concat(self): @@ -849,11 +880,27 @@ class TestPostgres(Validator): ) def test_variance(self): - self.validate_all("VAR_SAMP(x)", write={"postgres": "VAR_SAMP(x)"}) - self.validate_all("VAR_POP(x)", write={"postgres": "VAR_POP(x)"}) - self.validate_all("VARIANCE(x)", write={"postgres": "VAR_SAMP(x)"}) + self.validate_identity( + "VAR_SAMP(x)", + "VAR_SAMP(x)", + ) + self.validate_identity( + "VAR_POP(x)", + "VAR_POP(x)", + ) + self.validate_identity( + "VARIANCE(x)", + "VAR_SAMP(x)", + ) + self.validate_all( - "VAR_POP(x)", read={"": "VARIANCE_POP(x)"}, write={"postgres": "VAR_POP(x)"} + "VAR_POP(x)", + read={ + "": "VARIANCE_POP(x)", + }, + write={ + "postgres": "VAR_POP(x)", + }, ) def test_regexp_binary(self): diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 387b0e0..36006d2 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -343,16 +343,16 @@ class TestPresto(Validator): ) self.validate_all( - "SELECT timestamp '2012-10-31 00:00' AT TIME ZONE 'America/Sao_Paulo'", + "SELECT CAST('2012-10-31 00:00' AS TIMESTAMP) AT TIME ZONE 'America/Sao_Paulo'", write={ "spark": "SELECT FROM_UTC_TIMESTAMP(CAST('2012-10-31 00:00' AS TIMESTAMP), 'America/Sao_Paulo')", - "presto": "SELECT CAST('2012-10-31 00:00' AS TIMESTAMP) AT TIME ZONE 'America/Sao_Paulo'", + "presto": "SELECT AT_TIMEZONE(CAST('2012-10-31 00:00' AS TIMESTAMP), 'America/Sao_Paulo')", }, ) self.validate_all( - "CAST('2012-10-31 00:00' AS TIMESTAMP) AT TIME ZONE 'America/Sao_Paulo'", + "SELECT AT_TIMEZONE(CAST('2012-10-31 00:00' AS TIMESTAMP), 'America/Sao_Paulo')", read={ - "spark": "FROM_UTC_TIMESTAMP('2012-10-31 00:00', 'America/Sao_Paulo')", + "spark": "SELECT FROM_UTC_TIMESTAMP(TIMESTAMP '2012-10-31 00:00', 'America/Sao_Paulo')", }, ) self.validate_all( @@ -368,7 +368,7 @@ class TestPresto(Validator): "TIMESTAMP(x, 'America/Los_Angeles')", write={ "duckdb": "CAST(x AS TIMESTAMP) AT TIME ZONE 'America/Los_Angeles'", - "presto": "CAST(x AS TIMESTAMP) AT TIME ZONE 'America/Los_Angeles'", + "presto": "AT_TIMEZONE(CAST(x AS TIMESTAMP), 'America/Los_Angeles')", }, ) # this case isn't really correct, but it's a fall back for mysql's version @@ -564,7 +564,7 @@ class TestPresto(Validator): ) def test_presto(self): - with self.assertLogs(helper_logger) as cm: + with self.assertLogs(helper_logger): self.validate_all( "SELECT COALESCE(ELEMENT_AT(MAP_FROM_ENTRIES(ARRAY[(51, '1')]), id), quantity) FROM my_table", write={ diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 9ccd955..b6b6ccc 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -6,7 +6,29 @@ class TestRedshift(Validator): dialect = "redshift" def test_redshift(self): - self.validate_identity("CREATE MATERIALIZED VIEW orders AUTO REFRESH YES AS SELECT 1") + self.validate_all( + "GETDATE()", + read={ + "duckdb": "CURRENT_TIMESTAMP", + }, + write={ + "duckdb": "CURRENT_TIMESTAMP", + "redshift": "GETDATE()", + }, + ) + self.validate_all( + """SELECT JSON_EXTRACT_PATH_TEXT('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', 'farm', 'barn', 'color')""", + write={ + "bigquery": """SELECT JSON_EXTRACT_SCALAR('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', '$.farm.barn.color')""", + "databricks": """SELECT GET_JSON_OBJECT('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', '$.farm.barn.color')""", + "duckdb": """SELECT '{ "farm": {"barn": { "color": "red", "feed stocked": true }}}' ->> '$.farm.barn.color'""", + "postgres": """SELECT JSON_EXTRACT_PATH_TEXT('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', 'farm', 'barn', 'color')""", + "presto": """SELECT JSON_EXTRACT_SCALAR('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', '$.farm.barn.color')""", + "redshift": """SELECT JSON_EXTRACT_PATH_TEXT('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', 'farm', 'barn', 'color')""", + "spark": """SELECT GET_JSON_OBJECT('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', '$.farm.barn.color')""", + "sqlite": """SELECT '{ "farm": {"barn": { "color": "red", "feed stocked": true }}}' ->> '$.farm.barn.color'""", + }, + ) self.validate_all( "LISTAGG(sellerid, ', ')", read={ @@ -271,6 +293,7 @@ class TestRedshift(Validator): ) def test_identity(self): + self.validate_identity("CREATE MATERIALIZED VIEW orders AUTO REFRESH YES AS SELECT 1") self.validate_identity("SELECT DATEADD(DAY, 1, 'today')") self.validate_identity("SELECT * FROM #x") self.validate_identity("SELECT INTERVAL '5 DAY'") @@ -283,6 +306,9 @@ class TestRedshift(Validator): self.validate_identity("SELECT APPROXIMATE AS y") self.validate_identity("CREATE TABLE t (c BIGINT IDENTITY(0, 1))") self.validate_identity( + """SELECT JSON_EXTRACT_PATH_TEXT('{"f2":{"f3":1},"f4":{"f5":99,"f6":"star"}', 'f4', 'f6', TRUE)""" + ) + self.validate_identity( "SELECT CONCAT('abc', 'def')", "SELECT 'abc' || 'def'", ) @@ -458,16 +484,26 @@ FROM ( ) def test_create_table_like(self): + self.validate_identity( + "CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL" + ) + self.validate_all( - "CREATE TABLE t1 LIKE t2", + "CREATE TABLE t1 (LIKE t2)", write={ + "postgres": "CREATE TABLE t1 (LIKE t2)", + "presto": "CREATE TABLE t1 (LIKE t2)", "redshift": "CREATE TABLE t1 (LIKE t2)", + "trino": "CREATE TABLE t1 (LIKE t2)", }, ) self.validate_all( - "CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL", + "CREATE TABLE t1 (col VARCHAR, LIKE t2)", write={ - "redshift": "CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL", + "postgres": "CREATE TABLE t1 (col VARCHAR, LIKE t2)", + "presto": "CREATE TABLE t1 (col VARCHAR, LIKE t2)", + "redshift": "CREATE TABLE t1 (col VARCHAR, LIKE t2)", + "trino": "CREATE TABLE t1 (col VARCHAR, LIKE t2)", }, ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 7e41fd4..7a821f6 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -85,6 +85,10 @@ WHERE "SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) AS x TABLESAMPLE (0.1)" ) self.validate_identity( + """SELECT GET_PATH(PARSE_JSON('{"y": [{"z": 1}]}'), 'y[0]:z')""", + """SELECT GET_PATH(PARSE_JSON('{"y": [{"z": 1}]}'), 'y[0].z')""", + ) + self.validate_identity( "SELECT p FROM t WHERE p:val NOT IN ('2')", "SELECT p FROM t WHERE NOT GET_PATH(p, 'val') IN ('2')", ) @@ -118,7 +122,7 @@ WHERE ) self.validate_identity( 'SELECT v:"fruit" FROM vartab', - """SELECT GET_PATH(v, '"fruit"') FROM vartab""", + """SELECT GET_PATH(v, 'fruit') FROM vartab""", ) self.validate_identity( "v:attr[0]:name", @@ -249,7 +253,7 @@ WHERE "mysql": """WITH vartab(v) AS (SELECT '[{"attr": [{"name": "banana"}]}]') SELECT JSON_EXTRACT(v, '$[0].attr[0].name') FROM vartab""", "presto": """WITH vartab(v) AS (SELECT JSON_PARSE('[{"attr": [{"name": "banana"}]}]')) SELECT JSON_EXTRACT(v, '$[0].attr[0].name') FROM vartab""", "snowflake": """WITH vartab(v) AS (SELECT PARSE_JSON('[{"attr": [{"name": "banana"}]}]')) SELECT GET_PATH(v, '[0].attr[0].name') FROM vartab""", - "tsql": """WITH vartab(v) AS (SELECT '[{"attr": [{"name": "banana"}]}]') SELECT JSON_VALUE(v, '$[0].attr[0].name') FROM vartab""", + "tsql": """WITH vartab(v) AS (SELECT '[{"attr": [{"name": "banana"}]}]') SELECT ISNULL(JSON_QUERY(v, '$[0].attr[0].name'), JSON_VALUE(v, '$[0].attr[0].name')) FROM vartab""", }, ) self.validate_all( @@ -260,7 +264,7 @@ WHERE "mysql": """WITH vartab(v) AS (SELECT '{"attr": [{"name": "banana"}]}') SELECT JSON_EXTRACT(v, '$.attr[0].name') FROM vartab""", "presto": """WITH vartab(v) AS (SELECT JSON_PARSE('{"attr": [{"name": "banana"}]}')) SELECT JSON_EXTRACT(v, '$.attr[0].name') FROM vartab""", "snowflake": """WITH vartab(v) AS (SELECT PARSE_JSON('{"attr": [{"name": "banana"}]}')) SELECT GET_PATH(v, 'attr[0].name') FROM vartab""", - "tsql": """WITH vartab(v) AS (SELECT '{"attr": [{"name": "banana"}]}') SELECT JSON_VALUE(v, '$.attr[0].name') FROM vartab""", + "tsql": """WITH vartab(v) AS (SELECT '{"attr": [{"name": "banana"}]}') SELECT ISNULL(JSON_QUERY(v, '$.attr[0].name'), JSON_VALUE(v, '$.attr[0].name')) FROM vartab""", }, ) self.validate_all( @@ -271,7 +275,7 @@ WHERE "mysql": """SELECT JSON_EXTRACT('{"fruit":"banana"}', '$.fruit')""", "presto": """SELECT JSON_EXTRACT(JSON_PARSE('{"fruit":"banana"}'), '$.fruit')""", "snowflake": """SELECT GET_PATH(PARSE_JSON('{"fruit":"banana"}'), 'fruit')""", - "tsql": """SELECT JSON_VALUE('{"fruit":"banana"}', '$.fruit')""", + "tsql": """SELECT ISNULL(JSON_QUERY('{"fruit":"banana"}', '$.fruit'), JSON_VALUE('{"fruit":"banana"}', '$.fruit'))""", }, ) self.validate_all( @@ -550,7 +554,7 @@ WHERE write={ "duckdb": """SELECT JSON('{"a": {"b c": "foo"}}') -> '$.a' -> '$."b c"'""", "mysql": """SELECT JSON_EXTRACT(JSON_EXTRACT('{"a": {"b c": "foo"}}', '$.a'), '$."b c"')""", - "snowflake": """SELECT GET_PATH(GET_PATH(PARSE_JSON('{"a": {"b c": "foo"}}'), 'a'), '"b c"')""", + "snowflake": """SELECT GET_PATH(GET_PATH(PARSE_JSON('{"a": {"b c": "foo"}}'), 'a'), '["b c"]')""", }, ) self.validate_all( @@ -744,7 +748,7 @@ WHERE self.validate_all( r"SELECT FIRST_VALUE(TABLE1.COLUMN1 RESPECT NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1", write={ - "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1 RESPECT NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1" + "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) RESPECT NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1" }, ) self.validate_all( @@ -756,7 +760,7 @@ WHERE self.validate_all( r"SELECT FIRST_VALUE(TABLE1.COLUMN1 IGNORE NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1", write={ - "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1 IGNORE NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1" + "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) IGNORE NULLS OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1" }, ) self.validate_all( @@ -1454,12 +1458,6 @@ MATCH_RECOGNIZE ( ) def test_show(self): - # Parsed as Command - self.validate_identity( - "SHOW TABLES LIKE 'line%' IN tpch.public", check_command_warning=True - ) - self.validate_identity("SHOW TABLES HISTORY IN tpch.public", check_command_warning=True) - # Parsed as Show self.validate_identity("SHOW PRIMARY KEYS") self.validate_identity("SHOW PRIMARY KEYS IN ACCOUNT") @@ -1487,6 +1485,22 @@ MATCH_RECOGNIZE ( "show terse objects in db1.schema1 starts with 'a' limit 10 from 'b'", "SHOW TERSE OBJECTS IN SCHEMA db1.schema1 STARTS WITH 'a' LIMIT 10 FROM 'b'", ) + self.validate_identity( + "SHOW TABLES LIKE 'line%' IN tpch.public", + "SHOW TABLES LIKE 'line%' IN SCHEMA tpch.public", + ) + self.validate_identity( + "SHOW TABLES HISTORY IN tpch.public", + "SHOW TABLES HISTORY IN SCHEMA tpch.public", + ) + self.validate_identity( + "show terse tables in schema db1.schema1 starts with 'a' limit 10 from 'b'", + "SHOW TERSE TABLES IN SCHEMA db1.schema1 STARTS WITH 'a' LIMIT 10 FROM 'b'", + ) + self.validate_identity( + "show terse tables in db1.schema1 starts with 'a' limit 10 from 'b'", + "SHOW TERSE TABLES IN SCHEMA db1.schema1 STARTS WITH 'a' LIMIT 10 FROM 'b'", + ) ast = parse_one('SHOW PRIMARY KEYS IN "TEST"."PUBLIC"."customers"', read="snowflake") table = ast.find(exp.Table) @@ -1517,6 +1531,11 @@ MATCH_RECOGNIZE ( table = ast.find(exp.Table) self.assertEqual(table.sql(dialect="snowflake"), "db1.schema1") + ast = parse_one("SHOW TABLES IN db1.schema1", read="snowflake") + self.assertEqual(ast.args.get("scope_kind"), "SCHEMA") + table = ast.find(exp.Table) + self.assertEqual(table.sql(dialect="snowflake"), "db1.schema1") + def test_swap(self): ast = parse_one("ALTER TABLE a SWAP WITH b", read="snowflake") assert isinstance(ast, exp.AlterTable) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 6044037..a02a735 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -227,9 +227,11 @@ TBLPROPERTIES ( ) def test_spark(self): - expr = parse_one("any_value(col, true)", read="spark") - self.assertIsInstance(expr.args.get("ignore_nulls"), exp.Boolean) - self.assertEqual(expr.sql(dialect="spark"), "ANY_VALUE(col, TRUE)") + self.validate_identity("any_value(col, true)", "ANY_VALUE(col) IGNORE NULLS") + self.validate_identity("first(col, true)", "FIRST(col) IGNORE NULLS") + self.validate_identity("first_value(col, true)", "FIRST_VALUE(col) IGNORE NULLS") + self.validate_identity("last(col, true)", "LAST(col) IGNORE NULLS") + self.validate_identity("last_value(col, true)", "LAST_VALUE(col) IGNORE NULLS") self.assertEqual( parse_one("REFRESH TABLE t", read="spark").assert_is(exp.Refresh).sql(dialect="spark"), @@ -290,7 +292,7 @@ TBLPROPERTIES ( self.validate_all( "SELECT FROM_UTC_TIMESTAMP('2016-08-31', 'Asia/Seoul')", write={ - "presto": "SELECT CAST('2016-08-31' AS TIMESTAMP) AT TIME ZONE 'Asia/Seoul'", + "presto": "SELECT AT_TIMEZONE(CAST('2016-08-31' AS TIMESTAMP), 'Asia/Seoul')", "spark": "SELECT FROM_UTC_TIMESTAMP(CAST('2016-08-31' AS TIMESTAMP), 'Asia/Seoul')", }, ) diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index 3df74c8..f7a3dd7 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -10,12 +10,10 @@ class TestSQLite(Validator): self.validate_identity("INSERT OR IGNORE INTO foo (x, y) VALUES (1, 2)") self.validate_identity("INSERT OR REPLACE INTO foo (x, y) VALUES (1, 2)") self.validate_identity("INSERT OR ROLLBACK INTO foo (x, y) VALUES (1, 2)") + self.validate_identity("CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)") + self.validate_identity("CREATE TEMPORARY TABLE foo (id INTEGER)") self.validate_all( - "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)", - write={"sqlite": "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)"}, - ) - self.validate_all( """ CREATE TABLE "Track" ( @@ -73,6 +71,9 @@ class TestSQLite(Validator): self.validate_identity("SELECT UNIXEPOCH('now', 'subsec')") self.validate_identity("SELECT TIMEDIFF('now', '1809-02-12')") self.validate_identity( + "SELECT JSON_EXTRACT('[10, 20, [30, 40]]', '$[2]', '$[0]', '$[1]')", + ) + self.validate_identity( """SELECT item AS "item", some AS "some" FROM data WHERE (item = 'value_1' COLLATE NOCASE) AND (some = 't' COLLATE NOCASE) ORDER BY item ASC LIMIT 1 OFFSET 0""" ) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 101d356..c8c0d82 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -886,7 +886,7 @@ WHERE "END", ] - with self.assertLogs(parser_logger) as cm: + with self.assertLogs(parser_logger): for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): self.assertEqual(expr.sql(dialect="tsql"), expected_sql) @@ -907,7 +907,7 @@ WHERE "CREATE TABLE [target_schema].[target_table] (a INTEGER) WITH (DISTRIBUTION=REPLICATE, HEAP)", ] - with self.assertLogs(parser_logger) as cm: + with self.assertLogs(parser_logger): for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): self.assertEqual(expr.sql(dialect="tsql"), expected_sql) @@ -941,6 +941,16 @@ WHERE self.validate_all( "LEN(x)", read={"": "LENGTH(x)"}, write={"spark": "LENGTH(CAST(x AS STRING))"} ) + self.validate_all( + "RIGHT(x, 1)", + read={"": "RIGHT(CAST(x AS STRING), 1)"}, + write={"spark": "RIGHT(CAST(x AS STRING), 1)"}, + ) + self.validate_all( + "LEFT(x, 1)", + read={"": "LEFT(CAST(x AS STRING), 1)"}, + write={"spark": "LEFT(CAST(x AS STRING), 1)"}, + ) self.validate_all("LEN(1)", write={"tsql": "LEN(1)", "spark": "LENGTH(CAST(1 AS STRING))"}) self.validate_all("LEN('x')", write={"tsql": "LEN('x')", "spark": "LENGTH('x')"}) @@ -950,10 +960,20 @@ WHERE def test_isnull(self): self.validate_all("ISNULL(x, y)", write={"spark": "COALESCE(x, y)"}) - def test_jsonvalue(self): + def test_json(self): + self.validate_all( + "JSON_QUERY(r.JSON, '$.Attr_INT')", + write={ + "spark": "GET_JSON_OBJECT(r.JSON, '$.Attr_INT')", + "tsql": "ISNULL(JSON_QUERY(r.JSON, '$.Attr_INT'), JSON_VALUE(r.JSON, '$.Attr_INT'))", + }, + ) self.validate_all( "JSON_VALUE(r.JSON, '$.Attr_INT')", - write={"spark": "GET_JSON_OBJECT(r.JSON, '$.Attr_INT')"}, + write={ + "spark": "GET_JSON_OBJECT(r.JSON, '$.Attr_INT')", + "tsql": "ISNULL(JSON_QUERY(r.JSON, '$.Attr_INT'), JSON_VALUE(r.JSON, '$.Attr_INT'))", + }, ) def test_datefromparts(self): @@ -1438,7 +1458,7 @@ WHERE "mysql": "LAST_DAY(DATE(CURRENT_TIMESTAMP()))", "postgres": "CAST(DATE_TRUNC('MONTH', CAST(CURRENT_TIMESTAMP AS DATE)) + INTERVAL '1 MONTH' - INTERVAL '1 DAY' AS DATE)", "presto": "LAST_DAY_OF_MONTH(CAST(CAST(CURRENT_TIMESTAMP AS TIMESTAMP) AS DATE))", - "redshift": "LAST_DAY(CAST(SYSDATE AS DATE))", + "redshift": "LAST_DAY(CAST(GETDATE() AS DATE))", "snowflake": "LAST_DAY(CAST(CURRENT_TIMESTAMP() AS DATE))", "spark": "LAST_DAY(TO_DATE(CURRENT_TIMESTAMP()))", "tsql": "EOMONTH(CAST(GETDATE() AS DATE))", @@ -1453,7 +1473,7 @@ WHERE "mysql": "LAST_DAY(DATE_ADD(CURRENT_TIMESTAMP(), INTERVAL -1 MONTH))", "postgres": "CAST(DATE_TRUNC('MONTH', CAST(CURRENT_TIMESTAMP AS DATE) + INTERVAL '-1 MONTH') + INTERVAL '1 MONTH' - INTERVAL '1 DAY' AS DATE)", "presto": "LAST_DAY_OF_MONTH(DATE_ADD('MONTH', CAST(-1 AS BIGINT), CAST(CAST(CURRENT_TIMESTAMP AS TIMESTAMP) AS DATE)))", - "redshift": "LAST_DAY(DATEADD(MONTH, -1, CAST(SYSDATE AS DATE)))", + "redshift": "LAST_DAY(DATEADD(MONTH, -1, CAST(GETDATE() AS DATE)))", "snowflake": "LAST_DAY(DATEADD(MONTH, -1, CAST(CURRENT_TIMESTAMP() AS DATE)))", "spark": "LAST_DAY(ADD_MONTHS(TO_DATE(CURRENT_TIMESTAMP()), -1))", "tsql": "EOMONTH(DATEADD(MONTH, -1, CAST(GETDATE() AS DATE)))", diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 54d41b4..366b79e 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -158,7 +158,7 @@ CAST(x AS UUID) FILTER(a, x -> x.a.b.c.d.e.f.g) FILTER(a, x -> FOO(x.a.b.c.d.e.f.g) + x.a.b.c.d.e.f.g) TIMESTAMP_FROM_PARTS(2019, 1, 10, 2, 3, 4, 123456789, 'America/Los_Angeles') -TIMESTAMP_DIFF(CURRENT_TIMESTAMP(), 1, DAY) +TIMESTAMPDIFF(CURRENT_TIMESTAMP(), 1, DAY) DATETIME_DIFF(CURRENT_DATE, 1, DAY) QUANTILE(x, 0.5) REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2])) @@ -237,13 +237,8 @@ SELECT AGGREGATE(a, (a, b) -> a + b) AS x SELECT COUNT(DISTINCT a, b) SELECT COUNT(DISTINCT a, b + 1) SELECT SUM(DISTINCT x) -SELECT SUM(x IGNORE NULLS) AS x -SELECT COUNT(x RESPECT NULLS) SELECT TRUNCATE(a, b) -SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x -SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 1, 10) AS x SELECT ARRAY_AGG(STRUCT(x, x AS y) ORDER BY z DESC) AS x -SELECT LAST_VALUE(x IGNORE NULLS) OVER y AS x SELECT LAG(x) OVER (ORDER BY y) AS x SELECT LEAD(a) OVER (ORDER BY b) AS a SELECT LEAD(a, 1) OVER (PARTITION BY a ORDER BY a) AS x @@ -361,9 +356,10 @@ SELECT COUNT(DISTINCT a) FROM test SELECT EXP(a) FROM test SELECT FLOOR(a) FROM test SELECT FLOOR(a, b) FROM test -SELECT FIRST(a) FROM test +SELECT FIRST_VALUE(a) FROM test SELECT GREATEST(a, b, c) FROM test -SELECT LAST(a) FROM test +SELECT LAST_VALUE(a) FROM test +SELECT LAST_VALUE(a) IGNORE NULLS OVER () + 1 SELECT LN(a) FROM test SELECT LOG10(a) FROM test SELECT MAX(a) FROM test @@ -825,8 +821,6 @@ SELECT if.x SELECT NEXT VALUE FOR db.schema.sequence_name SELECT NEXT VALUE FOR db.schema.sequence_name OVER (ORDER BY foo), col SELECT PERCENTILE_CONT(x, 0.5) OVER () -SELECT PERCENTILE_CONT(x, 0.5 RESPECT NULLS) OVER () -SELECT PERCENTILE_CONT(x, 0.5 IGNORE NULLS) OVER () WITH my_cte AS (SELECT 'a' AS desc) SELECT desc AS description FROM my_cte WITH my_cte AS (SELECT 'a' AS asc) SELECT asc AS description FROM my_cte SELECT * FROM case @@ -852,3 +846,4 @@ SELECT x FROM t1 UNION ALL SELECT x FROM t2 UNION ALL SELECT x FROM t3 LIMIT 1 WITH use(use) AS (SELECT 1) SELECT use FROM use SELECT recursive FROM t SELECT (ROW_NUMBER() OVER (PARTITION BY user ORDER BY date ASC) - ROW_NUMBER() OVER (PARTITION BY user, segment ORDER BY date ASC)) AS group_id FROM example_table +CAST(foo AS BPCHAR) diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index a80be17..da9f26d 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -756,6 +756,9 @@ CAST(CAST('2023-01-01' AS TIMESTAMP) AS DATE); COALESCE(CAST(NULL AS DATE), x); COALESCE(CAST(NULL AS DATE), x); +NOT COALESCE(x, 1) = 2 AND y = 3; +(x <> 2 OR x IS NULL) AND y = 3; + -------------------------------------- -- CONCAT -------------------------------------- diff --git a/tests/test_executor.py b/tests/test_executor.py index 78d037a..35935b9 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -94,11 +94,10 @@ class TestExecutor(unittest.TestCase): with self.subTest(f"tpch-h {i + 1}"): sql, _ = self.sqls[i] a = self.cached_execute(sql) - b = pd.DataFrame(table.rows, columns=table.columns) - - # The executor represents NULL values as None, whereas DuckDB represents them as NaN, - # and so the following is done to silence Pandas' "Mismatched null-like values" warnings - b = b.fillna(value=np.nan) + b = pd.DataFrame( + ((np.nan if c is None else c for c in r) for r in table.rows), + columns=table.columns, + ) assert_frame_equal(a, b, check_dtype=False, check_index_type=False) @@ -778,14 +777,24 @@ class TestExecutor(unittest.TestCase): self.assertEqual(result.rows, expected) def test_dict_values(self): - tables = { - "foo": [{"raw": {"name": "Hello, World"}}], - } - result = execute("SELECT raw:name AS name FROM foo", read="snowflake", tables=tables) + tables = {"foo": [{"raw": {"name": "Hello, World", "a": [{"b": 1}]}}]} + result = execute("SELECT raw:name AS name FROM foo", read="snowflake", tables=tables) self.assertEqual(result.columns, ("NAME",)) self.assertEqual(result.rows, [("Hello, World",)]) + result = execute("SELECT raw:a[0].b AS b FROM foo", read="snowflake", tables=tables) + self.assertEqual(result.columns, ("B",)) + self.assertEqual(result.rows, [(1,)]) + + result = execute("SELECT raw:a[1].b AS b FROM foo", read="snowflake", tables=tables) + self.assertEqual(result.columns, ("B",)) + self.assertEqual(result.rows, [(None,)]) + + result = execute("SELECT raw:a[0].c AS c FROM foo", read="snowflake", tables=tables) + self.assertEqual(result.columns, ("C",)) + self.assertEqual(result.rows, [(None,)]) + tables = { '"ITEM"': [ {"id": 1, "attributes": {"flavor": "cherry", "taste": "sweet"}}, diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 4641233..f415ff6 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -1031,3 +1031,11 @@ FROM foo""", query = parse_one("SELECT * FROM foo /* sqlglot.meta x = 1, y = a, z */") self.assertEqual(query.find(exp.Table).meta, {"x": "1", "y": "a", "z": True}) self.assertEqual(query.sql(), "SELECT * FROM foo /* sqlglot.meta x = 1, y = a, z */") + + def test_assert_is(self): + parse_one("x").assert_is(exp.Column) + + with self.assertRaisesRegex( + AssertionError, "x is not <class 'sqlglot.expressions.Identifier'>\." + ): + parse_one("x").assert_is(exp.Identifier) diff --git a/tests/test_jsonpath.py b/tests/test_jsonpath.py index 01cd899..4daf3c1 100644 --- a/tests/test_jsonpath.py +++ b/tests/test_jsonpath.py @@ -2,7 +2,7 @@ import json import os import unittest -from sqlglot import jsonpath +from sqlglot import exp, jsonpath from sqlglot.errors import ParseError, TokenError from tests.helpers import FIXTURES_DIR @@ -11,24 +11,22 @@ class TestJsonpath(unittest.TestCase): maxDiff = None def test_jsonpath(self): + expected_expressions = [ + exp.JSONPathRoot(), + exp.JSONPathKey(this=exp.JSONPathWildcard()), + exp.JSONPathKey(this="a"), + exp.JSONPathSubscript(this=0), + exp.JSONPathKey(this="x"), + exp.JSONPathUnion(expressions=[exp.JSONPathWildcard(), "y", 1]), + exp.JSONPathKey(this="z"), + exp.JSONPathSelector(this=exp.JSONPathFilter(this="(@.a == 'b'), 1:")), + exp.JSONPathSubscript(this=exp.JSONPathSlice(start=1, end=5, step=None)), + exp.JSONPathUnion(expressions=[1, exp.JSONPathFilter(this="@.a")]), + exp.JSONPathSelector(this=exp.JSONPathScript(this="@.x)")), + ] self.assertEqual( jsonpath.parse("$.*.a[0]['x'][*, 'y', 1].z[?(@.a == 'b'), 1:][1:5][1,?@.a][(@.x)]"), - [ - {"kind": "root"}, - {"kind": "child", "value": "*"}, - {"kind": "child", "value": "a"}, - {"kind": "subscript", "value": 0}, - {"kind": "key", "value": "x"}, - {"kind": "union", "value": [{"kind": "wildcard"}, "y", 1]}, - {"kind": "child", "value": "z"}, - {"kind": "selector", "value": {"kind": "filter", "value": "(@.a == 'b'), 1:"}}, - { - "kind": "subscript", - "value": {"end": 5, "kind": "slice", "start": 1, "step": None}, - }, - {"kind": "union", "value": [1, {"kind": "filter", "value": "@.a"}]}, - {"kind": "selector", "value": {"kind": "script", "value": "@.x)"}}, - ], + exp.JSONPath(expressions=expected_expressions), ) def test_identity(self): @@ -38,7 +36,7 @@ class TestJsonpath(unittest.TestCase): ("$[((@.length-1))]", "$[((@.length-1))]"), ): with self.subTest(f"{selector} -> {expected}"): - self.assertEqual(jsonpath.generate(jsonpath.parse(selector)), expected) + self.assertEqual(jsonpath.parse(selector).sql(), f"'{expected}'") def test_cts_file(self): with open(os.path.join(FIXTURES_DIR, "jsonpath", "cts.json")) as file: @@ -46,6 +44,7 @@ class TestJsonpath(unittest.TestCase): # sqlglot json path generator rewrites to a normal form overrides = { + "$.☺": '$["☺"]', """$['a',1]""": """$["a",1]""", """$[*,'a']""": """$[*,"a"]""", """$..['a','d']""": """$..["a","d"]""", @@ -136,5 +135,5 @@ class TestJsonpath(unittest.TestCase): except (ParseError, TokenError): pass else: - nodes = jsonpath.parse(selector) - self.assertEqual(jsonpath.generate(nodes), overrides.get(selector, selector)) + path = jsonpath.parse(selector) + self.assertEqual(path.sql(), f"'{overrides.get(selector, selector)}'") diff --git a/tests/test_parser.py b/tests/test_parser.py index 91fd4c6..c7e1dbe 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -10,7 +10,7 @@ from tests.helpers import assert_logger_contains class TestParser(unittest.TestCase): def test_parse_empty(self): - with self.assertRaises(ParseError) as ctx: + with self.assertRaises(ParseError): parse_one("") def test_parse_into(self): @@ -805,3 +805,37 @@ class TestParser(unittest.TestCase): error_level=ErrorLevel.IGNORE, ) self.assertEqual(ast[0].sql(), "CONCAT_WS()") + + def test_parse_drop_schema(self): + for dialect in [None, "bigquery", "snowflake"]: + with self.subTest(dialect): + ast = parse_one("DROP SCHEMA catalog.schema", dialect=dialect) + self.assertEqual( + ast, + exp.Drop( + this=exp.Table( + this=None, + db=exp.Identifier(this="schema", quoted=False), + catalog=exp.Identifier(this="catalog", quoted=False), + ), + kind="SCHEMA", + ), + ) + self.assertEqual(ast.sql(dialect=dialect), "DROP SCHEMA catalog.schema") + + def test_parse_create_schema(self): + for dialect in [None, "bigquery", "snowflake"]: + with self.subTest(dialect): + ast = parse_one("CREATE SCHEMA catalog.schema", dialect=dialect) + self.assertEqual( + ast, + exp.Create( + this=exp.Table( + this=None, + db=exp.Identifier(this="schema", quoted=False), + catalog=exp.Identifier(this="catalog", quoted=False), + ), + kind="SCHEMA", + ), + ) + self.assertEqual(ast.sql(dialect=dialect), "CREATE SCHEMA catalog.schema") diff --git a/tests/test_serde.py b/tests/test_serde.py index 40d6134..1043fcf 100644 --- a/tests/test_serde.py +++ b/tests/test_serde.py @@ -6,7 +6,8 @@ from sqlglot.optimizer.annotate_types import annotate_types from tests.helpers import load_sql_fixtures -class CustomExpression(exp.Expression): ... +class CustomExpression(exp.Expression): + ... class TestSerDe(unittest.TestCase): |