summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/mysql.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/mysql.py')
-rw-r--r--sqlglot/dialects/mysql.py36
1 files changed, 35 insertions, 1 deletions
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index f9249eb..6327796 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -555,7 +555,26 @@ class MySQL(Dialect):
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
}
- TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
+ UNSIGNED_TYPE_MAPPING = {
+ exp.DataType.Type.UBIGINT: "BIGINT",
+ exp.DataType.Type.UINT: "INT",
+ exp.DataType.Type.UMEDIUMINT: "MEDIUMINT",
+ exp.DataType.Type.USMALLINT: "SMALLINT",
+ exp.DataType.Type.UTINYINT: "TINYINT",
+ }
+
+ TIMESTAMP_TYPE_MAPPING = {
+ exp.DataType.Type.TIMESTAMP: "DATETIME",
+ exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
+ exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP",
+ }
+
+ TYPE_MAPPING = {
+ **generator.Generator.TYPE_MAPPING,
+ **UNSIGNED_TYPE_MAPPING,
+ **TIMESTAMP_TYPE_MAPPING,
+ }
+
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB)
@@ -580,6 +599,18 @@ class MySQL(Dialect):
exp.DataType.Type.VARCHAR: "CHAR",
}
+ TIMESTAMP_FUNC_TYPES = {
+ exp.DataType.Type.TIMESTAMPTZ,
+ exp.DataType.Type.TIMESTAMPLTZ,
+ }
+
+ def datatype_sql(self, expression: exp.DataType) -> str:
+ # https://dev.mysql.com/doc/refman/8.0/en/numeric-type-syntax.html
+ result = super().datatype_sql(expression)
+ if expression.this in self.UNSIGNED_TYPE_MAPPING:
+ result = f"{result} UNSIGNED"
+ return result
+
def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
# MySQL requires simple literal values for its LIMIT clause.
expression = simplify_literal(expression.copy())
@@ -599,6 +630,9 @@ class MySQL(Dialect):
return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})"
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
+ if expression.to.this in self.TIMESTAMP_FUNC_TYPES:
+ return self.func("TIMESTAMP", expression.this)
+
to = self.CAST_MAPPING.get(expression.to.this)
if to: