diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/dialects/mysql.py | 329 |
1 files changed, 303 insertions, 26 deletions
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 524390f..e742640 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -1,4 +1,8 @@ -from sqlglot import exp +from __future__ import annotations + +import typing as t + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, no_ilike_sql, @@ -6,42 +10,47 @@ from sqlglot.dialects.dialect import ( no_tablesample_sql, no_trycast_sql, ) -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.helper import seq_get +from sqlglot.tokens import TokenType + + +def _show_parser(*args, **kwargs): + def _parse(self): + return self._parse_show_mysql(*args, **kwargs) + + return _parse def _date_trunc_sql(self, expression): - unit = expression.text("unit").lower() + unit = expression.name.lower() - this = self.sql(expression.this) + expr = self.sql(expression.expression) if unit == "day": - return f"DATE({this})" + return f"DATE({expr})" if unit == "week": - concat = f"CONCAT(YEAR({this}), ' ', WEEK({this}, 1), ' 1')" + concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')" date_format = "%Y %u %w" elif unit == "month": - concat = f"CONCAT(YEAR({this}), ' ', MONTH({this}), ' 1')" + concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')" date_format = "%Y %c %e" elif unit == "quarter": - concat = f"CONCAT(YEAR({this}), ' ', QUARTER({this}) * 3 - 2, ' 1')" + concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')" date_format = "%Y %c %e" elif unit == "year": - concat = f"CONCAT(YEAR({this}), ' 1 1')" + concat = f"CONCAT(YEAR({expr}), ' 1 1')" date_format = "%Y %c %e" else: self.unsupported("Unexpected interval unit: {unit}") - return f"DATE({this})" + return f"DATE({expr})" return f"STR_TO_DATE({concat}, '{date_format}')" def _str_to_date(args): - date_format = MySQL.format_time(list_get(args, 1)) - return exp.StrToDate(this=list_get(args, 0), format=date_format) + date_format = MySQL.format_time(seq_get(args, 1)) + return exp.StrToDate(this=seq_get(args, 0), format=date_format) def _str_to_date_sql(self, expression): @@ -66,9 +75,9 @@ def _trim_sql(self, expression): def _date_add(expression_class): def func(args): - interval = list_get(args, 1) + interval = seq_get(args, 1) return expression_class( - this=list_get(args, 0), + this=seq_get(args, 0), expression=interval.this, unit=exp.Literal.string(interval.text("unit").lower()), ) @@ -101,15 +110,16 @@ class MySQL(Dialect): "%l": "%-I", } - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): QUOTES = ["'", '"'] COMMENTS = ["--", "#", ("/*", "*/")] IDENTIFIERS = ["`"] + ESCAPES = ["'", "\\"] BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")] HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")] KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "SEPARATOR": TokenType.SEPARATOR, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, @@ -156,20 +166,23 @@ class MySQL(Dialect): "_UTF32": TokenType.INTRODUCER, "_UTF8MB3": TokenType.INTRODUCER, "_UTF8MB4": TokenType.INTRODUCER, + "@@": TokenType.SESSION_PARAMETER, } - class Parser(Parser): + COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW} + + class Parser(parser.Parser): STRICT_CAST = False FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "DATE_ADD": _date_add(exp.DateAdd), "DATE_SUB": _date_add(exp.DateSub), "STR_TO_DATE": _str_to_date, } FUNCTION_PARSERS = { - **Parser.FUNCTION_PARSERS, + **parser.Parser.FUNCTION_PARSERS, "GROUP_CONCAT": lambda self: self.expression( exp.GroupConcat, this=self._parse_lambda(), @@ -178,15 +191,212 @@ class MySQL(Dialect): } PROPERTY_PARSERS = { - **Parser.PROPERTY_PARSERS, + **parser.Parser.PROPERTY_PARSERS, TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty), } - class Generator(Generator): + STATEMENT_PARSERS = { + **parser.Parser.STATEMENT_PARSERS, + TokenType.SHOW: lambda self: self._parse_show(), + TokenType.SET: lambda self: self._parse_set(), + } + + SHOW_PARSERS = { + "BINARY LOGS": _show_parser("BINARY LOGS"), + "MASTER LOGS": _show_parser("BINARY LOGS"), + "BINLOG EVENTS": _show_parser("BINLOG EVENTS"), + "CHARACTER SET": _show_parser("CHARACTER SET"), + "CHARSET": _show_parser("CHARACTER SET"), + "COLLATION": _show_parser("COLLATION"), + "FULL COLUMNS": _show_parser("COLUMNS", target="FROM", full=True), + "COLUMNS": _show_parser("COLUMNS", target="FROM"), + "CREATE DATABASE": _show_parser("CREATE DATABASE", target=True), + "CREATE EVENT": _show_parser("CREATE EVENT", target=True), + "CREATE FUNCTION": _show_parser("CREATE FUNCTION", target=True), + "CREATE PROCEDURE": _show_parser("CREATE PROCEDURE", target=True), + "CREATE TABLE": _show_parser("CREATE TABLE", target=True), + "CREATE TRIGGER": _show_parser("CREATE TRIGGER", target=True), + "CREATE VIEW": _show_parser("CREATE VIEW", target=True), + "DATABASES": _show_parser("DATABASES"), + "ENGINE": _show_parser("ENGINE", target=True), + "STORAGE ENGINES": _show_parser("ENGINES"), + "ENGINES": _show_parser("ENGINES"), + "ERRORS": _show_parser("ERRORS"), + "EVENTS": _show_parser("EVENTS"), + "FUNCTION CODE": _show_parser("FUNCTION CODE", target=True), + "FUNCTION STATUS": _show_parser("FUNCTION STATUS"), + "GRANTS": _show_parser("GRANTS", target="FOR"), + "INDEX": _show_parser("INDEX", target="FROM"), + "MASTER STATUS": _show_parser("MASTER STATUS"), + "OPEN TABLES": _show_parser("OPEN TABLES"), + "PLUGINS": _show_parser("PLUGINS"), + "PROCEDURE CODE": _show_parser("PROCEDURE CODE", target=True), + "PROCEDURE STATUS": _show_parser("PROCEDURE STATUS"), + "PRIVILEGES": _show_parser("PRIVILEGES"), + "FULL PROCESSLIST": _show_parser("PROCESSLIST", full=True), + "PROCESSLIST": _show_parser("PROCESSLIST"), + "PROFILE": _show_parser("PROFILE"), + "PROFILES": _show_parser("PROFILES"), + "RELAYLOG EVENTS": _show_parser("RELAYLOG EVENTS"), + "REPLICAS": _show_parser("REPLICAS"), + "SLAVE HOSTS": _show_parser("REPLICAS"), + "REPLICA STATUS": _show_parser("REPLICA STATUS"), + "SLAVE STATUS": _show_parser("REPLICA STATUS"), + "GLOBAL STATUS": _show_parser("STATUS", global_=True), + "SESSION STATUS": _show_parser("STATUS"), + "STATUS": _show_parser("STATUS"), + "TABLE STATUS": _show_parser("TABLE STATUS"), + "FULL TABLES": _show_parser("TABLES", full=True), + "TABLES": _show_parser("TABLES"), + "TRIGGERS": _show_parser("TRIGGERS"), + "GLOBAL VARIABLES": _show_parser("VARIABLES", global_=True), + "SESSION VARIABLES": _show_parser("VARIABLES"), + "VARIABLES": _show_parser("VARIABLES"), + "WARNINGS": _show_parser("WARNINGS"), + } + + SET_PARSERS = { + "GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"), + "PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"), + "PERSIST_ONLY": lambda self: self._parse_set_item_assignment("PERSIST_ONLY"), + "SESSION": lambda self: self._parse_set_item_assignment("SESSION"), + "LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"), + "CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"), + "CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"), + "NAMES": lambda self: self._parse_set_item_names(), + } + + PROFILE_TYPES = { + "ALL", + "BLOCK IO", + "CONTEXT SWITCHES", + "CPU", + "IPC", + "MEMORY", + "PAGE FAULTS", + "SOURCE", + "SWAPS", + } + + def _parse_show_mysql(self, this, target=False, full=None, global_=None): + if target: + if isinstance(target, str): + self._match_text(target) + target_id = self._parse_id_var() + else: + target_id = None + + log = self._parse_string() if self._match_text("IN") else None + + if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}: + position = self._parse_number() if self._match_text("FROM") else None + db = None + else: + position = None + db = self._parse_id_var() if self._match_text("FROM") else None + + channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None + + like = self._parse_string() if self._match_text("LIKE") else None + where = self._parse_where() + + if this == "PROFILE": + types = self._parse_csv(self._parse_show_profile_type) + query = self._parse_number() if self._match_text("FOR", "QUERY") else None + offset = self._parse_number() if self._match_text("OFFSET") else None + limit = self._parse_number() if self._match_text("LIMIT") else None + else: + types, query = None, None + offset, limit = self._parse_oldstyle_limit() + + mutex = True if self._match_text("MUTEX") else None + mutex = False if self._match_text("STATUS") else mutex + + return self.expression( + exp.Show, + this=this, + target=target_id, + full=full, + log=log, + position=position, + db=db, + channel=channel, + like=like, + where=where, + types=types, + query=query, + offset=offset, + limit=limit, + mutex=mutex, + **{"global": global_}, + ) + + def _parse_show_profile_type(self): + for type_ in self.PROFILE_TYPES: + if self._match_text(*type_.split(" ")): + return exp.Var(this=type_) + return None + + def _parse_oldstyle_limit(self): + limit = None + offset = None + if self._match_text("LIMIT"): + parts = self._parse_csv(self._parse_number) + if len(parts) == 1: + limit = parts[0] + elif len(parts) == 2: + limit = parts[1] + offset = parts[0] + return offset, limit + + def _default_parse_set_item(self): + return self._parse_set_item_assignment(kind=None) + + def _parse_set_item_assignment(self, kind): + left = self._parse_primary() or self._parse_id_var() + if not self._match(TokenType.EQ): + self.raise_error("Expected =") + right = self._parse_statement() or self._parse_id_var() + + this = self.expression( + exp.EQ, + this=left, + expression=right, + ) + + return self.expression( + exp.SetItem, + this=this, + kind=kind, + ) + + def _parse_set_item_charset(self, kind): + this = self._parse_string() or self._parse_id_var() + + return self.expression( + exp.SetItem, + this=this, + kind=kind, + ) + + def _parse_set_item_names(self): + charset = self._parse_string() or self._parse_id_var() + if self._match_text("COLLATE"): + collate = self._parse_string() or self._parse_id_var() + else: + collate = None + return self.expression( + exp.SetItem, + this=charset, + collate=collate, + kind="NAMES", + ) + + class Generator(generator.Generator): NULL_ORDERING_SUPPORTED = False TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.ILike: no_ilike_sql, @@ -199,6 +409,8 @@ class MySQL(Dialect): exp.StrToDate: _str_to_date_sql, exp.StrToTime: _str_to_date_sql, exp.Trim: _trim_sql, + exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), + exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), } ROOT_PROPERTIES = { @@ -209,4 +421,69 @@ class MySQL(Dialect): exp.SchemaCommentProperty, } - WITH_PROPERTIES = {} + WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set() + + def show_sql(self, expression): + this = f" {expression.name}" + full = " FULL" if expression.args.get("full") else "" + global_ = " GLOBAL" if expression.args.get("global") else "" + + target = self.sql(expression, "target") + target = f" {target}" if target else "" + if expression.name in {"COLUMNS", "INDEX"}: + target = f" FROM{target}" + elif expression.name == "GRANTS": + target = f" FOR{target}" + + db = self._prefixed_sql("FROM", expression, "db") + + like = self._prefixed_sql("LIKE", expression, "like") + where = self.sql(expression, "where") + + types = self.expressions(expression, key="types") + types = f" {types}" if types else types + query = self._prefixed_sql("FOR QUERY", expression, "query") + + if expression.name == "PROFILE": + offset = self._prefixed_sql("OFFSET", expression, "offset") + limit = self._prefixed_sql("LIMIT", expression, "limit") + else: + offset = "" + limit = self._oldstyle_limit_sql(expression) + + log = self._prefixed_sql("IN", expression, "log") + position = self._prefixed_sql("FROM", expression, "position") + + channel = self._prefixed_sql("FOR CHANNEL", expression, "channel") + + if expression.name == "ENGINE": + mutex_or_status = " MUTEX" if expression.args.get("mutex") else " STATUS" + else: + mutex_or_status = "" + + return f"SHOW{full}{global_}{this}{target}{types}{db}{query}{log}{position}{channel}{mutex_or_status}{like}{where}{offset}{limit}" + + def _prefixed_sql(self, prefix, expression, arg): + sql = self.sql(expression, arg) + if not sql: + return "" + return f" {prefix} {sql}" + + def _oldstyle_limit_sql(self, expression): + limit = self.sql(expression, "limit") + offset = self.sql(expression, "offset") + if limit: + limit_offset = f"{offset}, {limit}" if offset else limit + return f" LIMIT {limit_offset}" + return "" + + def setitem_sql(self, expression): + kind = self.sql(expression, "kind") + kind = f"{kind} " if kind else "" + this = self.sql(expression, "this") + collate = self.sql(expression, "collate") + collate = f" COLLATE {collate}" if collate else "" + return f"{kind}{this}{collate}" + + def set_sql(self, expression): + return f"SET {self.expressions(expression)}" |