From 7db33518a4264e422294a1e20fbd1c1505d9a62d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Tue, 12 Sep 2023 10:28:54 +0200 Subject: Merging upstream version 18.3.0. Signed-off-by: Daniel Baumann --- sqlglot/dialects/doris.py | 2 ++ sqlglot/dialects/mysql.py | 36 +++++++++++++++++++++++++++++++++++- sqlglot/dialects/postgres.py | 6 +++++- sqlglot/dialects/spark.py | 7 +++++++ sqlglot/dialects/teradata.py | 8 ++++++++ sqlglot/dialects/tsql.py | 14 +++++++++++++- sqlglot/expressions.py | 15 ++++++--------- sqlglot/generator.py | 8 ++++---- sqlglot/parser.py | 43 +++++++++++++++++++++++++++++++++++-------- sqlglot/tokens.py | 1 + sqlglot/transforms.py | 2 +- 11 files changed, 117 insertions(+), 25 deletions(-) (limited to 'sqlglot') diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py index 4b8919c..bd7e0f2 100644 --- a/sqlglot/dialects/doris.py +++ b/sqlglot/dialects/doris.py @@ -33,6 +33,8 @@ class Doris(MySQL): exp.DataType.Type.TIMESTAMPTZ: "DATETIME", } + TIMESTAMP_FUNC_TYPES = set() + TRANSFORMS = { **MySQL.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index f9249eb..6327796 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -555,7 +555,26 @@ class MySQL(Dialect): exp.WeekOfYear: rename_func("WEEKOFYEAR"), } - TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy() + UNSIGNED_TYPE_MAPPING = { + exp.DataType.Type.UBIGINT: "BIGINT", + exp.DataType.Type.UINT: "INT", + exp.DataType.Type.UMEDIUMINT: "MEDIUMINT", + exp.DataType.Type.USMALLINT: "SMALLINT", + exp.DataType.Type.UTINYINT: "TINYINT", + } + + TIMESTAMP_TYPE_MAPPING = { + exp.DataType.Type.TIMESTAMP: "DATETIME", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", + exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP", + } + + TYPE_MAPPING = { + **generator.Generator.TYPE_MAPPING, + **UNSIGNED_TYPE_MAPPING, + **TIMESTAMP_TYPE_MAPPING, + } + TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT) TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT) TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB) @@ -580,6 +599,18 @@ class MySQL(Dialect): exp.DataType.Type.VARCHAR: "CHAR", } + TIMESTAMP_FUNC_TYPES = { + exp.DataType.Type.TIMESTAMPTZ, + exp.DataType.Type.TIMESTAMPLTZ, + } + + def datatype_sql(self, expression: exp.DataType) -> str: + # https://dev.mysql.com/doc/refman/8.0/en/numeric-type-syntax.html + result = super().datatype_sql(expression) + if expression.this in self.UNSIGNED_TYPE_MAPPING: + result = f"{result} UNSIGNED" + return result + def limit_sql(self, expression: exp.Limit, top: bool = False) -> str: # MySQL requires simple literal values for its LIMIT clause. expression = simplify_literal(expression.copy()) @@ -599,6 +630,9 @@ class MySQL(Dialect): return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})" def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: + if expression.to.this in self.TIMESTAMP_FUNC_TYPES: + return self.func("TIMESTAMP", expression.this) + to = self.CAST_MAPPING.get(expression.to.this) if to: diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index c26e121..5027013 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -190,7 +190,11 @@ def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression: if isinstance(expression, exp.Merge): alias = expression.this.args.get("alias") - normalize = lambda identifier: Postgres.normalize_identifier(identifier).name + normalize = ( + lambda identifier: Postgres.normalize_identifier(identifier).name + if identifier + else None + ) targets = {normalize(expression.this.this)} diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index a4435f6..9d4a1ab 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -35,6 +35,13 @@ def _parse_datediff(args: t.List) -> exp.Expression: class Spark(Spark2): + class Tokenizer(Spark2.Tokenizer): + RAW_STRINGS = [ + (prefix + q, q) + for q in t.cast(t.List[str], Spark2.Tokenizer.QUOTES) + for prefix in ("r", "R") + ] + class Parser(Spark2.Parser): FUNCTIONS = { **Spark2.Parser.FUNCTIONS, diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 163cc13..d9de968 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -45,6 +45,7 @@ class Teradata(Dialect): "MOD": TokenType.MOD, "NE": TokenType.NEQ, "NOT=": TokenType.NEQ, + "SAMPLE": TokenType.TABLE_SAMPLE, "SEL": TokenType.SELECT, "ST_GEOMETRY": TokenType.GEOMETRY, "TOP": TokenType.TOP, @@ -55,6 +56,8 @@ class Teradata(Dialect): SINGLE_TOKENS.pop("%") class Parser(parser.Parser): + TABLESAMPLE_CSV = True + CHARSET_TRANSLATORS = { "GRAPHIC_TO_KANJISJIS", "GRAPHIC_TO_LATIN", @@ -171,6 +174,11 @@ class Teradata(Dialect): exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}", } + def tablesample_sql( + self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " + ) -> str: + return f"{self.sql(expression, 'this')} SAMPLE {self.expressions(expression)}" + def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str: return f"PARTITION BY {self.sql(expression, 'this')}" diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index b26f499..19c586e 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -57,6 +57,8 @@ TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"} DEFAULT_START_DATE = datetime.date(1900, 1, 1) +BIT_TYPES = {exp.EQ, exp.NEQ, exp.Is, exp.In, exp.Select, exp.Alias} + def _format_time_lambda( exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None @@ -584,6 +586,7 @@ class TSQL(Dialect): RETURNING_END = False NVL2_SUPPORTED = False ALTER_TABLE_ADD_COLUMN_KEYWORD = False + LIMIT_FETCH = "FETCH" TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -630,7 +633,16 @@ class TSQL(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } - LIMIT_FETCH = "FETCH" + def boolean_sql(self, expression: exp.Boolean) -> str: + if type(expression.parent) in BIT_TYPES: + return "1" if expression.this else "0" + + return "(1 = 1)" if expression.this else "(1 = 0)" + + def is_sql(self, expression: exp.Is) -> str: + if isinstance(expression.expression, exp.Boolean): + return self.binary(expression, "=") + return self.binary(expression, "IS") def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: sql = self.sql(expression, "this") diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 0479da0..877e9fd 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -3350,6 +3350,7 @@ class Subquery(DerivedTable, Unionable): class TableSample(Expression): arg_types = { "this": False, + "expressions": False, "method": False, "bucket_numerator": False, "bucket_denominator": False, @@ -3542,6 +3543,7 @@ class DataType(Expression): UINT = auto() UINT128 = auto() UINT256 = auto() + UMEDIUMINT = auto() UNIQUEIDENTIFIER = auto() UNKNOWN = auto() # Sentinel value, useful for type annotation USERDEFINED = "USER-DEFINED" @@ -3708,7 +3710,7 @@ class Rollback(Expression): class AlterTable(Expression): - arg_types = {"this": True, "actions": True, "exists": False} + arg_types = {"this": True, "actions": True, "exists": False, "only": False} class AddConstraint(Expression): @@ -3992,16 +3994,11 @@ class TimeUnit(Expression): super().__init__(**args) -# https://www.oracletutorial.com/oracle-basics/oracle-interval/ -# https://trino.io/docs/current/language/types.html#interval-year-to-month -class IntervalYearToMonthSpan(Expression): - arg_types = {} - - # https://www.oracletutorial.com/oracle-basics/oracle-interval/ # https://trino.io/docs/current/language/types.html#interval-day-to-second -class IntervalDayToSecondSpan(Expression): - arg_types = {} +# https://docs.databricks.com/en/sql/language-manual/data-types/interval-type.html +class IntervalSpan(Expression): + arg_types = {"this": True, "expression": True} class Interval(TimeUnit): diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 306df81..1074e9a 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -72,8 +72,7 @@ class Generator: exp.ExternalProperty: lambda self, e: "EXTERNAL", exp.HeapProperty: lambda self, e: "HEAP", exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}", - exp.IntervalDayToSecondSpan: "DAY TO SECOND", - exp.IntervalYearToMonthSpan: "YEAR TO MONTH", + exp.IntervalSpan: lambda self, e: f"{self.sql(e, 'this')} TO {self.sql(e, 'expression')}", exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e), exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", @@ -953,7 +952,7 @@ class Generator: def filter_sql(self, expression: exp.Filter) -> str: this = self.sql(expression, "this") - where = self.sql(expression, "expression")[1:] # where has a leading space + where = self.sql(expression, "expression").strip() return f"{this} FILTER({where})" def hint_sql(self, expression: exp.Hint) -> str: @@ -2290,7 +2289,8 @@ class Generator: actions = self.expressions(expression, key="actions") exists = " IF EXISTS" if expression.args.get("exists") else "" - return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}" + only = " ONLY" if expression.args.get("only") else "" + return f"ALTER TABLE{exists}{only} {self.sql(expression, 'this')} {actions}" def droppartition_sql(self, expression: exp.DropPartition) -> str: expressions = self.expressions(expression) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index f8690d5..939303f 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -137,6 +137,7 @@ class Parser(metaclass=_Parser): TokenType.INT256, TokenType.UINT256, TokenType.MEDIUMINT, + TokenType.UMEDIUMINT, TokenType.FIXEDSTRING, TokenType.FLOAT, TokenType.DOUBLE, @@ -206,6 +207,14 @@ class Parser(metaclass=_Parser): *NESTED_TYPE_TOKENS, } + SIGNED_TO_UNSIGNED_TYPE_TOKEN = { + TokenType.BIGINT: TokenType.UBIGINT, + TokenType.INT: TokenType.UINT, + TokenType.MEDIUMINT: TokenType.UMEDIUMINT, + TokenType.SMALLINT: TokenType.USMALLINT, + TokenType.TINYINT: TokenType.UTINYINT, + } + SUBQUERY_PREDICATES = { TokenType.ANY: exp.Any, TokenType.ALL: exp.All, @@ -856,6 +865,9 @@ class Parser(metaclass=_Parser): # Whether or not ADD is present for each column added by ALTER TABLE ALTER_TABLE_ADD_COLUMN_KEYWORD = True + # Whether or not the table sample clause expects CSV syntax + TABLESAMPLE_CSV = False + __slots__ = ( "error_level", "error_message_context", @@ -2672,7 +2684,12 @@ class Parser(metaclass=_Parser): self._match(TokenType.L_PAREN) - num = self._parse_number() + if self.TABLESAMPLE_CSV: + num = None + expressions = self._parse_csv(self._parse_primary) + else: + expressions = None + num = self._parse_number() if self._match_text_seq("BUCKET"): bucket_numerator = self._parse_number() @@ -2684,7 +2701,7 @@ class Parser(metaclass=_Parser): percent = num elif self._match(TokenType.ROWS): rows = num - else: + elif num: size = num self._match(TokenType.R_PAREN) @@ -2698,6 +2715,7 @@ class Parser(metaclass=_Parser): return self.expression( exp.TableSample, + expressions=expressions, method=method, bucket_numerator=bucket_numerator, bucket_denominator=bucket_denominator, @@ -3325,15 +3343,14 @@ class Parser(metaclass=_Parser): elif self._match_text_seq("WITHOUT", "TIME", "ZONE"): maybe_func = False elif type_token == TokenType.INTERVAL: - if self._match_text_seq("YEAR", "TO", "MONTH"): - span: t.Optional[t.List[exp.Expression]] = [exp.IntervalYearToMonthSpan()] - elif self._match_text_seq("DAY", "TO", "SECOND"): - span = [exp.IntervalDayToSecondSpan()] + unit = self._parse_var() + + if self._match_text_seq("TO"): + span = [exp.IntervalSpan(this=unit, expression=self._parse_var())] else: span = None - unit = not span and self._parse_var() - if not unit: + if span or not unit: this = self.expression( exp.DataType, this=exp.DataType.Type.INTERVAL, expressions=span ) @@ -3351,6 +3368,13 @@ class Parser(metaclass=_Parser): self._retreat(index2) if not this: + if self._match_text_seq("UNSIGNED"): + unsigned_type_token = self.SIGNED_TO_UNSIGNED_TYPE_TOKEN.get(type_token) + if not unsigned_type_token: + self.raise_error(f"Cannot convert {type_token.value} to unsigned.") + + type_token = unsigned_type_token or type_token + this = exp.DataType( this=exp.DataType.Type[type_token.value], expressions=expressions, @@ -4761,6 +4785,7 @@ class Parser(metaclass=_Parser): return self._parse_as_command(start) exists = self._parse_exists() + only = self._match_text_seq("ONLY") this = self._parse_table(schema=True) if self._next: @@ -4776,7 +4801,9 @@ class Parser(metaclass=_Parser): this=this, exists=exists, actions=actions, + only=only, ) + return self._parse_as_command(start) def _parse_merge(self) -> exp.Merge: diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 83b97d6..3ba8195 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -86,6 +86,7 @@ class TokenType(AutoName): SMALLINT = auto() USMALLINT = auto() MEDIUMINT = auto() + UMEDIUMINT = auto() INT = auto() UINT = auto() BIGINT = auto() diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 48ea8dc..66ab884 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -76,7 +76,7 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: return ( exp.select(*outer_selects) - .from_(expression.subquery()) + .from_(expression.subquery("_t")) .where(exp.column(row_number).eq(1)) ) -- cgit v1.2.3