diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/dialects/mysql.py | 31 |
1 files changed, 13 insertions, 18 deletions
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 97c891d..e549f62 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -9,7 +9,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_sql, date_add_interval_sql, datestrtodate_sql, - format_time_lambda, + build_formatted_time, isnull_to_is_null, locate_to_strposition, max_or_greatest, @@ -19,8 +19,8 @@ from sqlglot.dialects.dialect import ( no_pivot_sql, no_tablesample_sql, no_trycast_sql, - parse_date_delta, - parse_date_delta_with_interval, + build_date_delta, + build_date_delta_with_interval, rename_func, strposition_to_locate_sql, ) @@ -39,9 +39,6 @@ def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str: expr = self.sql(expression, "this") unit = expression.text("unit").upper() - if unit == "DAY": - return f"DATE({expr})" - if unit == "WEEK": concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')" date_format = "%Y %u %w" @@ -55,10 +52,11 @@ def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str: concat = f"CONCAT(YEAR({expr}), ' 1 1')" date_format = "%Y %c %e" else: - self.unsupported(f"Unexpected interval unit: {unit}") - return f"DATE({expr})" + if unit != "DAY": + self.unsupported(f"Unexpected interval unit: {unit}") + return self.func("DATE", expr) - return f"STR_TO_DATE({concat}, '{date_format}')" + return self.func("STR_TO_DATE", concat, f"'{date_format}'") # All specifiers for time parts (as opposed to date parts) @@ -93,8 +91,7 @@ def _str_to_date(args: t.List) -> exp.StrToDate | exp.StrToTime: def _str_to_date_sql( self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate ) -> str: - date_format = self.format_time(expression) - return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})" + return self.func("STR_TO_DATE", expression.this, self.format_time(expression)) def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str: @@ -127,9 +124,7 @@ def _date_add_sql( def _ts_or_ds_to_date_sql(self: MySQL.Generator, expression: exp.TsOrDsToDate) -> str: time_format = expression.args.get("format") - if time_format: - return _str_to_date_sql(self, expression) - return f"DATE({self.sql(expression, 'this')})" + return _str_to_date_sql(self, expression) if time_format else self.func("DATE", expression.this) def _remove_ts_or_ds_to_date( @@ -289,9 +284,9 @@ class MySQL(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "DATE": lambda args: exp.TsOrDsToDate(this=seq_get(args, 0)), - "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), - "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"), - "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), + "DATE_ADD": build_date_delta_with_interval(exp.DateAdd), + "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "mysql"), + "DATE_SUB": build_date_delta_with_interval(exp.DateSub), "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), @@ -306,7 +301,7 @@ class MySQL(Dialect): format=exp.Literal.string("%B"), ), "STR_TO_DATE": _str_to_date, - "TIMESTAMPDIFF": parse_date_delta(exp.TimestampDiff), + "TIMESTAMPDIFF": build_date_delta(exp.TimestampDiff), "TO_DAYS": lambda args: exp.paren( exp.DateDiff( this=exp.TsOrDsToDate(this=seq_get(args, 0)), |