diff options
Diffstat (limited to 'sqlglot/dialects/mysql.py')
-rw-r--r-- | sqlglot/dialects/mysql.py | 64 |
1 files changed, 34 insertions, 30 deletions
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 5342624..2b41860 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, @@ -11,6 +13,7 @@ from sqlglot.dialects.dialect import ( min_or_least, no_ilike_sql, no_paren_current_date_sql, + no_pivot_sql, no_tablesample_sql, no_trycast_sql, parse_date_delta_with_interval, @@ -21,14 +24,14 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _show_parser(*args, **kwargs): - def _parse(self): +def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], exp.Show]: + def _parse(self: MySQL.Parser) -> exp.Show: return self._parse_show_mysql(*args, **kwargs) return _parse -def _date_trunc_sql(self, expression): +def _date_trunc_sql(self: generator.Generator, expression: exp.DateTrunc) -> str: expr = self.sql(expression, "this") unit = expression.text("unit") @@ -54,17 +57,17 @@ def _date_trunc_sql(self, expression): return f"STR_TO_DATE({concat}, '{date_format}')" -def _str_to_date(args): +def _str_to_date(args: t.List) -> exp.StrToDate: date_format = MySQL.format_time(seq_get(args, 1)) return exp.StrToDate(this=seq_get(args, 0), format=date_format) -def _str_to_date_sql(self, expression): +def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate | exp.StrToTime) -> str: date_format = self.format_time(expression) return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})" -def _trim_sql(self, expression): +def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str: target = self.sql(expression, "this") trim_type = self.sql(expression, "position") remove_chars = self.sql(expression, "expression") @@ -79,8 +82,8 @@ def _trim_sql(self, expression): return f"TRIM({trim_type}{remove_chars}{from_part}{target})" -def _date_add_sql(kind): - def func(self, expression): +def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: + def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = expression.text("unit").upper() or "DAY" return ( @@ -175,10 +178,10 @@ class MySQL(Dialect): COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW} class Parser(parser.Parser): - FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} # type: ignore + FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA, TokenType.DATABASE} FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"), "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), @@ -191,7 +194,7 @@ class MySQL(Dialect): } FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, # type: ignore + **parser.Parser.FUNCTION_PARSERS, "GROUP_CONCAT": lambda self: self.expression( exp.GroupConcat, this=self._parse_lambda(), @@ -199,13 +202,8 @@ class MySQL(Dialect): ), } - PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, # type: ignore - "ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty), - } - STATEMENT_PARSERS = { - **parser.Parser.STATEMENT_PARSERS, # type: ignore + **parser.Parser.STATEMENT_PARSERS, TokenType.SHOW: lambda self: self._parse_show(), } @@ -286,7 +284,13 @@ class MySQL(Dialect): LOG_DEFAULTS_TO_LN = True - def _parse_show_mysql(self, this, target=False, full=None, global_=None): + def _parse_show_mysql( + self, + this: str, + target: bool | str = False, + full: t.Optional[bool] = None, + global_: t.Optional[bool] = None, + ) -> exp.Show: if target: if isinstance(target, str): self._match_text_seq(target) @@ -342,10 +346,12 @@ class MySQL(Dialect): offset=offset, limit=limit, mutex=mutex, - **{"global": global_}, + **{"global": global_}, # type: ignore ) - def _parse_oldstyle_limit(self): + def _parse_oldstyle_limit( + self, + ) -> t.Tuple[t.Optional[exp.Expression], t.Optional[exp.Expression]]: limit = None offset = None if self._match_text_seq("LIMIT"): @@ -355,23 +361,20 @@ class MySQL(Dialect): elif len(parts) == 2: limit = parts[1] offset = parts[0] + return offset, limit - def _parse_set_item_charset(self, kind): + def _parse_set_item_charset(self, kind: str) -> exp.Expression: this = self._parse_string() or self._parse_id_var() + return self.expression(exp.SetItem, this=this, kind=kind) - return self.expression( - exp.SetItem, - this=this, - kind=kind, - ) - - def _parse_set_item_names(self): + def _parse_set_item_names(self) -> exp.Expression: charset = self._parse_string() or self._parse_id_var() if self._match_text_seq("COLLATE"): collate = self._parse_string() or self._parse_id_var() else: collate = None + return self.expression( exp.SetItem, this=charset, @@ -386,7 +389,7 @@ class MySQL(Dialect): TABLE_HINTS = False TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.CurrentDate: no_paren_current_date_sql, exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression), exp.DateAdd: _date_add_sql("ADD"), @@ -403,6 +406,7 @@ class MySQL(Dialect): exp.Min: min_or_least, exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), + exp.Pivot: no_pivot_sql, exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date_sql, @@ -422,7 +426,7 @@ class MySQL(Dialect): TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB) PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } |