diff options
Diffstat (limited to 'tests/dialects/test_mysql.py')
-rw-r--r-- | tests/dialects/test_mysql.py | 294 |
1 files changed, 286 insertions, 8 deletions
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index a25871c..1ba118b 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -1,3 +1,4 @@ +from sqlglot import expressions as exp from tests.dialects.test_dialect import Validator @@ -20,6 +21,52 @@ class TestMySQL(Validator): self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')") self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')") self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')") + self.validate_identity("@@GLOBAL.max_connections") + + # SET Commands + self.validate_identity("SET @var_name = expr") + self.validate_identity("SET @name = 43") + self.validate_identity("SET @total_tax = (SELECT SUM(tax) FROM taxable_transactions)") + self.validate_identity("SET GLOBAL max_connections = 1000") + self.validate_identity("SET @@GLOBAL.max_connections = 1000") + self.validate_identity("SET SESSION sql_mode = 'TRADITIONAL'") + self.validate_identity("SET LOCAL sql_mode = 'TRADITIONAL'") + self.validate_identity("SET @@SESSION.sql_mode = 'TRADITIONAL'") + self.validate_identity("SET @@LOCAL.sql_mode = 'TRADITIONAL'") + self.validate_identity("SET @@sql_mode = 'TRADITIONAL'") + self.validate_identity("SET sql_mode = 'TRADITIONAL'") + self.validate_identity("SET PERSIST max_connections = 1000") + self.validate_identity("SET @@PERSIST.max_connections = 1000") + self.validate_identity("SET PERSIST_ONLY back_log = 100") + self.validate_identity("SET @@PERSIST_ONLY.back_log = 100") + self.validate_identity("SET @@SESSION.max_join_size = DEFAULT") + self.validate_identity("SET @@SESSION.max_join_size = @@GLOBAL.max_join_size") + self.validate_identity("SET @x = 1, SESSION sql_mode = ''") + self.validate_identity( + "SET GLOBAL sort_buffer_size = 1000000, SESSION sort_buffer_size = 1000000" + ) + self.validate_identity( + "SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000" + ) + self.validate_identity("SET GLOBAL max_connections = 1000, sort_buffer_size = 1000000") + self.validate_identity("SET @@GLOBAL.sort_buffer_size = 50000, sort_buffer_size = 1000000") + self.validate_identity("SET CHARACTER SET 'utf8'") + self.validate_identity("SET CHARACTER SET utf8") + self.validate_identity("SET CHARACTER SET DEFAULT") + self.validate_identity("SET NAMES 'utf8'") + self.validate_identity("SET NAMES DEFAULT") + self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'") + self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci") + self.validate_identity("SET autocommit = ON") + + def test_escape(self): + self.validate_all( + r"'a \' b '' '", + write={ + "mysql": r"'a '' b '' '", + "spark": r"'a \' b \' '", + }, + ) def test_introducers(self): self.validate_all( @@ -115,14 +162,6 @@ class TestMySQL(Validator): }, ) - def test_hash_comments(self): - self.validate_all( - "SELECT 1 # arbitrary content,,, until end-of-line", - write={ - "mysql": "SELECT 1", - }, - ) - def test_mysql(self): self.validate_all( "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)", @@ -174,3 +213,242 @@ COMMENT='客户账户表'""" }, pretty=True, ) + + def test_show_simple(self): + for key, write_key in [ + ("BINARY LOGS", "BINARY LOGS"), + ("MASTER LOGS", "BINARY LOGS"), + ("STORAGE ENGINES", "ENGINES"), + ("ENGINES", "ENGINES"), + ("EVENTS", "EVENTS"), + ("MASTER STATUS", "MASTER STATUS"), + ("PLUGINS", "PLUGINS"), + ("PRIVILEGES", "PRIVILEGES"), + ("PROFILES", "PROFILES"), + ("REPLICAS", "REPLICAS"), + ("SLAVE HOSTS", "REPLICAS"), + ]: + show = self.validate_identity(f"SHOW {key}", f"SHOW {write_key}") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, write_key) + + def test_show_events(self): + for key in ["BINLOG", "RELAYLOG"]: + show = self.validate_identity(f"SHOW {key} EVENTS") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, f"{key} EVENTS") + + show = self.validate_identity(f"SHOW {key} EVENTS IN 'log' FROM 1 LIMIT 2, 3") + self.assertEqual(show.text("log"), "log") + self.assertEqual(show.text("position"), "1") + self.assertEqual(show.text("limit"), "3") + self.assertEqual(show.text("offset"), "2") + + show = self.validate_identity(f"SHOW {key} EVENTS LIMIT 1") + self.assertEqual(show.text("limit"), "1") + self.assertIsNone(show.args.get("offset")) + + def test_show_like_or_where(self): + for key, write_key in [ + ("CHARSET", "CHARACTER SET"), + ("CHARACTER SET", "CHARACTER SET"), + ("COLLATION", "COLLATION"), + ("DATABASES", "DATABASES"), + ("FUNCTION STATUS", "FUNCTION STATUS"), + ("PROCEDURE STATUS", "PROCEDURE STATUS"), + ("GLOBAL STATUS", "GLOBAL STATUS"), + ("SESSION STATUS", "STATUS"), + ("STATUS", "STATUS"), + ("GLOBAL VARIABLES", "GLOBAL VARIABLES"), + ("SESSION VARIABLES", "VARIABLES"), + ("VARIABLES", "VARIABLES"), + ]: + expected_name = write_key.strip("GLOBAL").strip() + template = "SHOW {}" + show = self.validate_identity(template.format(key), template.format(write_key)) + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, expected_name) + + template = "SHOW {} LIKE '%foo%'" + show = self.validate_identity(template.format(key), template.format(write_key)) + self.assertIsInstance(show, exp.Show) + self.assertIsInstance(show.args["like"], exp.Literal) + self.assertEqual(show.text("like"), "%foo%") + + template = "SHOW {} WHERE Column_name LIKE '%foo%'" + show = self.validate_identity(template.format(key), template.format(write_key)) + self.assertIsInstance(show, exp.Show) + self.assertIsInstance(show.args["where"], exp.Where) + self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'") + + def test_show_columns(self): + show = self.validate_identity("SHOW COLUMNS FROM tbl_name") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "COLUMNS") + self.assertEqual(show.text("target"), "tbl_name") + self.assertFalse(show.args["full"]) + + show = self.validate_identity("SHOW FULL COLUMNS FROM tbl_name FROM db_name LIKE '%foo%'") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.text("target"), "tbl_name") + self.assertTrue(show.args["full"]) + self.assertEqual(show.text("db"), "db_name") + self.assertIsInstance(show.args["like"], exp.Literal) + self.assertEqual(show.text("like"), "%foo%") + + def test_show_name(self): + for key in [ + "CREATE DATABASE", + "CREATE EVENT", + "CREATE FUNCTION", + "CREATE PROCEDURE", + "CREATE TABLE", + "CREATE TRIGGER", + "CREATE VIEW", + "FUNCTION CODE", + "PROCEDURE CODE", + ]: + show = self.validate_identity(f"SHOW {key} foo") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, key) + self.assertEqual(show.text("target"), "foo") + + def test_show_grants(self): + show = self.validate_identity(f"SHOW GRANTS FOR foo") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "GRANTS") + self.assertEqual(show.text("target"), "foo") + + def test_show_engine(self): + show = self.validate_identity("SHOW ENGINE foo STATUS") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "ENGINE") + self.assertEqual(show.text("target"), "foo") + self.assertFalse(show.args["mutex"]) + + show = self.validate_identity("SHOW ENGINE foo MUTEX") + self.assertEqual(show.name, "ENGINE") + self.assertEqual(show.text("target"), "foo") + self.assertTrue(show.args["mutex"]) + + def test_show_errors(self): + for key in ["ERRORS", "WARNINGS"]: + show = self.validate_identity(f"SHOW {key}") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, key) + + show = self.validate_identity(f"SHOW {key} LIMIT 2, 3") + self.assertEqual(show.text("limit"), "3") + self.assertEqual(show.text("offset"), "2") + + def test_show_index(self): + show = self.validate_identity("SHOW INDEX FROM foo") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "INDEX") + self.assertEqual(show.text("target"), "foo") + + show = self.validate_identity("SHOW INDEX FROM foo FROM bar") + self.assertEqual(show.text("db"), "bar") + + def test_show_db_like_or_where_sql(self): + for key in [ + "OPEN TABLES", + "TABLE STATUS", + "TRIGGERS", + ]: + show = self.validate_identity(f"SHOW {key}") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, key) + + show = self.validate_identity(f"SHOW {key} FROM db_name") + self.assertEqual(show.name, key) + self.assertEqual(show.text("db"), "db_name") + + show = self.validate_identity(f"SHOW {key} LIKE '%foo%'") + self.assertEqual(show.name, key) + self.assertIsInstance(show.args["like"], exp.Literal) + self.assertEqual(show.text("like"), "%foo%") + + show = self.validate_identity(f"SHOW {key} WHERE Column_name LIKE '%foo%'") + self.assertEqual(show.name, key) + self.assertIsInstance(show.args["where"], exp.Where) + self.assertEqual(show.args["where"].sql(), "WHERE Column_name LIKE '%foo%'") + + def test_show_processlist(self): + show = self.validate_identity("SHOW PROCESSLIST") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "PROCESSLIST") + self.assertFalse(show.args["full"]) + + show = self.validate_identity("SHOW FULL PROCESSLIST") + self.assertEqual(show.name, "PROCESSLIST") + self.assertTrue(show.args["full"]) + + def test_show_profile(self): + show = self.validate_identity("SHOW PROFILE") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "PROFILE") + + show = self.validate_identity("SHOW PROFILE BLOCK IO") + self.assertEqual(show.args["types"][0].name, "BLOCK IO") + + show = self.validate_identity( + "SHOW PROFILE BLOCK IO, PAGE FAULTS FOR QUERY 1 OFFSET 2 LIMIT 3" + ) + self.assertEqual(show.args["types"][0].name, "BLOCK IO") + self.assertEqual(show.args["types"][1].name, "PAGE FAULTS") + self.assertEqual(show.text("query"), "1") + self.assertEqual(show.text("offset"), "2") + self.assertEqual(show.text("limit"), "3") + + def test_show_replica_status(self): + show = self.validate_identity("SHOW REPLICA STATUS") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "REPLICA STATUS") + + show = self.validate_identity("SHOW SLAVE STATUS", "SHOW REPLICA STATUS") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "REPLICA STATUS") + + show = self.validate_identity("SHOW REPLICA STATUS FOR CHANNEL channel_name") + self.assertEqual(show.text("channel"), "channel_name") + + def test_show_tables(self): + show = self.validate_identity("SHOW TABLES") + self.assertIsInstance(show, exp.Show) + self.assertEqual(show.name, "TABLES") + + show = self.validate_identity("SHOW FULL TABLES FROM db_name LIKE '%foo%'") + self.assertTrue(show.args["full"]) + self.assertEqual(show.text("db"), "db_name") + self.assertIsInstance(show.args["like"], exp.Literal) + self.assertEqual(show.text("like"), "%foo%") + + def test_set_variable(self): + cmd = self.parse_one("SET SESSION x = 1") + item = cmd.expressions[0] + self.assertEqual(item.text("kind"), "SESSION") + self.assertIsInstance(item.this, exp.EQ) + self.assertEqual(item.this.left.name, "x") + self.assertEqual(item.this.right.name, "1") + + cmd = self.parse_one("SET @@GLOBAL.x = @@GLOBAL.y") + item = cmd.expressions[0] + self.assertEqual(item.text("kind"), "") + self.assertIsInstance(item.this, exp.EQ) + self.assertIsInstance(item.this.left, exp.SessionParameter) + self.assertIsInstance(item.this.right, exp.SessionParameter) + + cmd = self.parse_one("SET NAMES 'charset_name' COLLATE 'collation_name'") + item = cmd.expressions[0] + self.assertEqual(item.text("kind"), "NAMES") + self.assertEqual(item.name, "charset_name") + self.assertEqual(item.text("collate"), "collation_name") + + cmd = self.parse_one("SET CHARSET DEFAULT") + item = cmd.expressions[0] + self.assertEqual(item.text("kind"), "CHARACTER SET") + self.assertEqual(item.this.name, "DEFAULT") + + cmd = self.parse_one("SET x = 1, y = 2") + self.assertEqual(len(cmd.expressions), 2) |