diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-03-03 14:11:07 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-03-03 14:11:07 +0000 |
commit | 42a1548cecf48d18233f56e3385cf9c89abcb9c2 (patch) | |
tree | 5e0fff4ecbd1fd7dd1022a7580139038df2a824c /sqlglot/dialects/tsql.py | |
parent | Releasing debian version 21.1.2-1. (diff) | |
download | sqlglot-42a1548cecf48d18233f56e3385cf9c89abcb9c2.tar.xz sqlglot-42a1548cecf48d18233f56e3385cf9c89abcb9c2.zip |
Merging upstream version 22.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r-- | sqlglot/dialects/tsql.py | 189 |
1 files changed, 117 insertions, 72 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 5955352..b6f491f 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -18,7 +18,6 @@ from sqlglot.dialects.dialect import ( timestrtotime_sql, trim_sql, ) -from sqlglot.expressions import DataType from sqlglot.helper import seq_get from sqlglot.time import format_time from sqlglot.tokens import TokenType @@ -63,6 +62,44 @@ DEFAULT_START_DATE = datetime.date(1900, 1, 1) BIT_TYPES = {exp.EQ, exp.NEQ, exp.Is, exp.In, exp.Select, exp.Alias} +# Unsupported options: +# - OPTIMIZE FOR ( @variable_name { UNKNOWN | = <literal_constant> } [ , ...n ] ) +# - TABLE HINT +OPTIONS: parser.OPTIONS_TYPE = { + **dict.fromkeys( + ( + "DISABLE_OPTIMIZED_PLAN_FORCING", + "FAST", + "IGNORE_NONCLUSTERED_COLUMNSTORE_INDEX", + "LABEL", + "MAXDOP", + "MAXRECURSION", + "MAX_GRANT_PERCENT", + "MIN_GRANT_PERCENT", + "NO_PERFORMANCE_SPOOL", + "QUERYTRACEON", + "RECOMPILE", + ), + tuple(), + ), + "CONCAT": ("UNION",), + "DISABLE": ("EXTERNALPUSHDOWN", "SCALEOUTEXECUTION"), + "EXPAND": ("VIEWS",), + "FORCE": ("EXTERNALPUSHDOWN", "ORDER", "SCALEOUTEXECUTION"), + "HASH": ("GROUP", "JOIN", "UNION"), + "KEEP": ("PLAN",), + "KEEPFIXED": ("PLAN",), + "LOOP": ("JOIN",), + "MERGE": ("JOIN", "UNION"), + "OPTIMIZE": (("FOR", "UNKNOWN"),), + "ORDER": ("GROUP",), + "PARAMETERIZATION": ("FORCED", "SIMPLE"), + "ROBUST": ("PLAN",), + "USE": ("PLAN",), +} + +OPTIONS_THAT_REQUIRE_EQUAL = ("MAX_GRANT_PERCENT", "MIN_GRANT_PERCENT", "LABEL") + def _build_formatted_time( exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None @@ -221,19 +258,17 @@ def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression: # We keep track of the unaliased column projection indexes instead of the expressions # themselves, because the latter are going to be replaced by new nodes when the aliases # are added and hence we won't be able to reach these newly added Alias parents - subqueryable = expression.this + query = expression.this unaliased_column_indexes = ( - i - for i, c in enumerate(subqueryable.selects) - if isinstance(c, exp.Column) and not c.alias + i for i, c in enumerate(query.selects) if isinstance(c, exp.Column) and not c.alias ) - qualify_outputs(subqueryable) + qualify_outputs(query) # Preserve the quoting information of columns for newly added Alias nodes - subqueryable_selects = subqueryable.selects + query_selects = query.selects for select_index in unaliased_column_indexes: - alias = subqueryable_selects[select_index] + alias = query_selects[select_index] column = alias.this if isinstance(column.this, exp.Identifier): alias.args["alias"].set("quoted", column.this.quoted) @@ -420,7 +455,6 @@ class TSQL(Dialect): "IMAGE": TokenType.IMAGE, "MONEY": TokenType.MONEY, "NTEXT": TokenType.TEXT, - "NVARCHAR(MAX)": TokenType.TEXT, "PRINT": TokenType.COMMAND, "PROC": TokenType.PROCEDURE, "REAL": TokenType.FLOAT, @@ -431,15 +465,24 @@ class TSQL(Dialect): "TOP": TokenType.TOP, "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, "UPDATE STATISTICS": TokenType.COMMAND, - "VARCHAR(MAX)": TokenType.TEXT, "XML": TokenType.XML, "OUTPUT": TokenType.RETURNING, "SYSTEM_USER": TokenType.CURRENT_USER, "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, + "OPTION": TokenType.OPTION, } class Parser(parser.Parser): SET_REQUIRES_ASSIGNMENT_DELIMITER = False + LOG_DEFAULTS_TO_LN = True + ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False + STRING_ALIASES = True + NO_PAREN_IF_COMMANDS = False + + QUERY_MODIFIER_PARSERS = { + **parser.Parser.QUERY_MODIFIER_PARSERS, + TokenType.OPTION: lambda self: ("options", self._parse_options()), + } FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -472,19 +515,7 @@ class TSQL(Dialect): "TIMEFROMPARTS": _build_timefromparts, } - JOIN_HINTS = { - "LOOP", - "HASH", - "MERGE", - "REMOTE", - } - - VAR_LENGTH_DATATYPES = { - DataType.Type.NVARCHAR, - DataType.Type.VARCHAR, - DataType.Type.CHAR, - DataType.Type.NCHAR, - } + JOIN_HINTS = {"LOOP", "HASH", "MERGE", "REMOTE"} RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { TokenType.TABLE, @@ -496,11 +527,21 @@ class TSQL(Dialect): TokenType.END: lambda self: self._parse_command(), } - LOG_DEFAULTS_TO_LN = True + def _parse_options(self) -> t.Optional[t.List[exp.Expression]]: + if not self._match(TokenType.OPTION): + return None - ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False - STRING_ALIASES = True - NO_PAREN_IF_COMMANDS = False + def _parse_option() -> t.Optional[exp.Expression]: + option = self._parse_var_from_options(OPTIONS) + if not option: + return None + + self._match(TokenType.EQ) + return self.expression( + exp.QueryOption, this=option, expression=self._parse_primary_or_var() + ) + + return self._parse_wrapped_csv(_parse_option) def _parse_projections(self) -> t.List[exp.Expression]: """ @@ -576,48 +617,13 @@ class TSQL(Dialect): def _parse_convert( self, strict: bool, safe: t.Optional[bool] = None ) -> t.Optional[exp.Expression]: - to = self._parse_types() + this = self._parse_types() self._match(TokenType.COMMA) - this = self._parse_conjunction() - - if not to or not this: - return None - - # Retrieve length of datatype and override to default if not specified - if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: - to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False) - - # Check whether a conversion with format is applicable - if self._match(TokenType.COMMA): - format_val = self._parse_number() - format_val_name = format_val.name if format_val else "" - - if format_val_name not in TSQL.CONVERT_FORMAT_MAPPING: - raise ValueError( - f"CONVERT function at T-SQL does not support format style {format_val_name}" - ) - - format_norm = exp.Literal.string(TSQL.CONVERT_FORMAT_MAPPING[format_val_name]) - - # Check whether the convert entails a string to date format - if to.this == DataType.Type.DATE: - return self.expression(exp.StrToDate, this=this, format=format_norm) - # Check whether the convert entails a string to datetime format - elif to.this == DataType.Type.DATETIME: - return self.expression(exp.StrToTime, this=this, format=format_norm) - # Check whether the convert entails a date to string format - elif to.this in self.VAR_LENGTH_DATATYPES: - return self.expression( - exp.Cast if strict else exp.TryCast, - to=to, - this=self.expression(exp.TimeToStr, this=this, format=format_norm), - safe=safe, - ) - elif to.this == DataType.Type.TEXT: - return self.expression(exp.TimeToStr, this=this, format=format_norm) - - # Entails a simple cast without any format requirement - return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, safe=safe) + args = [this, *self._parse_csv(self._parse_conjunction)] + convert = exp.Convert.from_arg_list(args) + convert.set("safe", safe) + convert.set("strict", strict) + return convert def _parse_user_defined_function( self, kind: t.Optional[TokenType] = None @@ -683,6 +689,26 @@ class TSQL(Dialect): return self.expression(exp.UniqueColumnConstraint, this=this) + def _parse_partition(self) -> t.Optional[exp.Partition]: + if not self._match_text_seq("WITH", "(", "PARTITIONS"): + return None + + def parse_range(): + low = self._parse_bitwise() + high = self._parse_bitwise() if self._match_text_seq("TO") else None + + return ( + self.expression(exp.PartitionRange, this=low, expression=high) if high else low + ) + + partition = self.expression( + exp.Partition, expressions=self._parse_wrapped_csv(parse_range) + ) + + self._match_r_paren() + + return partition + class Generator(generator.Generator): LIMIT_IS_TOP = True QUERY_HINTS = False @@ -728,6 +754,9 @@ class TSQL(Dialect): exp.DataType.Type.VARIANT: "SQL_VARIANT", } + TYPE_MAPPING.pop(exp.DataType.Type.NCHAR) + TYPE_MAPPING.pop(exp.DataType.Type.NVARCHAR) + TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.AnyValue: any_value_to_max_sql, @@ -779,6 +808,20 @@ class TSQL(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def convert_sql(self, expression: exp.Convert) -> str: + name = "TRY_CONVERT" if expression.args.get("safe") else "CONVERT" + return self.func( + name, expression.this, expression.expression, expression.args.get("style") + ) + + def queryoption_sql(self, expression: exp.QueryOption) -> str: + option = self.sql(expression, "this") + value = self.sql(expression, "expression") + if value: + optional_equal_sign = "= " if option in OPTIONS_THAT_REQUIRE_EQUAL else "" + return f"{option} {optional_equal_sign}{value}" + return option + def lateral_op(self, expression: exp.Lateral) -> str: cross_apply = expression.args.get("cross_apply") if cross_apply is True: @@ -876,11 +919,10 @@ class TSQL(Dialect): if ctas_with: ctas_with = ctas_with.pop() - subquery = ctas_expression - if isinstance(subquery, exp.Subqueryable): - subquery = subquery.subquery() + if isinstance(ctas_expression, exp.UNWRAPPED_QUERIES): + ctas_expression = ctas_expression.subquery() - select_into = exp.select("*").from_(exp.alias_(subquery, "temp", table=True)) + select_into = exp.select("*").from_(exp.alias_(ctas_expression, "temp", table=True)) select_into.set("into", exp.Into(this=table)) select_into.set("with", ctas_with) @@ -993,3 +1035,6 @@ class TSQL(Dialect): this_sql = self.sql(this) expression_sql = self.sql(expression, "expression") return self.func(name, this_sql, expression_sql if expression_sql else None) + + def partition_sql(self, expression: exp.Partition) -> str: + return f"WITH (PARTITIONS({self.expressions(expression, flat=True)}))" |