diff options
Diffstat (limited to 'sqlglot/dialects/mysql.py')
-rw-r--r-- | sqlglot/dialects/mysql.py | 82 |
1 files changed, 56 insertions, 26 deletions
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 6ebae1e..1d53346 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -23,6 +23,7 @@ from sqlglot.dialects.dialect import ( build_date_delta_with_interval, rename_func, strposition_to_locate_sql, + unit_to_var, ) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -109,14 +110,14 @@ def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str: return f"TRIM({trim_type}{remove_chars}{from_part}{target})" -def _date_add_sql( +def date_add_sql( kind: str, -) -> t.Callable[[MySQL.Generator, exp.Expression], str]: - def func(self: MySQL.Generator, expression: exp.Expression) -> str: - this = self.sql(expression, "this") - unit = expression.text("unit").upper() or "DAY" - return ( - f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" +) -> t.Callable[[generator.Generator, exp.Expression], str]: + def func(self: generator.Generator, expression: exp.Expression) -> str: + return self.func( + f"DATE_{kind}", + expression.this, + exp.Interval(this=expression.expression, unit=unit_to_var(expression)), ) return func @@ -291,6 +292,7 @@ class MySQL(Dialect): "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)), + "FROM_UNIXTIME": build_formatted_time(exp.UnixToTime, "mysql"), "ISNULL": isnull_to_is_null, "LOCATE": locate_to_strposition, "MAKETIME": exp.TimeFromParts.from_arg_list, @@ -319,11 +321,7 @@ class MySQL(Dialect): FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, "CHAR": lambda self: self._parse_chr(), - "GROUP_CONCAT": lambda self: self.expression( - exp.GroupConcat, - this=self._parse_lambda(), - separator=self._match(TokenType.SEPARATOR) and self._parse_field(), - ), + "GROUP_CONCAT": lambda self: self._parse_group_concat(), # https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values "VALUES": lambda self: self.expression( exp.Anonymous, this="VALUES", expressions=[self._parse_id_var()] @@ -412,6 +410,11 @@ class MySQL(Dialect): "SPATIAL": lambda self: self._parse_index_constraint(kind="SPATIAL"), } + ALTER_PARSERS = { + **parser.Parser.ALTER_PARSERS, + "MODIFY": lambda self: self._parse_alter_table_alter(), + } + SCHEMA_UNNAMED_CONSTRAINTS = { *parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS, "FULLTEXT", @@ -458,7 +461,7 @@ class MySQL(Dialect): this = self._parse_id_var(any_token=False) index_type = self._match(TokenType.USING) and self._advance_any() and self._prev.text - schema = self._parse_schema() + expressions = self._parse_wrapped_csv(self._parse_ordered) options = [] while True: @@ -478,9 +481,6 @@ class MySQL(Dialect): elif self._match_text_seq("ENGINE_ATTRIBUTE"): self._match(TokenType.EQ) opt = exp.IndexConstraintOption(engine_attr=self._parse_string()) - elif self._match_text_seq("ENGINE_ATTRIBUTE"): - self._match(TokenType.EQ) - opt = exp.IndexConstraintOption(engine_attr=self._parse_string()) elif self._match_text_seq("SECONDARY_ENGINE_ATTRIBUTE"): self._match(TokenType.EQ) opt = exp.IndexConstraintOption(secondary_engine_attr=self._parse_string()) @@ -495,7 +495,7 @@ class MySQL(Dialect): return self.expression( exp.IndexColumnConstraint, this=this, - schema=schema, + expressions=expressions, kind=kind, index_type=index_type, options=options, @@ -617,6 +617,39 @@ class MySQL(Dialect): return self.expression(exp.Chr, **kwargs) + def _parse_group_concat(self) -> t.Optional[exp.Expression]: + def concat_exprs( + node: t.Optional[exp.Expression], exprs: t.List[exp.Expression] + ) -> exp.Expression: + if isinstance(node, exp.Distinct) and len(node.expressions) > 1: + concat_exprs = [ + self.expression(exp.Concat, expressions=node.expressions, safe=True) + ] + node.set("expressions", concat_exprs) + return node + if len(exprs) == 1: + return exprs[0] + return self.expression(exp.Concat, expressions=args, safe=True) + + args = self._parse_csv(self._parse_lambda) + + if args: + order = args[-1] if isinstance(args[-1], exp.Order) else None + + if order: + # Order By is the last (or only) expression in the list and has consumed the 'expr' before it, + # remove 'expr' from exp.Order and add it back to args + args[-1] = order.this + order.set("this", concat_exprs(order.this, args)) + + this = order or concat_exprs(args[0], args) + else: + this = None + + separator = self._parse_field() if self._match(TokenType.SEPARATOR) else None + + return self.expression(exp.GroupConcat, this=this, separator=separator) + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = None @@ -630,6 +663,7 @@ class MySQL(Dialect): JSON_TYPE_REQUIRED_FOR_EXTRACTION = True JSON_PATH_BRACKETED_KEY_SUPPORTED = False JSON_KEY_VALUE_PAIR_SEP = "," + SUPPORTS_TO_NUMBER = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -637,9 +671,9 @@ class MySQL(Dialect): exp.DateDiff: _remove_ts_or_ds_to_date( lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression") ), - exp.DateAdd: _remove_ts_or_ds_to_date(_date_add_sql("ADD")), + exp.DateAdd: _remove_ts_or_ds_to_date(date_add_sql("ADD")), exp.DateStrToDate: datestrtodate_sql, - exp.DateSub: _remove_ts_or_ds_to_date(_date_add_sql("SUB")), + exp.DateSub: _remove_ts_or_ds_to_date(date_add_sql("SUB")), exp.DateTrunc: _date_trunc_sql, exp.Day: _remove_ts_or_ds_to_date(), exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")), @@ -672,7 +706,7 @@ class MySQL(Dialect): exp.TimeFromParts: rename_func("MAKETIME"), exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"), exp.TimestampDiff: lambda self, e: self.func( - "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this + "TIMESTAMPDIFF", unit_to_var(e), e.expression, e.this ), exp.TimestampSub: date_add_interval_sql("DATE", "SUB"), exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), @@ -682,9 +716,10 @@ class MySQL(Dialect): ), exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, - exp.TsOrDsAdd: _date_add_sql("ADD"), + exp.TsOrDsAdd: date_add_sql("ADD"), exp.TsOrDsDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression), exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.UnixToTime: lambda self, e: self.func("FROM_UNIXTIME", e.this, self.format_time(e)), exp.Week: _remove_ts_or_ds_to_date(), exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")), exp.Year: _remove_ts_or_ds_to_date(), @@ -751,11 +786,6 @@ class MySQL(Dialect): result = f"{result} UNSIGNED" return result - def xor_sql(self, expression: exp.Xor) -> str: - if expression.expressions: - return self.expressions(expression, sep=" XOR ") - return super().xor_sql(expression) - def jsonarraycontains_sql(self, expression: exp.JSONArrayContains) -> str: return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})" |