diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/dialects/mysql.py | 107 |
1 files changed, 94 insertions, 13 deletions
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 554241d..59a0a2a 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -66,7 +66,9 @@ def _str_to_date(args: t.List) -> exp.StrToDate: return exp.StrToDate(this=seq_get(args, 0), format=date_format) -def _str_to_date_sql(self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime) -> str: +def _str_to_date_sql( + self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate +) -> str: date_format = self.format_time(expression) return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})" @@ -86,8 +88,10 @@ def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str: return f"TRIM({trim_type}{remove_chars}{from_part}{target})" -def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.DateSub], str]: - def func(self: MySQL.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _date_add_sql( + kind: str, +) -> t.Callable[[MySQL.Generator, exp.Expression], str]: + def func(self: MySQL.Generator, expression: exp.Expression) -> str: this = self.sql(expression, "this") unit = expression.text("unit").upper() or "DAY" return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})" @@ -95,6 +99,30 @@ def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.D return func +def _ts_or_ds_to_date_sql(self: MySQL.Generator, expression: exp.TsOrDsToDate) -> str: + time_format = expression.args.get("format") + if time_format: + return _str_to_date_sql(self, expression) + return f"DATE({self.sql(expression, 'this')})" + + +def _remove_ts_or_ds_to_date( + to_sql: t.Optional[t.Callable[[MySQL.Generator, exp.Expression], str]] = None, + args: t.Tuple[str, ...] = ("this",), +) -> t.Callable[[MySQL.Generator, exp.Func], str]: + def func(self: MySQL.Generator, expression: exp.Func) -> str: + expression = expression.copy() + + for arg_key in args: + arg = expression.args.get(arg_key) + if isinstance(arg, exp.TsOrDsToDate) and not arg.args.get("format"): + expression.set(arg_key, arg.this) + + return to_sql(self, expression) if to_sql else self.function_fallback_sql(expression) + + return func + + class MySQL(Dialect): # https://dev.mysql.com/doc/refman/8.0/en/identifiers.html IDENTIFIERS_CAN_START_WITH_DIGIT = True @@ -233,6 +261,7 @@ class MySQL(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, + "DATE": lambda args: exp.TsOrDsToDate(this=seq_get(args, 0)), "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"), "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), @@ -240,14 +269,33 @@ class MySQL(Dialect): "ISNULL": isnull_to_is_null, "LOCATE": locate_to_strposition, "MONTHNAME": lambda args: exp.TimeToStr( - this=seq_get(args, 0), + this=exp.TsOrDsToDate(this=seq_get(args, 0)), format=exp.Literal.string("%B"), ), "STR_TO_DATE": _str_to_date, + "TO_DAYS": lambda args: exp.paren( + exp.DateDiff( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + expression=exp.TsOrDsToDate(this=exp.Literal.string("0000-01-01")), + unit=exp.var("DAY"), + ) + + 1 + ), + "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "WEEK": lambda args: exp.Week( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), mode=seq_get(args, 1) + ), + "WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate(this=seq_get(args, 0))), } FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, + "CHAR": lambda self: self._parse_chr(), "GROUP_CONCAT": lambda self: self.expression( exp.GroupConcat, this=self._parse_lambda(), @@ -531,6 +579,18 @@ class MySQL(Dialect): return super()._parse_type(parse_interval=parse_interval) + def _parse_chr(self) -> t.Optional[exp.Expression]: + expressions = self._parse_csv(self._parse_conjunction) + kwargs: t.Dict[str, t.Any] = {"this": seq_get(expressions, 0)} + + if len(expressions) > 1: + kwargs["expressions"] = expressions[1:] + + if self._match(TokenType.USING): + kwargs["charset"] = self._parse_var() + + return self.expression(exp.Chr, **kwargs) + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = False @@ -544,25 +604,33 @@ class MySQL(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.CurrentDate: no_paren_current_date_sql, - exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression), - exp.DateAdd: _date_add_sql("ADD"), + exp.DateDiff: _remove_ts_or_ds_to_date( + lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression") + ), + exp.DateAdd: _remove_ts_or_ds_to_date(_date_add_sql("ADD")), exp.DateStrToDate: datestrtodate_sql, - exp.DateSub: _date_add_sql("SUB"), + exp.DateSub: _remove_ts_or_ds_to_date(_date_add_sql("SUB")), exp.DateTrunc: _date_trunc_sql, - exp.DayOfMonth: rename_func("DAYOFMONTH"), - exp.DayOfWeek: rename_func("DAYOFWEEK"), - exp.DayOfYear: rename_func("DAYOFYEAR"), + exp.Day: _remove_ts_or_ds_to_date(), + exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")), + exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")), + exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")), exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""", exp.ILike: no_ilike_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONKeyValue: json_keyvalue_comma_sql, exp.Max: max_or_greatest, exp.Min: min_or_least, + exp.Month: _remove_ts_or_ds_to_date(), exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), exp.Pivot: no_pivot_sql, exp.Select: transforms.preprocess( - [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] + [ + transforms.eliminate_distinct_on, + transforms.eliminate_semi_and_anti_joins, + transforms.eliminate_qualify, + ] ), exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date_sql, @@ -573,10 +641,16 @@ class MySQL(Dialect): exp.TimestampSub: date_add_interval_sql("DATE", "SUB"), exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)), - exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)), + exp.TimeToStr: _remove_ts_or_ds_to_date( + lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)) + ), exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, - exp.WeekOfYear: rename_func("WEEKOFYEAR"), + exp.TsOrDsAdd: _date_add_sql("ADD"), + exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.Week: _remove_ts_or_ds_to_date(), + exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")), + exp.Year: _remove_ts_or_ds_to_date(), } UNSIGNED_TYPE_MAPPING = { @@ -585,6 +659,7 @@ class MySQL(Dialect): exp.DataType.Type.UMEDIUMINT: "MEDIUMINT", exp.DataType.Type.USMALLINT: "SMALLINT", exp.DataType.Type.UTINYINT: "TINYINT", + exp.DataType.Type.UDECIMAL: "DECIMAL", } TIMESTAMP_TYPE_MAPPING = { @@ -717,3 +792,9 @@ class MySQL(Dialect): limit_offset = f"{offset}, {limit}" if offset else limit return f" LIMIT {limit_offset}" return "" + + def chr_sql(self, expression: exp.Chr) -> str: + this = self.expressions(sqls=[expression.this] + expression.expressions) + charset = expression.args.get("charset") + using = f" USING {self.sql(charset)}" if charset else "" + return f"CHAR({this}{using})" |