summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/mysql.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/mysql.py')
-rw-r--r--sqlglot/dialects/mysql.py31
1 files changed, 13 insertions, 18 deletions
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 97c891d..e549f62 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -9,7 +9,7 @@ from sqlglot.dialects.dialect import (
arrow_json_extract_sql,
date_add_interval_sql,
datestrtodate_sql,
- format_time_lambda,
+ build_formatted_time,
isnull_to_is_null,
locate_to_strposition,
max_or_greatest,
@@ -19,8 +19,8 @@ from sqlglot.dialects.dialect import (
no_pivot_sql,
no_tablesample_sql,
no_trycast_sql,
- parse_date_delta,
- parse_date_delta_with_interval,
+ build_date_delta,
+ build_date_delta_with_interval,
rename_func,
strposition_to_locate_sql,
)
@@ -39,9 +39,6 @@ def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str:
expr = self.sql(expression, "this")
unit = expression.text("unit").upper()
- if unit == "DAY":
- return f"DATE({expr})"
-
if unit == "WEEK":
concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')"
date_format = "%Y %u %w"
@@ -55,10 +52,11 @@ def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str:
concat = f"CONCAT(YEAR({expr}), ' 1 1')"
date_format = "%Y %c %e"
else:
- self.unsupported(f"Unexpected interval unit: {unit}")
- return f"DATE({expr})"
+ if unit != "DAY":
+ self.unsupported(f"Unexpected interval unit: {unit}")
+ return self.func("DATE", expr)
- return f"STR_TO_DATE({concat}, '{date_format}')"
+ return self.func("STR_TO_DATE", concat, f"'{date_format}'")
# All specifiers for time parts (as opposed to date parts)
@@ -93,8 +91,7 @@ def _str_to_date(args: t.List) -> exp.StrToDate | exp.StrToTime:
def _str_to_date_sql(
self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
) -> str:
- date_format = self.format_time(expression)
- return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
+ return self.func("STR_TO_DATE", expression.this, self.format_time(expression))
def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str:
@@ -127,9 +124,7 @@ def _date_add_sql(
def _ts_or_ds_to_date_sql(self: MySQL.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = expression.args.get("format")
- if time_format:
- return _str_to_date_sql(self, expression)
- return f"DATE({self.sql(expression, 'this')})"
+ return _str_to_date_sql(self, expression) if time_format else self.func("DATE", expression.this)
def _remove_ts_or_ds_to_date(
@@ -289,9 +284,9 @@ class MySQL(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE": lambda args: exp.TsOrDsToDate(this=seq_get(args, 0)),
- "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
- "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
- "DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
+ "DATE_ADD": build_date_delta_with_interval(exp.DateAdd),
+ "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "mysql"),
+ "DATE_SUB": build_date_delta_with_interval(exp.DateSub),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
@@ -306,7 +301,7 @@ class MySQL(Dialect):
format=exp.Literal.string("%B"),
),
"STR_TO_DATE": _str_to_date,
- "TIMESTAMPDIFF": parse_date_delta(exp.TimestampDiff),
+ "TIMESTAMPDIFF": build_date_delta(exp.TimestampDiff),
"TO_DAYS": lambda args: exp.paren(
exp.DateDiff(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),