summaryrefslogtreecommitdiffstats
path: root/tests/dialects/test_dialect.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dialects/test_dialect.py')
-rw-r--r--tests/dialects/test_dialect.py219
1 files changed, 190 insertions, 29 deletions
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 2546c98..49afc62 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -7,9 +7,10 @@ from sqlglot import (
ParseError,
TokenError,
UnsupportedError,
+ exp,
parse_one,
)
-from sqlglot.dialects import Hive
+from sqlglot.dialects import BigQuery, Hive, Snowflake
class Validator(unittest.TestCase):
@@ -78,9 +79,56 @@ class TestDialect(Validator):
self.assertIsNotNone(Dialect[dialect.value])
def test_get_or_raise(self):
- self.assertEqual(Dialect.get_or_raise(Hive), Hive)
- self.assertEqual(Dialect.get_or_raise(Hive()), Hive)
- self.assertEqual(Dialect.get_or_raise("hive"), Hive)
+ self.assertIsInstance(Dialect.get_or_raise(Hive), Hive)
+ self.assertIsInstance(Dialect.get_or_raise(Hive()), Hive)
+ self.assertIsInstance(Dialect.get_or_raise("hive"), Hive)
+
+ with self.assertRaises(ValueError):
+ Dialect.get_or_raise(1)
+
+ default_mysql = Dialect.get_or_raise("mysql")
+ self.assertEqual(default_mysql.normalization_strategy, "CASE_SENSITIVE")
+
+ lowercase_mysql = Dialect.get_or_raise("mysql,normalization_strategy=lowercase")
+ self.assertEqual(lowercase_mysql.normalization_strategy, "LOWERCASE")
+
+ lowercase_mysql = Dialect.get_or_raise("mysql, normalization_strategy = lowercase")
+ self.assertEqual(lowercase_mysql.normalization_strategy.value, "LOWERCASE")
+
+ with self.assertRaises(ValueError) as cm:
+ Dialect.get_or_raise("mysql, normalization_strategy")
+
+ self.assertEqual(
+ str(cm.exception),
+ "Invalid dialect format: 'mysql, normalization_strategy'. "
+ "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'.",
+ )
+
+ def test_compare_dialects(self):
+ bigquery_class = Dialect["bigquery"]
+ bigquery_object = BigQuery()
+ bigquery_string = "bigquery"
+
+ snowflake_class = Dialect["snowflake"]
+ snowflake_object = Snowflake()
+ snowflake_string = "snowflake"
+
+ self.assertEqual(snowflake_class, snowflake_class)
+ self.assertEqual(snowflake_class, snowflake_object)
+ self.assertEqual(snowflake_class, snowflake_string)
+ self.assertEqual(snowflake_object, snowflake_object)
+ self.assertEqual(snowflake_object, snowflake_string)
+
+ self.assertNotEqual(snowflake_class, bigquery_class)
+ self.assertNotEqual(snowflake_class, bigquery_object)
+ self.assertNotEqual(snowflake_class, bigquery_string)
+ self.assertNotEqual(snowflake_object, bigquery_object)
+ self.assertNotEqual(snowflake_object, bigquery_string)
+
+ self.assertTrue(snowflake_class in {"snowflake", "bigquery"})
+ self.assertTrue(snowflake_object in {"snowflake", "bigquery"})
+ self.assertFalse(snowflake_class in {"bigquery", "redshift"})
+ self.assertFalse(snowflake_object in {"bigquery", "redshift"})
def test_cast(self):
self.validate_all(
@@ -561,6 +609,7 @@ class TestDialect(Validator):
self.validate_all(
"TIME_TO_STR(x, '%Y-%m-%d')",
write={
+ "bigquery": "FORMAT_DATE('%Y-%m-%d', x)",
"drill": "TO_CHAR(x, 'yyyy-MM-dd')",
"duckdb": "STRFTIME(x, '%Y-%m-%d')",
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
@@ -866,9 +915,9 @@ class TestDialect(Validator):
write={
"drill": "CAST(x AS DATE)",
"duckdb": "CAST(x AS DATE)",
- "hive": "TO_DATE(x)",
- "presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)",
- "spark": "TO_DATE(x)",
+ "hive": "CAST(x AS DATE)",
+ "presto": "CAST(x AS DATE)",
+ "spark": "CAST(x AS DATE)",
"sqlite": "x",
},
)
@@ -893,7 +942,7 @@ class TestDialect(Validator):
self.validate_all(
"TS_OR_DS_ADD(CURRENT_DATE, 1, 'DAY')",
write={
- "presto": "DATE_ADD('DAY', 1, CURRENT_DATE)",
+ "presto": "DATE_ADD('DAY', 1, CAST(CAST(CURRENT_DATE AS TIMESTAMP) AS DATE))",
"hive": "DATE_ADD(CURRENT_DATE, 1)",
},
)
@@ -1269,13 +1318,6 @@ class TestDialect(Validator):
},
)
self.validate_all(
- "SELECT * FROM a ORDER BY col_a NULLS LAST",
- write={
- "mysql": UnsupportedError,
- "starrocks": UnsupportedError,
- },
- )
- self.validate_all(
"POSITION(needle in haystack)",
write={
"drill": "STRPOS(haystack, needle)",
@@ -1315,35 +1357,37 @@ class TestDialect(Validator):
self.validate_all(
"CONCAT_WS('-', 'a', 'b')",
write={
+ "clickhouse": "CONCAT_WS('-', 'a', 'b')",
"duckdb": "CONCAT_WS('-', 'a', 'b')",
- "presto": "CONCAT_WS('-', 'a', 'b')",
+ "presto": "CONCAT_WS('-', CAST('a' AS VARCHAR), CAST('b' AS VARCHAR))",
"hive": "CONCAT_WS('-', 'a', 'b')",
"spark": "CONCAT_WS('-', 'a', 'b')",
- "trino": "CONCAT_WS('-', 'a', 'b')",
+ "trino": "CONCAT_WS('-', CAST('a' AS VARCHAR), CAST('b' AS VARCHAR))",
},
)
self.validate_all(
"CONCAT_WS('-', x)",
write={
+ "clickhouse": "CONCAT_WS('-', x)",
"duckdb": "CONCAT_WS('-', x)",
"hive": "CONCAT_WS('-', x)",
- "presto": "CONCAT_WS('-', x)",
+ "presto": "CONCAT_WS('-', CAST(x AS VARCHAR))",
"spark": "CONCAT_WS('-', x)",
- "trino": "CONCAT_WS('-', x)",
+ "trino": "CONCAT_WS('-', CAST(x AS VARCHAR))",
},
)
self.validate_all(
"CONCAT(a)",
write={
- "clickhouse": "a",
- "presto": "a",
- "trino": "a",
+ "clickhouse": "CONCAT(a)",
+ "presto": "CAST(a AS VARCHAR)",
+ "trino": "CAST(a AS VARCHAR)",
"tsql": "a",
},
)
self.validate_all(
- "COALESCE(CAST(a AS TEXT), '')",
+ "CONCAT(COALESCE(a, ''))",
read={
"drill": "CONCAT(a)",
"duckdb": "CONCAT(a)",
@@ -1442,6 +1486,76 @@ class TestDialect(Validator):
"spark": "FILTER(the_array, x -> x > 0)",
},
)
+ self.validate_all(
+ "a / b",
+ write={
+ "bigquery": "a / b",
+ "clickhouse": "a / b",
+ "databricks": "a / b",
+ "duckdb": "a / b",
+ "hive": "a / b",
+ "mysql": "a / b",
+ "oracle": "a / b",
+ "snowflake": "a / b",
+ "spark": "a / b",
+ "starrocks": "a / b",
+ "drill": "CAST(a AS DOUBLE) / b",
+ "postgres": "CAST(a AS DOUBLE PRECISION) / b",
+ "presto": "CAST(a AS DOUBLE) / b",
+ "redshift": "CAST(a AS DOUBLE PRECISION) / b",
+ "sqlite": "CAST(a AS REAL) / b",
+ "teradata": "CAST(a AS DOUBLE) / b",
+ "trino": "CAST(a AS DOUBLE) / b",
+ "tsql": "CAST(a AS FLOAT) / b",
+ },
+ )
+
+ def test_typeddiv(self):
+ typed_div = exp.Div(this=exp.column("a"), expression=exp.column("b"), typed=True)
+ div = exp.Div(this=exp.column("a"), expression=exp.column("b"))
+ typed_div_dialect = "presto"
+ div_dialect = "hive"
+ INT = exp.DataType.Type.INT
+ FLOAT = exp.DataType.Type.FLOAT
+
+ for expression, types, dialect, expected in [
+ (typed_div, (None, None), typed_div_dialect, "a / b"),
+ (typed_div, (None, None), div_dialect, "a / b"),
+ (div, (None, None), typed_div_dialect, "CAST(a AS DOUBLE) / b"),
+ (div, (None, None), div_dialect, "a / b"),
+ (typed_div, (INT, INT), typed_div_dialect, "a / b"),
+ (typed_div, (INT, INT), div_dialect, "CAST(a / b AS BIGINT)"),
+ (div, (INT, INT), typed_div_dialect, "CAST(a AS DOUBLE) / b"),
+ (div, (INT, INT), div_dialect, "a / b"),
+ (typed_div, (FLOAT, FLOAT), typed_div_dialect, "a / b"),
+ (typed_div, (FLOAT, FLOAT), div_dialect, "a / b"),
+ (div, (FLOAT, FLOAT), typed_div_dialect, "a / b"),
+ (div, (FLOAT, FLOAT), div_dialect, "a / b"),
+ (typed_div, (INT, FLOAT), typed_div_dialect, "a / b"),
+ (typed_div, (INT, FLOAT), div_dialect, "a / b"),
+ (div, (INT, FLOAT), typed_div_dialect, "a / b"),
+ (div, (INT, FLOAT), div_dialect, "a / b"),
+ ]:
+ with self.subTest(f"{expression.__class__.__name__} {types} {dialect} -> {expected}"):
+ expression = expression.copy()
+ expression.left.type = types[0]
+ expression.right.type = types[1]
+ self.assertEqual(expected, expression.sql(dialect=dialect))
+
+ def test_safediv(self):
+ safe_div = exp.Div(this=exp.column("a"), expression=exp.column("b"), safe=True)
+ div = exp.Div(this=exp.column("a"), expression=exp.column("b"))
+ safe_div_dialect = "mysql"
+ div_dialect = "snowflake"
+
+ for expression, dialect, expected in [
+ (safe_div, safe_div_dialect, "a / b"),
+ (safe_div, div_dialect, "a / NULLIF(b, 0)"),
+ (div, safe_div_dialect, "a / b"),
+ (div, div_dialect, "a / b"),
+ ]:
+ with self.subTest(f"{expression.__class__.__name__} {dialect} -> {expected}"):
+ self.assertEqual(expected, expression.sql(dialect=dialect))
def test_limit(self):
self.validate_all(
@@ -1547,7 +1661,7 @@ class TestDialect(Validator):
"CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 TEXT, c2 TEXT(1024))",
write={
"duckdb": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))",
- "hive": "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 STRING, c2 STRING(1024))",
+ "hive": "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 STRING, c2 VARCHAR(1024))",
"oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))",
"postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))",
"sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))",
@@ -1864,7 +1978,7 @@ SELECT
write={
"bigquery": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
"clickhouse": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
- "databricks": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
+ "databricks": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq",
"duckdb": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
"hive": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq",
"mysql": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
@@ -1872,11 +1986,11 @@ SELECT
"presto": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
"redshift": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
"snowflake": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
- "spark": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
+ "spark": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq",
"spark2": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq",
"sqlite": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
"trino": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq",
- "tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq",
+ "tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c AS c FROM t) AS subq",
},
)
self.validate_all(
@@ -1885,13 +1999,60 @@ SELECT
"bigquery": "SELECT * FROM (SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq1) AS subq2",
"duckdb": "SELECT * FROM (SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq1) AS subq2",
"hive": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT * FROM (SELECT c FROM t) AS subq1) AS subq2",
- "tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT * FROM (SELECT c FROM t) AS subq1) AS subq2",
+ "tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT * FROM (SELECT c AS c FROM t) AS subq1) AS subq2",
},
)
self.validate_all(
"WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq",
write={
"duckdb": "WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq",
- "tsql": "WITH t1(x) AS (SELECT 1), t2(y) AS (SELECT 2) SELECT * FROM (SELECT y FROM t2) AS subq",
+ "tsql": "WITH t1(x) AS (SELECT 1), t2(y) AS (SELECT 2) SELECT * FROM (SELECT y AS y FROM t2) AS subq",
},
)
+
+ def test_unsupported_null_ordering(self):
+ # We'll transpile a portable query from the following dialects to MySQL / T-SQL, which
+ # both treat NULLs as small values, so the expected output queries should be equivalent
+ with_last_nulls = "duckdb"
+ with_small_nulls = "spark"
+ with_large_nulls = "postgres"
+
+ sql = "SELECT * FROM t ORDER BY c"
+ sql_nulls_last = "SELECT * FROM t ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END, c"
+ sql_nulls_first = "SELECT * FROM t ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c"
+
+ for read_dialect, desc, nulls_first, expected_sql in (
+ (with_last_nulls, False, None, sql_nulls_last),
+ (with_last_nulls, True, None, sql),
+ (with_last_nulls, False, True, sql),
+ (with_last_nulls, True, True, sql_nulls_first),
+ (with_last_nulls, False, False, sql_nulls_last),
+ (with_last_nulls, True, False, sql),
+ (with_small_nulls, False, None, sql),
+ (with_small_nulls, True, None, sql),
+ (with_small_nulls, False, True, sql),
+ (with_small_nulls, True, True, sql_nulls_first),
+ (with_small_nulls, False, False, sql_nulls_last),
+ (with_small_nulls, True, False, sql),
+ (with_large_nulls, False, None, sql_nulls_last),
+ (with_large_nulls, True, None, sql_nulls_first),
+ (with_large_nulls, False, True, sql),
+ (with_large_nulls, True, True, sql_nulls_first),
+ (with_large_nulls, False, False, sql_nulls_last),
+ (with_large_nulls, True, False, sql),
+ ):
+ with self.subTest(
+ f"read: {read_dialect}, descending: {desc}, nulls first: {nulls_first}"
+ ):
+ sort_order = " DESC" if desc else ""
+ null_order = (
+ " NULLS FIRST"
+ if nulls_first
+ else (" NULLS LAST" if nulls_first is not None else "")
+ )
+
+ expected_sql = f"{expected_sql}{sort_order}"
+ expression = parse_one(f"{sql}{sort_order}{null_order}", read=read_dialect)
+
+ self.assertEqual(expression.sql(dialect="mysql"), expected_sql)
+ self.assertEqual(expression.sql(dialect="tsql"), expected_sql)