diff options
Diffstat (limited to '')
-rw-r--r-- | tests/dialects/test_bigquery.py | 12 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 5 | ||||
-rw-r--r-- | tests/dialects/test_databricks.py | 3 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 156 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 294 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 45 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 117 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_starrocks.py | 3 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 29 |
11 files changed, 607 insertions, 65 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 050d41e..a0ebc45 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -157,6 +157,14 @@ class TestBigQuery(Validator): }, ) + self.validate_all( + "DIV(x, y)", + write={ + "bigquery": "DIV(x, y)", + "duckdb": "CAST(x / y AS INT)", + }, + ) + self.validate_identity( "SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)" ) @@ -284,4 +292,6 @@ class TestBigQuery(Validator): "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'" ) self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)") - self.validate_identity("CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t") + self.validate_identity( + "CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t" + ) diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 715bf10..efb41bb 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -18,7 +18,6 @@ class TestClickhouse(Validator): "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", }, ) - self.validate_all( "CAST(1 AS NULLABLE(Int64))", write={ @@ -31,3 +30,7 @@ class TestClickhouse(Validator): "clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))", }, ) + self.validate_all( + "SELECT x #! comment", + write={"": "SELECT x /* comment */"}, + ) diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index e242e73..2168f55 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -22,7 +22,8 @@ class TestDatabricks(Validator): }, ) self.validate_all( - "SELECT DATEDIFF('end', 'start')", write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"} + "SELECT DATEDIFF('end', 'start')", + write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"}, ) self.validate_all( "SELECT DATE_ADD('2020-01-01', 1)", diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 3b837df..1913f53 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1,20 +1,18 @@ import unittest -from sqlglot import ( - Dialect, - Dialects, - ErrorLevel, - UnsupportedError, - parse_one, - transpile, -) +from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one class Validator(unittest.TestCase): dialect = None - def validate_identity(self, sql): - self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql) + def parse_one(self, sql): + return parse_one(sql, read=self.dialect) + + def validate_identity(self, sql, write_sql=None): + expression = self.parse_one(sql) + self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect)) + return expression def validate_all(self, sql, read=None, write=None, pretty=False): """ @@ -28,12 +26,14 @@ class Validator(unittest.TestCase): read (dict): Mapping of dialect -> SQL write (dict): Mapping of dialect -> SQL """ - expression = parse_one(sql, read=self.dialect) + expression = self.parse_one(sql) for read_dialect, read_sql in (read or {}).items(): with self.subTest(f"{read_dialect} -> {sql}"): self.assertEqual( - parse_one(read_sql, read_dialect).sql(self.dialect, unsupported_level=ErrorLevel.IGNORE), + parse_one(read_sql, read_dialect).sql( + self.dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty + ), sql, ) @@ -83,10 +83,6 @@ class TestDialect(Validator): ) self.validate_all( "CAST(a AS BINARY(4))", - read={ - "presto": "CAST(a AS VARBINARY(4))", - "sqlite": "CAST(a AS VARBINARY(4))", - }, write={ "bigquery": "CAST(a AS BINARY(4))", "clickhouse": "CAST(a AS BINARY(4))", @@ -104,6 +100,24 @@ class TestDialect(Validator): }, ) self.validate_all( + "CAST(a AS VARBINARY(4))", + write={ + "bigquery": "CAST(a AS VARBINARY(4))", + "clickhouse": "CAST(a AS VARBINARY(4))", + "duckdb": "CAST(a AS VARBINARY(4))", + "mysql": "CAST(a AS VARBINARY(4))", + "hive": "CAST(a AS BINARY(4))", + "oracle": "CAST(a AS BLOB(4))", + "postgres": "CAST(a AS BYTEA(4))", + "presto": "CAST(a AS VARBINARY(4))", + "redshift": "CAST(a AS VARBYTE(4))", + "snowflake": "CAST(a AS VARBINARY(4))", + "sqlite": "CAST(a AS BLOB(4))", + "spark": "CAST(a AS BINARY(4))", + "starrocks": "CAST(a AS VARBINARY(4))", + }, + ) + self.validate_all( "CAST(MAP('a', '1') AS MAP(TEXT, TEXT))", write={ "clickhouse": "CAST(map('a', '1') AS Map(TEXT, TEXT))", @@ -472,45 +486,57 @@ class TestDialect(Validator): }, ) self.validate_all( - "DATE_TRUNC(x, 'day')", + "DATE_TRUNC('day', x)", write={ "mysql": "DATE(x)", - "starrocks": "DATE(x)", }, ) self.validate_all( - "DATE_TRUNC(x, 'week')", + "DATE_TRUNC('week', x)", write={ "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')", - "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')", }, ) self.validate_all( - "DATE_TRUNC(x, 'month')", + "DATE_TRUNC('month', x)", write={ "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')", - "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')", }, ) self.validate_all( - "DATE_TRUNC(x, 'quarter')", + "DATE_TRUNC('quarter', x)", write={ "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')", - "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')", }, ) self.validate_all( - "DATE_TRUNC(x, 'year')", + "DATE_TRUNC('year', x)", write={ "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')", - "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')", }, ) self.validate_all( - "DATE_TRUNC(x, 'millenium')", + "DATE_TRUNC('millenium', x)", write={ "mysql": UnsupportedError, - "starrocks": UnsupportedError, + }, + ) + self.validate_all( + "DATE_TRUNC('year', x)", + read={ + "starrocks": "DATE_TRUNC('year', x)", + }, + write={ + "starrocks": "DATE_TRUNC('year', x)", + }, + ) + self.validate_all( + "DATE_TRUNC(x, year)", + read={ + "bigquery": "DATE_TRUNC(x, year)", + }, + write={ + "bigquery": "DATE_TRUNC(x, year)", }, ) self.validate_all( @@ -564,6 +590,22 @@ class TestDialect(Validator): "spark": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", }, ) + self.validate_all( + "TIMESTAMP '2022-01-01'", + write={ + "mysql": "CAST('2022-01-01' AS TIMESTAMP)", + "starrocks": "CAST('2022-01-01' AS DATETIME)", + "hive": "CAST('2022-01-01' AS TIMESTAMP)", + }, + ) + self.validate_all( + "TIMESTAMP('2022-01-01')", + write={ + "mysql": "TIMESTAMP('2022-01-01')", + "starrocks": "TIMESTAMP('2022-01-01')", + "hive": "TIMESTAMP('2022-01-01')", + }, + ) for unit in ("DAY", "MONTH", "YEAR"): self.validate_all( @@ -1002,7 +1044,10 @@ class TestDialect(Validator): ) def test_limit(self): - self.validate_all("SELECT * FROM data LIMIT 10, 20", write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"}) + self.validate_all( + "SELECT * FROM data LIMIT 10, 20", + write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"}, + ) self.validate_all( "SELECT x FROM y LIMIT 10", write={ @@ -1132,3 +1177,56 @@ class TestDialect(Validator): "sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c", }, ) + + def test_nullsafe_eq(self): + self.validate_all( + "SELECT a IS NOT DISTINCT FROM b", + read={ + "mysql": "SELECT a <=> b", + "postgres": "SELECT a IS NOT DISTINCT FROM b", + }, + write={ + "mysql": "SELECT a <=> b", + "postgres": "SELECT a IS NOT DISTINCT FROM b", + }, + ) + + def test_nullsafe_neq(self): + self.validate_all( + "SELECT a IS DISTINCT FROM b", + read={ + "postgres": "SELECT a IS DISTINCT FROM b", + }, + write={ + "mysql": "SELECT NOT a <=> b", + "postgres": "SELECT a IS DISTINCT FROM b", + }, + ) + + def test_hash_comments(self): + self.validate_all( + "SELECT 1 /* arbitrary content,,, until end-of-line */", + read={ + "mysql": "SELECT 1 # arbitrary content,,, until end-of-line", + "bigquery": "SELECT 1 # arbitrary content,,, until end-of-line", + "clickhouse": "SELECT 1 #! arbitrary content,,, until end-of-line", + }, + ) + self.validate_all( + """/* comment1 */ +SELECT + x, -- comment2 + y -- comment3""", + read={ + "mysql": """SELECT # comment1 + x, # comment2 + y # comment3""", + "bigquery": """SELECT # comment1 + x, # comment2 + y # comment3""", + "clickhouse": """SELECT # comment1 + x, # comment2 + y # comment3""", + }, + pretty=True, + ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index a25871c..1ba118b 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -1,3 +1,4 @@ +from sqlglot import expressions as exp from tests.dialects.test_dialect import Validator @@ -20,6 +21,52 @@ class TestMySQL(Validator): self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')") self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')") self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')") + self.validate_identity("@@GLOBAL.max_connections") + + # SET Commands + self.validate_identity("SET @var_name = expr") + self.validate_identity("SET @name = 43") + self.validate_identity("SET @total_tax = (SELECT SUM(tax) FROM taxable_transactions)") + self.validate_identity("SET GLOBAL max_connections = 1000") + self.validate_identity("SET @@GLOBAL.max_connections = 1000") + self.validate_identity("SET SESSION sql_mode = 'TRADITIONAL'") + self.validate_identity("SET LOCAL sql_mode = 'TRADITIONAL'") + self.validate_identity("SET @@SESSION.sql_mode = 'TRADITIONAL'") + self.validate_identity("SET @@LOCAL.sql_mode = 'TRADITIONAL'") + self.validate_identity("SET @@sql_mode = 'TRADITIONAL'") + self.validate_identity("SET sql_mode = 'TRADITIONAL'") + self.validate_identity("SET PERSIST max_connections = 1000") + self.validate_identity("SET @@PERSIST.max_connections = 1000") + self.validate_identity("SET PERSIST_ONLY back_log = 100") + self.validate_identity("SET @@PERSIST_ONLY.back_log = 100") + self.validate_identity("SET @@SESSION.max_join_size = DEFAULT") + self.validate_identity("SET @@SESSION.max_join_size = @@GLOBAL.max_join_size") + self.validate_identity("SET @x = 1, SESSION sql_mode = ''") + self.validate_identity( + "SET GLOBAL sort_buffer_size = 1000000, SESSION sort_buffer_size = 1000000" + ) + self.validate_identity( + "SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000" + ) + self.validate_identity("SET GLOBAL max_connections = 1000, sort_buffer_size = 1000000") + self.validate_identity("SET @@GLOBAL.sort_buffer_size = 50000, sort_buffer_size = 1000000") + self.validate_identity("SET CHARACTER SET 'utf8'") + self.validate_identity("SET CHARACTER SET utf8") + self.validate_identity("SET CHARACTER SET DEFAULT") + self.validate_identity("SET NAMES 'utf8'") + self.validate_identity("SET NAMES DEFAULT") + self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'") + self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci") + self.validate_identity("SET autocommit = ON") + + def test_escape(self): + self.validate_all( + r"'a \' b '' '", + write={ + "mysql": r"'a '' b '' '", + "spark": r"'a \' b \' '", + }, + ) def test_introducers(self): self.validate_all( @@ -115,14 +162,6 @@ class TestMySQL(Validator): }, ) - def test_hash_comments(self): - self.validate_all( - "SELECT 1 # arbitrary content,,, until end-of-line", - write={ - "mysql": "SELECT 1", - }, - ) - def test_mysql(self): self.validate_all( "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)", @@ -174,3 +213,242 @@ COMMENT='客户账户表'""" }, pretty=True, ) + + def test_show_simple(self): + for key, write_key in [ + ("BINARY LOGS", "BINARY LOGS"), + ("MASTER LOGS", "BINARY LOGS"), + ("STORAGE ENGINES", "ENGINES"), + ("ENGINES", "ENGINES"), + ("EVENTS", "EVENTS"), + ("MASTER STATUS", "MASTER STATUS"), + ("PLUGINS", "PLUGINS"), + ("PRIVILEGES", "PRIVILEGES"), + ("PROFILES", "PROFILES"), + ("REPLICAS", "REPLICAS"), + ("SLAVE HOSTS", "REPLICAS"), + ]: + show = self.validate_identity(f"SHOW {key}", f"SHOW {write_key}") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, write_key) + + def test_show_events(self): + for key in ["BINLOG", "RELAYLOG"]: + show = self.validate_identity(f"SHOW {key} EVENTS") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, f"{key} EVENTS") + + show = self.validate_identity(f"SHOW {key} EVENTS IN 'log' FROM 1 LIMIT 2, 3") + self.assertEqual(show.text("log"), "log") + self.assertEqual(show.text("position"), "1") + self.assertEqual(show.text("limit"), "3") + self.assertEqual(show.text("offset"), "2") + + show = self.validate_identity(f"SHOW {key} EVENTS LIMIT 1") + self.assertEqual(show.text("limit"), "1") + self.assertIsNone(show.args.get("offset")) + + def test_show_like_or_where(self): + for key, write_key in [ + ("CHARSET", "CHARACTER SET"), + ("CHARACTER SET", "CHARACTER SET"), + ("COLLATION", "COLLATION"), + ("DATABASES", "DATABASES"), + ("FUNCTION STATUS", "FUNCTION STATUS"), + ("PROCEDURE STATUS", "PROCEDURE STATUS"), + ("GLOBAL STATUS", "GLOBAL STATUS"), + ("SESSION STATUS", "STATUS"), + ("STATUS", "STATUS"), + ("GLOBAL VARIABLES", "GLOBAL VARIABLES"), + ("SESSION VARIABLES", "VARIABLES"), + ("VARIABLES", "VARIABLES"), + ]: + expected_name = write_key.strip("GLOBAL").strip() + template = "SHOW {}" + show = self.validate_identity(template.format(key), template.format(write_key)) + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, expected_name) + + template = "SHOW {} LIKE '%foo%'" + show = self.validate_identity(template.format(key), template.format(write_key)) + self.assertIsInstance(show, exp.Show) + self.assertIsInstance(show.args["like"], exp.Literal) + self.assertEqual(show.text("like"), "%foo%") + + template = "SHOW {} WHERE Column_name LIKE '%foo%'" + show = self.validate_identity(template.format(key), template.format(write_key)) + self.assertIsInstance(show, exp.Show) + self.assertIsInstance(show.args["where"], exp.Where) + self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'") + + def test_show_columns(self): + show = self.validate_identity("SHOW COLUMNS FROM tbl_name") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "COLUMNS") + self.assertEqual(show.text("target"), "tbl_name") + self.assertFalse(show.args["full"]) + + show = self.validate_identity("SHOW FULL COLUMNS FROM tbl_name FROM db_name LIKE '%foo%'") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.text("target"), "tbl_name") + self.assertTrue(show.args["full"]) + self.assertEqual(show.text("db"), "db_name") + self.assertIsInstance(show.args["like"], exp.Literal) + self.assertEqual(show.text("like"), "%foo%") + + def test_show_name(self): + for key in [ + "CREATE DATABASE", + "CREATE EVENT", + "CREATE FUNCTION", + "CREATE PROCEDURE", + "CREATE TABLE", + "CREATE TRIGGER", + "CREATE VIEW", + "FUNCTION CODE", + "PROCEDURE CODE", + ]: + show = self.validate_identity(f"SHOW {key} foo") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, key) + self.assertEqual(show.text("target"), "foo") + + def test_show_grants(self): + show = self.validate_identity(f"SHOW GRANTS FOR foo") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "GRANTS") + self.assertEqual(show.text("target"), "foo") + + def test_show_engine(self): + show = self.validate_identity("SHOW ENGINE foo STATUS") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "ENGINE") + self.assertEqual(show.text("target"), "foo") + self.assertFalse(show.args["mutex"]) + + show = self.validate_identity("SHOW ENGINE foo MUTEX") + self.assertEqual(show.name, "ENGINE") + self.assertEqual(show.text("target"), "foo") + self.assertTrue(show.args["mutex"]) + + def test_show_errors(self): + for key in ["ERRORS", "WARNINGS"]: + show = self.validate_identity(f"SHOW {key}") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, key) + + show = self.validate_identity(f"SHOW {key} LIMIT 2, 3") + self.assertEqual(show.text("limit"), "3") + self.assertEqual(show.text("offset"), "2") + + def test_show_index(self): + show = self.validate_identity("SHOW INDEX FROM foo") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "INDEX") + self.assertEqual(show.text("target"), "foo") + + show = self.validate_identity("SHOW INDEX FROM foo FROM bar") + self.assertEqual(show.text("db"), "bar") + + def test_show_db_like_or_where_sql(self): + for key in [ + "OPEN TABLES", + "TABLE STATUS", + "TRIGGERS", + ]: + show = self.validate_identity(f"SHOW {key}") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, key) + + show = self.validate_identity(f"SHOW {key} FROM db_name") + self.assertEqual(show.name, key) + self.assertEqual(show.text("db"), "db_name") + + show = self.validate_identity(f"SHOW {key} LIKE '%foo%'") + self.assertEqual(show.name, key) + self.assertIsInstance(show.args["like"], exp.Literal) + self.assertEqual(show.text("like"), "%foo%") + + show = self.validate_identity(f"SHOW {key} WHERE Column_name LIKE '%foo%'") + self.assertEqual(show.name, key) + self.assertIsInstance(show.args["where"], exp.Where) + self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'") + + def test_show_processlist(self): + show = self.validate_identity("SHOW PROCESSLIST") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "PROCESSLIST") + self.assertFalse(show.args["full"]) + + show = self.validate_identity("SHOW FULL PROCESSLIST") + self.assertEqual(show.name, "PROCESSLIST") + self.assertTrue(show.args["full"]) + + def test_show_profile(self): + show = self.validate_identity("SHOW PROFILE") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "PROFILE") + + show = self.validate_identity("SHOW PROFILE BLOCK IO") + self.assertEqual(show.args["types"][0].name, "BLOCK IO") + + show = self.validate_identity( + "SHOW PROFILE BLOCK IO, PAGE FAULTS FOR QUERY 1 OFFSET 2 LIMIT 3" + ) + self.assertEqual(show.args["types"][0].name, "BLOCK IO") + self.assertEqual(show.args["types"][1].name, "PAGE FAULTS") + self.assertEqual(show.text("query"), "1") + self.assertEqual(show.text("offset"), "2") + self.assertEqual(show.text("limit"), "3") + + def test_show_replica_status(self): + show = self.validate_identity("SHOW REPLICA STATUS") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "REPLICA STATUS") + + show = self.validate_identity("SHOW SLAVE STATUS", "SHOW REPLICA STATUS") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "REPLICA STATUS") + + show = self.validate_identity("SHOW REPLICA STATUS FOR CHANNEL channel_name") + self.assertEqual(show.text("channel"), "channel_name") + + def test_show_tables(self): + show = self.validate_identity("SHOW TABLES") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "TABLES") + + show = self.validate_identity("SHOW FULL TABLES FROM db_name LIKE '%foo%'") + self.assertTrue(show.args["full"]) + self.assertEqual(show.text("db"), "db_name") + self.assertIsInstance(show.args["like"], exp.Literal) + self.assertEqual(show.text("like"), "%foo%") + + def test_set_variable(self): + cmd = self.parse_one("SET SESSION x = 1") + item = cmd.expressions[0] + self.assertEqual(item.text("kind"), "SESSION") + self.assertIsInstance(item.this, exp.EQ) + self.assertEqual(item.this.left.name, "x") + self.assertEqual(item.this.right.name, "1") + + cmd = self.parse_one("SET @@GLOBAL.x = @@GLOBAL.y") + item = cmd.expressions[0] + self.assertEqual(item.text("kind"), "") + self.assertIsInstance(item.this, exp.EQ) + self.assertIsInstance(item.this.left, exp.SessionParameter) + self.assertIsInstance(item.this.right, exp.SessionParameter) + + cmd = self.parse_one("SET NAMES 'charset_name' COLLATE 'collation_name'") + item = cmd.expressions[0] + self.assertEqual(item.text("kind"), "NAMES") + self.assertEqual(item.name, "charset_name") + self.assertEqual(item.text("collate"), "collation_name") + + cmd = self.parse_one("SET CHARSET DEFAULT") + item = cmd.expressions[0] + self.assertEqual(item.text("kind"), "CHARACTER SET") + self.assertEqual(item.this.name, "DEFAULT") + + cmd = self.parse_one("SET x = 1, y = 2") + self.assertEqual(len(cmd.expressions), 2) diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 35141e2..8294eea 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -8,7 +8,9 @@ class TestPostgres(Validator): def test_ddl(self): self.validate_all( "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)", - write={"postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)"}, + write={ + "postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)" + }, ) self.validate_all( "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)", @@ -59,15 +61,27 @@ class TestPostgres(Validator): def test_postgres(self): self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END") - self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END") - self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END") - self.validate_identity('SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')') + self.validate_identity( + "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END" + ) + self.validate_identity( + "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1 FOR 2) IN ('ab') THEN 1 ELSE 0 END" + ) + self.validate_identity( + 'SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')' + ) self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')") - self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')") - self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))") + self.validate_identity( + "SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')" + ) + self.validate_identity( + "SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))" + ) self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')") self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)") - self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')") + self.validate_identity( + "SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')" + ) self.validate_identity("COMMENT ON TABLE mytable IS 'this'") self.validate_identity("SELECT e'\\xDEADBEEF'") self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") @@ -75,7 +89,7 @@ class TestPostgres(Validator): self.validate_all( "CREATE TABLE x (a UUID, b BYTEA)", write={ - "duckdb": "CREATE TABLE x (a UUID, b BINARY)", + "duckdb": "CREATE TABLE x (a UUID, b VARBINARY)", "presto": "CREATE TABLE x (a UUID, b VARBINARY)", "hive": "CREATE TABLE x (a UUID, b BINARY)", "spark": "CREATE TABLE x (a UUID, b BINARY)", @@ -153,7 +167,9 @@ class TestPostgres(Validator): ) self.validate_all( "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss", - read={"postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss"}, + read={ + "postgres": "SELECT * FROM foo, LATERAL (SELECT * FROM bar WHERE bar.id = foo.bar_id) AS ss" + }, ) self.validate_all( "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL", @@ -169,11 +185,15 @@ class TestPostgres(Validator): ) self.validate_all( "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL", - read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL"}, + read={ + "postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE deleted NOTNULL" + }, ) self.validate_all( "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted IS NULL", - read={"postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL"}, + read={ + "postgres": "SELECT id, email, CAST(deleted AS TEXT) FROM users WHERE NOT deleted ISNULL" + }, ) self.validate_all( "'[1,2,3]'::json->2", @@ -184,7 +204,8 @@ class TestPostgres(Validator): write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->'b'"""}, ) self.validate_all( - """'{"x": {"y": 1}}'::json->'x'->'y'""", write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""} + """'{"x": {"y": 1}}'::json->'x'->'y'""", + write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""}, ) self.validate_all( """'{"x": {"y": 1}}'::json->'x'::json->'y'""", diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 1ed2bb6..5309a34 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -61,4 +61,6 @@ class TestRedshift(Validator): "SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'" ) self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)") - self.validate_identity("SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'") + self.validate_identity( + "SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'" + ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index fea2311..1846b17 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -336,7 +336,8 @@ class TestSnowflake(Validator): def test_table_literal(self): # All examples from https://docs.snowflake.com/en/sql-reference/literals-table.html self.validate_all( - r"""SELECT * FROM TABLE('MYTABLE')""", write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""} + r"""SELECT * FROM TABLE('MYTABLE')""", + write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""}, ) self.validate_all( @@ -352,15 +353,123 @@ class TestSnowflake(Validator): write={"snowflake": r"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')"""}, ) - self.validate_all(r"""SELECT * FROM TABLE($MYVAR)""", write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""}) + self.validate_all( + r"""SELECT * FROM TABLE($MYVAR)""", + write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""}, + ) - self.validate_all(r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""}) + self.validate_all( + r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""} + ) self.validate_all( - r"""SELECT * FROM TABLE(:BINDING)""", write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""} + r"""SELECT * FROM TABLE(:BINDING)""", + write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""}, ) self.validate_all( r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10""", write={"snowflake": r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10"""}, ) + + def test_flatten(self): + self.validate_all( + """ + select + dag_report.acct_id, + dag_report.report_date, + dag_report.report_uuid, + dag_report.airflow_name, + dag_report.dag_id, + f.value::varchar as operator + from cs.telescope.dag_report, + table(flatten(input=>split(operators, ','))) f + """, + write={ + "snowflake": """SELECT + dag_report.acct_id, + dag_report.report_date, + dag_report.report_uuid, + dag_report.airflow_name, + dag_report.dag_id, + CAST(f.value AS VARCHAR) AS operator +FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS f""" + }, + pretty=True, + ) + + # All examples from https://docs.snowflake.com/en/sql-reference/functions/flatten.html#syntax + self.validate_all( + "SELECT * FROM TABLE(FLATTEN(input => parse_json('[1, ,77]'))) f", + write={ + "snowflake": "SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[1, ,77]'))) AS f" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), outer => true)) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), outer => TRUE)) AS f""" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88]}'), path => 'b')) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88]}'), path => 'b')) AS f""" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'))) f""", + write={"snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'))) AS f"""}, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('[]'), outer => true)) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('[]'), outer => TRUE)) AS f""" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'))) AS f""" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true)) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE)) AS f""" + }, + ) + + self.validate_all( + """SELECT * FROM TABLE(FLATTEN(input => parse_json('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => true, mode => 'object')) f""", + write={ + "snowflake": """SELECT * FROM TABLE(FLATTEN(input => PARSE_JSON('{"a":1, "b":[77,88], "c": {"d":"X"}}'), recursive => TRUE, mode => 'object')) AS f""" + }, + ) + + self.validate_all( + """ + SELECT id as "ID", + f.value AS "Contact", + f1.value:type AS "Type", + f1.value:content AS "Details" + FROM persons p, + lateral flatten(input => p.c, path => 'contact') f, + lateral flatten(input => f.value:business) f1 + """, + write={ + "snowflake": """SELECT + id AS "ID", + f.value AS "Contact", + f1.value['type'] AS "Type", + f1.value['content'] AS "Details" +FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL FLATTEN(input => f.value['business']) f1""", + }, + pretty=True, + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 8605bd1..4470722 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -284,4 +284,6 @@ TBLPROPERTIES ( ) def test_iif(self): - self.validate_all("SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"}) + self.validate_all( + "SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"} + ) diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py index 1fe1a57..35d8b45 100644 --- a/tests/dialects/test_starrocks.py +++ b/tests/dialects/test_starrocks.py @@ -6,3 +6,6 @@ class TestMySQL(Validator): def test_identity(self): self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") + + def test_time(self): + self.validate_identity("TIMESTAMP('2022-01-01')") diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index d22a9c2..a60f48d 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -278,12 +278,19 @@ class TestTSQL(Validator): def test_add_date(self): self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')") self.validate_all( - "SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"} + "SELECT DATEADD(year, 1, '2017/08/25')", + write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"}, + ) + self.validate_all( + "SELECT DATEADD(qq, 1, '2017/08/25')", + write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"}, ) - self.validate_all("SELECT DATEADD(qq, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"}) self.validate_all( "SELECT DATEADD(wk, 1, '2017/08/25')", - write={"spark": "SELECT DATE_ADD('2017/08/25', 7)", "databricks": "SELECT DATEADD(week, 1, '2017/08/25')"}, + write={ + "spark": "SELECT DATE_ADD('2017/08/25', 7)", + "databricks": "SELECT DATEADD(week, 1, '2017/08/25')", + }, ) def test_date_diff(self): @@ -370,13 +377,21 @@ class TestTSQL(Validator): "SELECT FORMAT(1000000.01,'###,###.###')", write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"}, ) - self.validate_all("SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"}) + self.validate_all( + "SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"} + ) self.validate_all( "SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')", write={"spark": "SELECT DATE_FORMAT('01-01-1991', 'dd.mm.yyyy')"}, ) self.validate_all( - "SELECT FORMAT(date_col, 'dd.mm.yyyy')", write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"} + "SELECT FORMAT(date_col, 'dd.mm.yyyy')", + write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"}, + ) + self.validate_all( + "SELECT FORMAT(date_col, 'm')", + write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"}, + ) + self.validate_all( + "SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"} ) - self.validate_all("SELECT FORMAT(date_col, 'm')", write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"}) - self.validate_all("SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"}) |