diff options
Diffstat (limited to 'sqlglot/dialects/mysql.py')
-rw-r--r-- | sqlglot/dialects/mysql.py | 36 |
1 files changed, 35 insertions, 1 deletions
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index f9249eb..6327796 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -555,7 +555,26 @@ class MySQL(Dialect): exp.WeekOfYear: rename_func("WEEKOFYEAR"), } - TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy() + UNSIGNED_TYPE_MAPPING = { + exp.DataType.Type.UBIGINT: "BIGINT", + exp.DataType.Type.UINT: "INT", + exp.DataType.Type.UMEDIUMINT: "MEDIUMINT", + exp.DataType.Type.USMALLINT: "SMALLINT", + exp.DataType.Type.UTINYINT: "TINYINT", + } + + TIMESTAMP_TYPE_MAPPING = { + exp.DataType.Type.TIMESTAMP: "DATETIME", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", + exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP", + } + + TYPE_MAPPING = { + **generator.Generator.TYPE_MAPPING, + **UNSIGNED_TYPE_MAPPING, + **TIMESTAMP_TYPE_MAPPING, + } + TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT) TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT) TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB) @@ -580,6 +599,18 @@ class MySQL(Dialect): exp.DataType.Type.VARCHAR: "CHAR", } + TIMESTAMP_FUNC_TYPES = { + exp.DataType.Type.TIMESTAMPTZ, + exp.DataType.Type.TIMESTAMPLTZ, + } + + def datatype_sql(self, expression: exp.DataType) -> str: + # https://dev.mysql.com/doc/refman/8.0/en/numeric-type-syntax.html + result = super().datatype_sql(expression) + if expression.this in self.UNSIGNED_TYPE_MAPPING: + result = f"{result} UNSIGNED" + return result + def limit_sql(self, expression: exp.Limit, top: bool = False) -> str: # MySQL requires simple literal values for its LIMIT clause. expression = simplify_literal(expression.copy()) @@ -599,6 +630,9 @@ class MySQL(Dialect): return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})" def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: + if expression.to.this in self.TIMESTAMP_FUNC_TYPES: + return self.func("TIMESTAMP", expression.this) + to = self.CAST_MAPPING.get(expression.to.this) if to: |