From 8fe30fd23dc37ec3516e530a86d1c4b604e71241 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 10 Dec 2023 11:46:01 +0100 Subject: Merging upstream version 20.1.0. Signed-off-by: Daniel Baumann --- tests/dialects/test_dialect.py | 219 +++++++++++++++++++++++++++++++++++------ 1 file changed, 190 insertions(+), 29 deletions(-) (limited to 'tests/dialects/test_dialect.py') diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 2546c98..49afc62 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -7,9 +7,10 @@ from sqlglot import ( ParseError, TokenError, UnsupportedError, + exp, parse_one, ) -from sqlglot.dialects import Hive +from sqlglot.dialects import BigQuery, Hive, Snowflake class Validator(unittest.TestCase): @@ -78,9 +79,56 @@ class TestDialect(Validator): self.assertIsNotNone(Dialect[dialect.value]) def test_get_or_raise(self): - self.assertEqual(Dialect.get_or_raise(Hive), Hive) - self.assertEqual(Dialect.get_or_raise(Hive()), Hive) - self.assertEqual(Dialect.get_or_raise("hive"), Hive) + self.assertIsInstance(Dialect.get_or_raise(Hive), Hive) + self.assertIsInstance(Dialect.get_or_raise(Hive()), Hive) + self.assertIsInstance(Dialect.get_or_raise("hive"), Hive) + + with self.assertRaises(ValueError): + Dialect.get_or_raise(1) + + default_mysql = Dialect.get_or_raise("mysql") + self.assertEqual(default_mysql.normalization_strategy, "CASE_SENSITIVE") + + lowercase_mysql = Dialect.get_or_raise("mysql,normalization_strategy=lowercase") + self.assertEqual(lowercase_mysql.normalization_strategy, "LOWERCASE") + + lowercase_mysql = Dialect.get_or_raise("mysql, normalization_strategy = lowercase") + self.assertEqual(lowercase_mysql.normalization_strategy.value, "LOWERCASE") + + with self.assertRaises(ValueError) as cm: + Dialect.get_or_raise("mysql, normalization_strategy") + + self.assertEqual( + str(cm.exception), + "Invalid dialect format: 'mysql, normalization_strategy'. " + "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'.", + ) + + def test_compare_dialects(self): + bigquery_class = Dialect["bigquery"] + bigquery_object = BigQuery() + bigquery_string = "bigquery" + + snowflake_class = Dialect["snowflake"] + snowflake_object = Snowflake() + snowflake_string = "snowflake" + + self.assertEqual(snowflake_class, snowflake_class) + self.assertEqual(snowflake_class, snowflake_object) + self.assertEqual(snowflake_class, snowflake_string) + self.assertEqual(snowflake_object, snowflake_object) + self.assertEqual(snowflake_object, snowflake_string) + + self.assertNotEqual(snowflake_class, bigquery_class) + self.assertNotEqual(snowflake_class, bigquery_object) + self.assertNotEqual(snowflake_class, bigquery_string) + self.assertNotEqual(snowflake_object, bigquery_object) + self.assertNotEqual(snowflake_object, bigquery_string) + + self.assertTrue(snowflake_class in {"snowflake", "bigquery"}) + self.assertTrue(snowflake_object in {"snowflake", "bigquery"}) + self.assertFalse(snowflake_class in {"bigquery", "redshift"}) + self.assertFalse(snowflake_object in {"bigquery", "redshift"}) def test_cast(self): self.validate_all( @@ -561,6 +609,7 @@ class TestDialect(Validator): self.validate_all( "TIME_TO_STR(x, '%Y-%m-%d')", write={ + "bigquery": "FORMAT_DATE('%Y-%m-%d', x)", "drill": "TO_CHAR(x, 'yyyy-MM-dd')", "duckdb": "STRFTIME(x, '%Y-%m-%d')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd')", @@ -866,9 +915,9 @@ class TestDialect(Validator): write={ "drill": "CAST(x AS DATE)", "duckdb": "CAST(x AS DATE)", - "hive": "TO_DATE(x)", - "presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)", - "spark": "TO_DATE(x)", + "hive": "CAST(x AS DATE)", + "presto": "CAST(x AS DATE)", + "spark": "CAST(x AS DATE)", "sqlite": "x", }, ) @@ -893,7 +942,7 @@ class TestDialect(Validator): self.validate_all( "TS_OR_DS_ADD(CURRENT_DATE, 1, 'DAY')", write={ - "presto": "DATE_ADD('DAY', 1, CURRENT_DATE)", + "presto": "DATE_ADD('DAY', 1, CAST(CAST(CURRENT_DATE AS TIMESTAMP) AS DATE))", "hive": "DATE_ADD(CURRENT_DATE, 1)", }, ) @@ -1268,13 +1317,6 @@ class TestDialect(Validator): "doris": "LOWER(x) LIKE '%y'", }, ) - self.validate_all( - "SELECT * FROM a ORDER BY col_a NULLS LAST", - write={ - "mysql": UnsupportedError, - "starrocks": UnsupportedError, - }, - ) self.validate_all( "POSITION(needle in haystack)", write={ @@ -1315,35 +1357,37 @@ class TestDialect(Validator): self.validate_all( "CONCAT_WS('-', 'a', 'b')", write={ + "clickhouse": "CONCAT_WS('-', 'a', 'b')", "duckdb": "CONCAT_WS('-', 'a', 'b')", - "presto": "CONCAT_WS('-', 'a', 'b')", + "presto": "CONCAT_WS('-', CAST('a' AS VARCHAR), CAST('b' AS VARCHAR))", "hive": "CONCAT_WS('-', 'a', 'b')", "spark": "CONCAT_WS('-', 'a', 'b')", - "trino": "CONCAT_WS('-', 'a', 'b')", + "trino": "CONCAT_WS('-', CAST('a' AS VARCHAR), CAST('b' AS VARCHAR))", }, ) self.validate_all( "CONCAT_WS('-', x)", write={ + "clickhouse": "CONCAT_WS('-', x)", "duckdb": "CONCAT_WS('-', x)", "hive": "CONCAT_WS('-', x)", - "presto": "CONCAT_WS('-', x)", + "presto": "CONCAT_WS('-', CAST(x AS VARCHAR))", "spark": "CONCAT_WS('-', x)", - "trino": "CONCAT_WS('-', x)", + "trino": "CONCAT_WS('-', CAST(x AS VARCHAR))", }, ) self.validate_all( "CONCAT(a)", write={ - "clickhouse": "a", - "presto": "a", - "trino": "a", + "clickhouse": "CONCAT(a)", + "presto": "CAST(a AS VARCHAR)", + "trino": "CAST(a AS VARCHAR)", "tsql": "a", }, ) self.validate_all( - "COALESCE(CAST(a AS TEXT), '')", + "CONCAT(COALESCE(a, ''))", read={ "drill": "CONCAT(a)", "duckdb": "CONCAT(a)", @@ -1442,6 +1486,76 @@ class TestDialect(Validator): "spark": "FILTER(the_array, x -> x > 0)", }, ) + self.validate_all( + "a / b", + write={ + "bigquery": "a / b", + "clickhouse": "a / b", + "databricks": "a / b", + "duckdb": "a / b", + "hive": "a / b", + "mysql": "a / b", + "oracle": "a / b", + "snowflake": "a / b", + "spark": "a / b", + "starrocks": "a / b", + "drill": "CAST(a AS DOUBLE) / b", + "postgres": "CAST(a AS DOUBLE PRECISION) / b", + "presto": "CAST(a AS DOUBLE) / b", + "redshift": "CAST(a AS DOUBLE PRECISION) / b", + "sqlite": "CAST(a AS REAL) / b", + "teradata": "CAST(a AS DOUBLE) / b", + "trino": "CAST(a AS DOUBLE) / b", + "tsql": "CAST(a AS FLOAT) / b", + }, + ) + + def test_typeddiv(self): + typed_div = exp.Div(this=exp.column("a"), expression=exp.column("b"), typed=True) + div = exp.Div(this=exp.column("a"), expression=exp.column("b")) + typed_div_dialect = "presto" + div_dialect = "hive" + INT = exp.DataType.Type.INT + FLOAT = exp.DataType.Type.FLOAT + + for expression, types, dialect, expected in [ + (typed_div, (None, None), typed_div_dialect, "a / b"), + (typed_div, (None, None), div_dialect, "a / b"), + (div, (None, None), typed_div_dialect, "CAST(a AS DOUBLE) / b"), + (div, (None, None), div_dialect, "a / b"), + (typed_div, (INT, INT), typed_div_dialect, "a / b"), + (typed_div, (INT, INT), div_dialect, "CAST(a / b AS BIGINT)"), + (div, (INT, INT), typed_div_dialect, "CAST(a AS DOUBLE) / b"), + (div, (INT, INT), div_dialect, "a / b"), + (typed_div, (FLOAT, FLOAT), typed_div_dialect, "a / b"), + (typed_div, (FLOAT, FLOAT), div_dialect, "a / b"), + (div, (FLOAT, FLOAT), typed_div_dialect, "a / b"), + (div, (FLOAT, FLOAT), div_dialect, "a / b"), + (typed_div, (INT, FLOAT), typed_div_dialect, "a / b"), + (typed_div, (INT, FLOAT), div_dialect, "a / b"), + (div, (INT, FLOAT), typed_div_dialect, "a / b"), + (div, (INT, FLOAT), div_dialect, "a / b"), + ]: + with self.subTest(f"{expression.__class__.__name__} {types} {dialect} -> {expected}"): + expression = expression.copy() + expression.left.type = types[0] + expression.right.type = types[1] + self.assertEqual(expected, expression.sql(dialect=dialect)) + + def test_safediv(self): + safe_div = exp.Div(this=exp.column("a"), expression=exp.column("b"), safe=True) + div = exp.Div(this=exp.column("a"), expression=exp.column("b")) + safe_div_dialect = "mysql" + div_dialect = "snowflake" + + for expression, dialect, expected in [ + (safe_div, safe_div_dialect, "a / b"), + (safe_div, div_dialect, "a / NULLIF(b, 0)"), + (div, safe_div_dialect, "a / b"), + (div, div_dialect, "a / b"), + ]: + with self.subTest(f"{expression.__class__.__name__} {dialect} -> {expected}"): + self.assertEqual(expected, expression.sql(dialect=dialect)) def test_limit(self): self.validate_all( @@ -1547,7 +1661,7 @@ class TestDialect(Validator): "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 TEXT, c2 TEXT(1024))", write={ "duckdb": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))", - "hive": "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 STRING, c2 STRING(1024))", + "hive": "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 STRING, c2 VARCHAR(1024))", "oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))", "postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))", "sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))", @@ -1864,7 +1978,7 @@ SELECT write={ "bigquery": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", "clickhouse": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", - "databricks": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", + "databricks": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq", "duckdb": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", "hive": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq", "mysql": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", @@ -1872,11 +1986,11 @@ SELECT "presto": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", "redshift": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", "snowflake": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", - "spark": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", + "spark": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq", "spark2": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq", "sqlite": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", "trino": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", - "tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq", + "tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c AS c FROM t) AS subq", }, ) self.validate_all( @@ -1885,13 +1999,60 @@ SELECT "bigquery": "SELECT * FROM (SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq1) AS subq2", "duckdb": "SELECT * FROM (SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq1) AS subq2", "hive": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT * FROM (SELECT c FROM t) AS subq1) AS subq2", - "tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT * FROM (SELECT c FROM t) AS subq1) AS subq2", + "tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT * FROM (SELECT c AS c FROM t) AS subq1) AS subq2", }, ) self.validate_all( "WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq", write={ "duckdb": "WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq", - "tsql": "WITH t1(x) AS (SELECT 1), t2(y) AS (SELECT 2) SELECT * FROM (SELECT y FROM t2) AS subq", + "tsql": "WITH t1(x) AS (SELECT 1), t2(y) AS (SELECT 2) SELECT * FROM (SELECT y AS y FROM t2) AS subq", }, ) + + def test_unsupported_null_ordering(self): + # We'll transpile a portable query from the following dialects to MySQL / T-SQL, which + # both treat NULLs as small values, so the expected output queries should be equivalent + with_last_nulls = "duckdb" + with_small_nulls = "spark" + with_large_nulls = "postgres" + + sql = "SELECT * FROM t ORDER BY c" + sql_nulls_last = "SELECT * FROM t ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END, c" + sql_nulls_first = "SELECT * FROM t ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c" + + for read_dialect, desc, nulls_first, expected_sql in ( + (with_last_nulls, False, None, sql_nulls_last), + (with_last_nulls, True, None, sql), + (with_last_nulls, False, True, sql), + (with_last_nulls, True, True, sql_nulls_first), + (with_last_nulls, False, False, sql_nulls_last), + (with_last_nulls, True, False, sql), + (with_small_nulls, False, None, sql), + (with_small_nulls, True, None, sql), + (with_small_nulls, False, True, sql), + (with_small_nulls, True, True, sql_nulls_first), + (with_small_nulls, False, False, sql_nulls_last), + (with_small_nulls, True, False, sql), + (with_large_nulls, False, None, sql_nulls_last), + (with_large_nulls, True, None, sql_nulls_first), + (with_large_nulls, False, True, sql), + (with_large_nulls, True, True, sql_nulls_first), + (with_large_nulls, False, False, sql_nulls_last), + (with_large_nulls, True, False, sql), + ): + with self.subTest( + f"read: {read_dialect}, descending: {desc}, nulls first: {nulls_first}" + ): + sort_order = " DESC" if desc else "" + null_order = ( + " NULLS FIRST" + if nulls_first + else (" NULLS LAST" if nulls_first is not None else "") + ) + + expected_sql = f"{expected_sql}{sort_order}" + expression = parse_one(f"{sql}{sort_order}{null_order}", read=read_dialect) + + self.assertEqual(expression.sql(dialect="mysql"), expected_sql) + self.assertEqual(expression.sql(dialect="tsql"), expected_sql) -- cgit v1.2.3