diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-07 11:39:43 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-07 11:39:43 +0000 |
commit | 341eb1a6bdf0dd5b015e5140d3b068c6fd3f4d87 (patch) | |
tree | 61fb7eca2238fb5d41d3906f4af41de03abd25ea /tests/dialects/test_dialect.py | |
parent | Adding upstream version 17.12.0. (diff) | |
download | sqlglot-341eb1a6bdf0dd5b015e5140d3b068c6fd3f4d87.tar.xz sqlglot-341eb1a6bdf0dd5b015e5140d3b068c6fd3f4d87.zip |
Adding upstream version 18.2.0.upstream/18.2.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/dialects/test_dialect.py')
-rw-r--r-- | tests/dialects/test_dialect.py | 116 |
1 files changed, 97 insertions, 19 deletions
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 63f789f..6a41218 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1,6 +1,13 @@ import unittest -from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one +from sqlglot import ( + Dialect, + Dialects, + ErrorLevel, + ParseError, + UnsupportedError, + parse_one, +) from sqlglot.dialects import Hive @@ -23,9 +30,10 @@ class Validator(unittest.TestCase): Args: sql (str): Main SQL expression - dialect (str): dialect of `sql` read (dict): Mapping of dialect -> SQL write (dict): Mapping of dialect -> SQL + pretty (bool): prettify both read and write + identify (bool): quote identifiers in both read and write """ expression = self.parse_one(sql) @@ -78,7 +86,7 @@ class TestDialect(Validator): "CAST(a AS TEXT)", write={ "bigquery": "CAST(a AS STRING)", - "clickhouse": "CAST(a AS TEXT)", + "clickhouse": "CAST(a AS String)", "drill": "CAST(a AS VARCHAR)", "duckdb": "CAST(a AS TEXT)", "mysql": "CAST(a AS CHAR)", @@ -116,7 +124,7 @@ class TestDialect(Validator): "CAST(a AS VARBINARY(4))", write={ "bigquery": "CAST(a AS BYTES)", - "clickhouse": "CAST(a AS VARBINARY(4))", + "clickhouse": "CAST(a AS String)", "duckdb": "CAST(a AS BLOB(4))", "mysql": "CAST(a AS VARBINARY(4))", "hive": "CAST(a AS BINARY(4))", @@ -133,7 +141,7 @@ class TestDialect(Validator): self.validate_all( "CAST(MAP('a', '1') AS MAP(TEXT, TEXT))", write={ - "clickhouse": "CAST(map('a', '1') AS Map(TEXT, TEXT))", + "clickhouse": "CAST(map('a', '1') AS Map(String, String))", }, ) self.validate_all( @@ -367,6 +375,60 @@ class TestDialect(Validator): }, ) + def test_nvl2(self): + self.validate_all( + "SELECT NVL2(a, b, c)", + write={ + "": "SELECT NVL2(a, b, c)", + "bigquery": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "clickhouse": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "databricks": "SELECT NVL2(a, b, c)", + "doris": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "drill": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "duckdb": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "hive": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "mysql": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "oracle": "SELECT NVL2(a, b, c)", + "postgres": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "presto": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "redshift": "SELECT NVL2(a, b, c)", + "snowflake": "SELECT NVL2(a, b, c)", + "spark": "SELECT NVL2(a, b, c)", + "spark2": "SELECT NVL2(a, b, c)", + "sqlite": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "starrocks": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "teradata": "SELECT NVL2(a, b, c)", + "trino": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "tsql": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + }, + ) + self.validate_all( + "SELECT NVL2(a, b)", + write={ + "": "SELECT NVL2(a, b)", + "bigquery": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "clickhouse": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "databricks": "SELECT NVL2(a, b)", + "doris": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "drill": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "duckdb": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "hive": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "mysql": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "oracle": "SELECT NVL2(a, b)", + "postgres": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "presto": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "redshift": "SELECT NVL2(a, b)", + "snowflake": "SELECT NVL2(a, b)", + "spark": "SELECT NVL2(a, b)", + "spark2": "SELECT NVL2(a, b)", + "sqlite": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "starrocks": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "teradata": "SELECT NVL2(a, b)", + "trino": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "tsql": "SELECT CASE WHEN NOT a IS NULL THEN b END", + }, + ) + def test_time(self): self.validate_all( "STR_TO_TIME(x, '%Y-%m-%dT%H:%M:%S')", @@ -860,7 +922,7 @@ class TestDialect(Validator): "ARRAY(0, 1, 2)", write={ "bigquery": "[0, 1, 2]", - "duckdb": "LIST_VALUE(0, 1, 2)", + "duckdb": "[0, 1, 2]", "presto": "ARRAY[0, 1, 2]", "spark": "ARRAY(0, 1, 2)", }, @@ -879,7 +941,7 @@ class TestDialect(Validator): "ARRAY_SUM(ARRAY(1, 2))", write={ "trino": "REDUCE(ARRAY[1, 2], 0, (acc, x) -> acc + x, acc -> acc)", - "duckdb": "LIST_SUM(LIST_VALUE(1, 2))", + "duckdb": "LIST_SUM([1, 2])", "hive": "ARRAY_SUM(ARRAY(1, 2))", "presto": "ARRAY_SUM(ARRAY[1, 2])", "spark": "AGGREGATE(ARRAY(1, 2), 0, (acc, x) -> acc + x, acc -> acc)", @@ -1403,27 +1465,27 @@ class TestDialect(Validator): }, ) self.validate_all( - "CREATE INDEX my_idx ON tbl (a, b)", + "CREATE INDEX my_idx ON tbl(a, b)", read={ - "hive": "CREATE INDEX my_idx ON TABLE tbl (a, b)", - "sqlite": "CREATE INDEX my_idx ON tbl (a, b)", + "hive": "CREATE INDEX my_idx ON TABLE tbl(a, b)", + "sqlite": "CREATE INDEX my_idx ON tbl(a, b)", }, write={ - "hive": "CREATE INDEX my_idx ON TABLE tbl (a, b)", - "postgres": "CREATE INDEX my_idx ON tbl (a NULLS FIRST, b NULLS FIRST)", - "sqlite": "CREATE INDEX my_idx ON tbl (a, b)", + "hive": "CREATE INDEX my_idx ON TABLE tbl(a, b)", + "postgres": "CREATE INDEX my_idx ON tbl(a NULLS FIRST, b NULLS FIRST)", + "sqlite": "CREATE INDEX my_idx ON tbl(a, b)", }, ) self.validate_all( - "CREATE UNIQUE INDEX my_idx ON tbl (a, b)", + "CREATE UNIQUE INDEX my_idx ON tbl(a, b)", read={ - "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl (a, b)", - "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)", + "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl(a, b)", + "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl(a, b)", }, write={ - "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl (a, b)", - "postgres": "CREATE UNIQUE INDEX my_idx ON tbl (a NULLS FIRST, b NULLS FIRST)", - "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)", + "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl(a, b)", + "postgres": "CREATE UNIQUE INDEX my_idx ON tbl(a NULLS FIRST, b NULLS FIRST)", + "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl(a, b)", }, ) self.validate_all( @@ -1710,3 +1772,19 @@ SELECT "tsql": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", }, ) + + def test_cast_to_user_defined_type(self): + self.validate_all( + "CAST(x AS some_udt)", + write={ + "": "CAST(x AS some_udt)", + "oracle": "CAST(x AS some_udt)", + "postgres": "CAST(x AS some_udt)", + "presto": "CAST(x AS some_udt)", + "teradata": "CAST(x AS some_udt)", + "tsql": "CAST(x AS some_udt)", + }, + ) + + with self.assertRaises(ParseError): + parse_one("CAST(x AS some_udt)", read="bigquery") |