summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/mysql.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot/dialects/mysql.py107
1 files changed, 94 insertions, 13 deletions
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 554241d..59a0a2a 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -66,7 +66,9 @@ def _str_to_date(args: t.List) -> exp.StrToDate:
return exp.StrToDate(this=seq_get(args, 0), format=date_format)
-def _str_to_date_sql(self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime) -> str:
+def _str_to_date_sql(
+ self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate
+) -> str:
date_format = self.format_time(expression)
return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
@@ -86,8 +88,10 @@ def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str:
return f"TRIM({trim_type}{remove_chars}{from_part}{target})"
-def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.DateSub], str]:
- def func(self: MySQL.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _date_add_sql(
+ kind: str,
+) -> t.Callable[[MySQL.Generator, exp.Expression], str]:
+ def func(self: MySQL.Generator, expression: exp.Expression) -> str:
this = self.sql(expression, "this")
unit = expression.text("unit").upper() or "DAY"
return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})"
@@ -95,6 +99,30 @@ def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.D
return func
+def _ts_or_ds_to_date_sql(self: MySQL.Generator, expression: exp.TsOrDsToDate) -> str:
+ time_format = expression.args.get("format")
+ if time_format:
+ return _str_to_date_sql(self, expression)
+ return f"DATE({self.sql(expression, 'this')})"
+
+
+def _remove_ts_or_ds_to_date(
+ to_sql: t.Optional[t.Callable[[MySQL.Generator, exp.Expression], str]] = None,
+ args: t.Tuple[str, ...] = ("this",),
+) -> t.Callable[[MySQL.Generator, exp.Func], str]:
+ def func(self: MySQL.Generator, expression: exp.Func) -> str:
+ expression = expression.copy()
+
+ for arg_key in args:
+ arg = expression.args.get(arg_key)
+ if isinstance(arg, exp.TsOrDsToDate) and not arg.args.get("format"):
+ expression.set(arg_key, arg.this)
+
+ return to_sql(self, expression) if to_sql else self.function_fallback_sql(expression)
+
+ return func
+
+
class MySQL(Dialect):
# https://dev.mysql.com/doc/refman/8.0/en/identifiers.html
IDENTIFIERS_CAN_START_WITH_DIGIT = True
@@ -233,6 +261,7 @@ class MySQL(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
+ "DATE": lambda args: exp.TsOrDsToDate(this=seq_get(args, 0)),
"DATE_ADD": parse_date_delta_with_interval(exp.DateAdd),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"),
"DATE_SUB": parse_date_delta_with_interval(exp.DateSub),
@@ -240,14 +269,33 @@ class MySQL(Dialect):
"ISNULL": isnull_to_is_null,
"LOCATE": locate_to_strposition,
"MONTHNAME": lambda args: exp.TimeToStr(
- this=seq_get(args, 0),
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
format=exp.Literal.string("%B"),
),
"STR_TO_DATE": _str_to_date,
+ "TO_DAYS": lambda args: exp.paren(
+ exp.DateDiff(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ expression=exp.TsOrDsToDate(this=exp.Literal.string("0000-01-01")),
+ unit=exp.var("DAY"),
+ )
+ + 1
+ ),
+ "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "WEEK": lambda args: exp.Week(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)), mode=seq_get(args, 1)
+ ),
+ "WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
}
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
+ "CHAR": lambda self: self._parse_chr(),
"GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat,
this=self._parse_lambda(),
@@ -531,6 +579,18 @@ class MySQL(Dialect):
return super()._parse_type(parse_interval=parse_interval)
+ def _parse_chr(self) -> t.Optional[exp.Expression]:
+ expressions = self._parse_csv(self._parse_conjunction)
+ kwargs: t.Dict[str, t.Any] = {"this": seq_get(expressions, 0)}
+
+ if len(expressions) > 1:
+ kwargs["expressions"] = expressions[1:]
+
+ if self._match(TokenType.USING):
+ kwargs["charset"] = self._parse_var()
+
+ return self.expression(exp.Chr, **kwargs)
+
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False
@@ -544,25 +604,33 @@ class MySQL(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.CurrentDate: no_paren_current_date_sql,
- exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression),
- exp.DateAdd: _date_add_sql("ADD"),
+ exp.DateDiff: _remove_ts_or_ds_to_date(
+ lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression")
+ ),
+ exp.DateAdd: _remove_ts_or_ds_to_date(_date_add_sql("ADD")),
exp.DateStrToDate: datestrtodate_sql,
- exp.DateSub: _date_add_sql("SUB"),
+ exp.DateSub: _remove_ts_or_ds_to_date(_date_add_sql("SUB")),
exp.DateTrunc: _date_trunc_sql,
- exp.DayOfMonth: rename_func("DAYOFMONTH"),
- exp.DayOfWeek: rename_func("DAYOFWEEK"),
- exp.DayOfYear: rename_func("DAYOFYEAR"),
+ exp.Day: _remove_ts_or_ds_to_date(),
+ exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")),
+ exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")),
+ exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")),
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.ILike: no_ilike_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONKeyValue: json_keyvalue_comma_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
+ exp.Month: _remove_ts_or_ds_to_date(),
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
exp.Pivot: no_pivot_sql,
exp.Select: transforms.preprocess(
- [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
+ [
+ transforms.eliminate_distinct_on,
+ transforms.eliminate_semi_and_anti_joins,
+ transforms.eliminate_qualify,
+ ]
),
exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date_sql,
@@ -573,10 +641,16 @@ class MySQL(Dialect):
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),
- exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)),
+ exp.TimeToStr: _remove_ts_or_ds_to_date(
+ lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e))
+ ),
exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
- exp.WeekOfYear: rename_func("WEEKOFYEAR"),
+ exp.TsOrDsAdd: _date_add_sql("ADD"),
+ exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
+ exp.Week: _remove_ts_or_ds_to_date(),
+ exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")),
+ exp.Year: _remove_ts_or_ds_to_date(),
}
UNSIGNED_TYPE_MAPPING = {
@@ -585,6 +659,7 @@ class MySQL(Dialect):
exp.DataType.Type.UMEDIUMINT: "MEDIUMINT",
exp.DataType.Type.USMALLINT: "SMALLINT",
exp.DataType.Type.UTINYINT: "TINYINT",
+ exp.DataType.Type.UDECIMAL: "DECIMAL",
}
TIMESTAMP_TYPE_MAPPING = {
@@ -717,3 +792,9 @@ class MySQL(Dialect):
limit_offset = f"{offset}, {limit}" if offset else limit
return f" LIMIT {limit_offset}"
return ""
+
+ def chr_sql(self, expression: exp.Chr) -> str:
+ this = self.expressions(sqls=[expression.this] + expression.expressions)
+ charset = expression.args.get("charset")
+ using = f" USING {self.sql(charset)}" if charset else ""
+ return f"CHAR({this}{using})"