diff options
Diffstat (limited to 'sqlglot/dialects/mysql.py')
-rw-r--r-- | sqlglot/dialects/mysql.py | 76 |
1 files changed, 56 insertions, 20 deletions
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index e742640..93a60f4 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -5,10 +5,12 @@ import typing as t from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, + locate_to_strposition, no_ilike_sql, no_paren_current_date_sql, no_tablesample_sql, no_trycast_sql, + strposition_to_local_sql, ) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -120,6 +122,7 @@ class MySQL(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "START": TokenType.BEGIN, "SEPARATOR": TokenType.SEPARATOR, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, @@ -172,13 +175,18 @@ class MySQL(Dialect): COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW} class Parser(parser.Parser): - STRICT_CAST = False + FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA} FUNCTIONS = { **parser.Parser.FUNCTIONS, "DATE_ADD": _date_add(exp.DateAdd), "DATE_SUB": _date_add(exp.DateSub), "STR_TO_DATE": _str_to_date, + "LOCATE": locate_to_strposition, + "INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)), + "LEFT": lambda args: exp.Substring( + this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1) + ), } FUNCTION_PARSERS = { @@ -264,6 +272,7 @@ class MySQL(Dialect): "CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"), "CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"), "NAMES": lambda self: self._parse_set_item_names(), + "TRANSACTION": lambda self: self._parse_set_transaction(), } PROFILE_TYPES = { @@ -278,39 +287,48 @@ class MySQL(Dialect): "SWAPS", } + TRANSACTION_CHARACTERISTICS = { + "ISOLATION LEVEL REPEATABLE READ", + "ISOLATION LEVEL READ COMMITTED", + "ISOLATION LEVEL READ UNCOMMITTED", + "ISOLATION LEVEL SERIALIZABLE", + "READ WRITE", + "READ ONLY", + } + def _parse_show_mysql(self, this, target=False, full=None, global_=None): if target: if isinstance(target, str): - self._match_text(target) + self._match_text_seq(target) target_id = self._parse_id_var() else: target_id = None - log = self._parse_string() if self._match_text("IN") else None + log = self._parse_string() if self._match_text_seq("IN") else None if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}: - position = self._parse_number() if self._match_text("FROM") else None + position = self._parse_number() if self._match_text_seq("FROM") else None db = None else: position = None - db = self._parse_id_var() if self._match_text("FROM") else None + db = self._parse_id_var() if self._match_text_seq("FROM") else None - channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None + channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None - like = self._parse_string() if self._match_text("LIKE") else None + like = self._parse_string() if self._match_text_seq("LIKE") else None where = self._parse_where() if this == "PROFILE": - types = self._parse_csv(self._parse_show_profile_type) - query = self._parse_number() if self._match_text("FOR", "QUERY") else None - offset = self._parse_number() if self._match_text("OFFSET") else None - limit = self._parse_number() if self._match_text("LIMIT") else None + types = self._parse_csv(lambda: self._parse_var_from_options(self.PROFILE_TYPES)) + query = self._parse_number() if self._match_text_seq("FOR", "QUERY") else None + offset = self._parse_number() if self._match_text_seq("OFFSET") else None + limit = self._parse_number() if self._match_text_seq("LIMIT") else None else: types, query = None, None offset, limit = self._parse_oldstyle_limit() - mutex = True if self._match_text("MUTEX") else None - mutex = False if self._match_text("STATUS") else mutex + mutex = True if self._match_text_seq("MUTEX") else None + mutex = False if self._match_text_seq("STATUS") else mutex return self.expression( exp.Show, @@ -331,16 +349,16 @@ class MySQL(Dialect): **{"global": global_}, ) - def _parse_show_profile_type(self): - for type_ in self.PROFILE_TYPES: - if self._match_text(*type_.split(" ")): - return exp.Var(this=type_) + def _parse_var_from_options(self, options): + for option in options: + if self._match_text_seq(*option.split(" ")): + return exp.Var(this=option) return None def _parse_oldstyle_limit(self): limit = None offset = None - if self._match_text("LIMIT"): + if self._match_text_seq("LIMIT"): parts = self._parse_csv(self._parse_number) if len(parts) == 1: limit = parts[0] @@ -353,6 +371,9 @@ class MySQL(Dialect): return self._parse_set_item_assignment(kind=None) def _parse_set_item_assignment(self, kind): + if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"): + return self._parse_set_transaction(global_=kind == "GLOBAL") + left = self._parse_primary() or self._parse_id_var() if not self._match(TokenType.EQ): self.raise_error("Expected =") @@ -381,7 +402,7 @@ class MySQL(Dialect): def _parse_set_item_names(self): charset = self._parse_string() or self._parse_id_var() - if self._match_text("COLLATE"): + if self._match_text_seq("COLLATE"): collate = self._parse_string() or self._parse_id_var() else: collate = None @@ -392,6 +413,18 @@ class MySQL(Dialect): kind="NAMES", ) + def _parse_set_transaction(self, global_=False): + self._match_text_seq("TRANSACTION") + characteristics = self._parse_csv( + lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS) + ) + return self.expression( + exp.SetItem, + expressions=characteristics, + kind="TRANSACTION", + **{"global": global_}, + ) + class Generator(generator.Generator): NULL_ORDERING_SUPPORTED = False @@ -411,6 +444,7 @@ class MySQL(Dialect): exp.Trim: _trim_sql, exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), + exp.StrPosition: strposition_to_local_sql, } ROOT_PROPERTIES = { @@ -481,9 +515,11 @@ class MySQL(Dialect): kind = self.sql(expression, "kind") kind = f"{kind} " if kind else "" this = self.sql(expression, "this") + expressions = self.expressions(expression) collate = self.sql(expression, "collate") collate = f" COLLATE {collate}" if collate else "" - return f"{kind}{this}{collate}" + global_ = "GLOBAL " if expression.args.get("global") else "" + return f"{global_}{kind}{this}{expressions}{collate}" def set_sql(self, expression): return f"SET {self.expressions(expression)}" |