summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/mysql.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/mysql.py')
-rw-r--r--sqlglot/dialects/mysql.py38
1 files changed, 31 insertions, 7 deletions
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 2185a85..c78aa9e 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -60,9 +60,33 @@ def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str:
return f"STR_TO_DATE({concat}, '{date_format}')"
-def _str_to_date(args: t.List) -> exp.StrToDate:
- date_format = MySQL.format_time(seq_get(args, 1))
- return exp.StrToDate(this=seq_get(args, 0), format=date_format)
+# All specifiers for time parts (as opposed to date parts)
+# https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_date-format
+TIME_SPECIFIERS = {"f", "H", "h", "I", "i", "k", "l", "p", "r", "S", "s", "T"}
+
+
+def _has_time_specifier(date_format: str) -> bool:
+ i = 0
+ length = len(date_format)
+
+ while i < length:
+ if date_format[i] == "%":
+ i += 1
+ if i < length and date_format[i] in TIME_SPECIFIERS:
+ return True
+ i += 1
+ return False
+
+
+def _str_to_date(args: t.List) -> exp.StrToDate | exp.StrToTime:
+ mysql_date_format = seq_get(args, 1)
+ date_format = MySQL.format_time(mysql_date_format)
+ this = seq_get(args, 0)
+
+ if mysql_date_format and _has_time_specifier(mysql_date_format.name):
+ return exp.StrToTime(this=this, format=date_format)
+
+ return exp.StrToDate(this=this, format=date_format)
def _str_to_date_sql(
@@ -93,7 +117,9 @@ def _date_add_sql(
def func(self: MySQL.Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
- return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
+ return (
+ f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression, unit=unit))})"
+ )
return func
@@ -110,8 +136,6 @@ def _remove_ts_or_ds_to_date(
args: t.Tuple[str, ...] = ("this",),
) -> t.Callable[[MySQL.Generator, exp.Func], str]:
def func(self: MySQL.Generator, expression: exp.Func) -> str:
- expression = expression.copy()
-
for arg_key in args:
arg = expression.args.get(arg_key)
if isinstance(arg, exp.TsOrDsToDate) and not arg.args.get("format"):
@@ -629,6 +653,7 @@ class MySQL(Dialect):
transforms.eliminate_distinct_on,
transforms.eliminate_semi_and_anti_joins,
transforms.eliminate_qualify,
+ transforms.eliminate_full_outer_join,
]
),
exp.StrPosition: strposition_to_locate_sql,
@@ -728,7 +753,6 @@ class MySQL(Dialect):
to = self.CAST_MAPPING.get(expression.to.this)
if to:
- expression = expression.copy()
expression.to.set("this", to)
return super().cast_sql(expression)