diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-01-30 17:08:37 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-01-30 17:08:37 +0000 |
commit | be1cb18ea28222fca384a5459a024b7e9af5cadb (patch) | |
tree | 4698c9069380a7c30ceb51129f93f6c8662315e4 /sqlglot/dialects/tsql.py | |
parent | Releasing debian version 10.5.6-1. (diff) | |
download | sqlglot-be1cb18ea28222fca384a5459a024b7e9af5cadb.tar.xz sqlglot-be1cb18ea28222fca384a5459a024b7e9af5cadb.zip |
Merging upstream version 10.5.10.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r-- | sqlglot/dialects/tsql.py | 116 |
1 files changed, 109 insertions, 7 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 9342e6b..9f9099e 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +import typing as t from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func @@ -251,6 +252,7 @@ class TSQL(Dialect): "NTEXT": TokenType.TEXT, "NVARCHAR(MAX)": TokenType.TEXT, "PRINT": TokenType.COMMAND, + "PROC": TokenType.PROCEDURE, "REAL": TokenType.FLOAT, "ROWVERSION": TokenType.ROWVERSION, "SMALLDATETIME": TokenType.DATETIME, @@ -263,6 +265,11 @@ class TSQL(Dialect): "XML": TokenType.XML, } + # TSQL allows @, # to appear as a variable/identifier prefix + SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy() + SINGLE_TOKENS.pop("@") + SINGLE_TOKENS.pop("#") + class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore @@ -293,26 +300,82 @@ class TSQL(Dialect): DataType.Type.NCHAR, } - # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table - TABLE_PREFIX_TOKENS = {TokenType.HASH, TokenType.PARAMETER} + RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { # type: ignore + TokenType.TABLE, + *parser.Parser.TYPE_TOKENS, # type: ignore + } + + STATEMENT_PARSERS = { + **parser.Parser.STATEMENT_PARSERS, # type: ignore + TokenType.END: lambda self: self._parse_command(), + } + + def _parse_system_time(self) -> t.Optional[exp.Expression]: + if not self._match_text_seq("FOR", "SYSTEM_TIME"): + return None - def _parse_convert(self, strict): + if self._match_text_seq("AS", "OF"): + system_time = self.expression( + exp.SystemTime, this=self._parse_bitwise(), kind="AS OF" + ) + elif self._match_set((TokenType.FROM, TokenType.BETWEEN)): + kind = self._prev.text + this = self._parse_bitwise() + self._match_texts(("TO", "AND")) + expression = self._parse_bitwise() + system_time = self.expression( + exp.SystemTime, this=this, expression=expression, kind=kind + ) + elif self._match_text_seq("CONTAINED", "IN"): + args = self._parse_wrapped_csv(self._parse_bitwise) + system_time = self.expression( + exp.SystemTime, + this=seq_get(args, 0), + expression=seq_get(args, 1), + kind="CONTAINED IN", + ) + elif self._match(TokenType.ALL): + system_time = self.expression(exp.SystemTime, kind="ALL") + else: + system_time = None + self.raise_error("Unable to parse FOR SYSTEM_TIME clause") + + return system_time + + def _parse_table_parts(self, schema: bool = False) -> exp.Expression: + table = super()._parse_table_parts(schema=schema) + table.set("system_time", self._parse_system_time()) + return table + + def _parse_returns(self) -> exp.Expression: + table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS) + returns = super()._parse_returns() + returns.set("table", table) + return returns + + def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]: to = self._parse_types() self._match(TokenType.COMMA) this = self._parse_conjunction() + if not to or not this: + return None + # Retrieve length of datatype and override to default if not specified if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False) # Check whether a conversion with format is applicable if self._match(TokenType.COMMA): - format_val = self._parse_number().name - if format_val not in TSQL.convert_format_mapping: + format_val = self._parse_number() + format_val_name = format_val.name if format_val else "" + + if format_val_name not in TSQL.convert_format_mapping: raise ValueError( - f"CONVERT function at T-SQL does not support format style {format_val}" + f"CONVERT function at T-SQL does not support format style {format_val_name}" ) - format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val]) + + format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val_name]) # Check whether the convert entails a string to date format if to.this == DataType.Type.DATE: @@ -333,6 +396,21 @@ class TSQL(Dialect): # Entails a simple cast without any format requirement return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) + def _parse_user_defined_function( + self, kind: t.Optional[TokenType] = None + ) -> t.Optional[exp.Expression]: + this = super()._parse_user_defined_function(kind=kind) + + if ( + kind == TokenType.FUNCTION + or isinstance(this, exp.UserDefinedFunction) + or self._match(TokenType.ALIAS, advance=False) + ): + return this + + expressions = self._parse_csv(self._parse_udf_kwarg) + return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) + class Generator(generator.Generator): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore @@ -354,3 +432,27 @@ class TSQL(Dialect): exp.TimeToStr: _format_sql, exp.GroupConcat: _string_agg_sql, } + + TRANSFORMS.pop(exp.ReturnsProperty) + + def systemtime_sql(self, expression: exp.SystemTime) -> str: + kind = expression.args["kind"] + if kind == "ALL": + return "FOR SYSTEM_TIME ALL" + + start = self.sql(expression, "this") + if kind == "AS OF": + return f"FOR SYSTEM_TIME AS OF {start}" + + end = self.sql(expression, "expression") + if kind == "FROM": + return f"FOR SYSTEM_TIME FROM {start} TO {end}" + if kind == "BETWEEN": + return f"FOR SYSTEM_TIME BETWEEN {start} AND {end}" + + return f"FOR SYSTEM_TIME CONTAINED IN ({start}, {end})" + + def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str: + table = expression.args.get("table") + table = f"{table} " if table else "" + return f"RETURNS {table}{self.sql(expression, 'this')}" |