diff options
Diffstat (limited to 'sqlglot/dialects/mysql.py')
-rw-r--r-- | sqlglot/dialects/mysql.py | 38 |
1 files changed, 31 insertions, 7 deletions
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 2185a85..c78aa9e 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -60,9 +60,33 @@ def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str: return f"STR_TO_DATE({concat}, '{date_format}')" -def _str_to_date(args: t.List) -> exp.StrToDate: - date_format = MySQL.format_time(seq_get(args, 1)) - return exp.StrToDate(this=seq_get(args, 0), format=date_format) +# All specifiers for time parts (as opposed to date parts) +# https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_date-format +TIME_SPECIFIERS = {"f", "H", "h", "I", "i", "k", "l", "p", "r", "S", "s", "T"} + + +def _has_time_specifier(date_format: str) -> bool: + i = 0 + length = len(date_format) + + while i < length: + if date_format[i] == "%": + i += 1 + if i < length and date_format[i] in TIME_SPECIFIERS: + return True + i += 1 + return False + + +def _str_to_date(args: t.List) -> exp.StrToDate | exp.StrToTime: + mysql_date_format = seq_get(args, 1) + date_format = MySQL.format_time(mysql_date_format) + this = seq_get(args, 0) + + if mysql_date_format and _has_time_specifier(mysql_date_format.name): + return exp.StrToTime(this=this, format=date_format) + + return exp.StrToDate(this=this, format=date_format) def _str_to_date_sql( @@ -93,7 +117,9 @@ def _date_add_sql( def func(self: MySQL.Generator, expression: exp.Expression) -> str: this = self.sql(expression, "this") unit = expression.text("unit").upper() or "DAY" - return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})" + return ( + f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})" + ) return func @@ -110,8 +136,6 @@ def _remove_ts_or_ds_to_date( args: t.Tuple[str, ...] = ("this",), ) -> t.Callable[[MySQL.Generator, exp.Func], str]: def func(self: MySQL.Generator, expression: exp.Func) -> str: - expression = expression.copy() - for arg_key in args: arg = expression.args.get(arg_key) if isinstance(arg, exp.TsOrDsToDate) and not arg.args.get("format"): @@ -629,6 +653,7 @@ class MySQL(Dialect): transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins, transforms.eliminate_qualify, + transforms.eliminate_full_outer_join, ] ), exp.StrPosition: strposition_to_locate_sql, @@ -728,7 +753,6 @@ class MySQL(Dialect): to = self.CAST_MAPPING.get(expression.to.this) if to: - expression = expression.copy() expression.to.set("this", to) return super().cast_sql(expression) |