summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/tsql.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r--sqlglot/dialects/tsql.py116
1 files changed, 109 insertions, 7 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 9342e6b..9f9099e 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import re
+import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func
@@ -251,6 +252,7 @@ class TSQL(Dialect):
"NTEXT": TokenType.TEXT,
"NVARCHAR(MAX)": TokenType.TEXT,
"PRINT": TokenType.COMMAND,
+ "PROC": TokenType.PROCEDURE,
"REAL": TokenType.FLOAT,
"ROWVERSION": TokenType.ROWVERSION,
"SMALLDATETIME": TokenType.DATETIME,
@@ -263,6 +265,11 @@ class TSQL(Dialect):
"XML": TokenType.XML,
}
+ # TSQL allows @, # to appear as a variable/identifier prefix
+ SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy()
+ SINGLE_TOKENS.pop("@")
+ SINGLE_TOKENS.pop("#")
+
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
@@ -293,26 +300,82 @@ class TSQL(Dialect):
DataType.Type.NCHAR,
}
- # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table
- TABLE_PREFIX_TOKENS = {TokenType.HASH, TokenType.PARAMETER}
+ RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { # type: ignore
+ TokenType.TABLE,
+ *parser.Parser.TYPE_TOKENS, # type: ignore
+ }
+
+ STATEMENT_PARSERS = {
+ **parser.Parser.STATEMENT_PARSERS, # type: ignore
+ TokenType.END: lambda self: self._parse_command(),
+ }
+
+ def _parse_system_time(self) -> t.Optional[exp.Expression]:
+ if not self._match_text_seq("FOR", "SYSTEM_TIME"):
+ return None
- def _parse_convert(self, strict):
+ if self._match_text_seq("AS", "OF"):
+ system_time = self.expression(
+ exp.SystemTime, this=self._parse_bitwise(), kind="AS OF"
+ )
+ elif self._match_set((TokenType.FROM, TokenType.BETWEEN)):
+ kind = self._prev.text
+ this = self._parse_bitwise()
+ self._match_texts(("TO", "AND"))
+ expression = self._parse_bitwise()
+ system_time = self.expression(
+ exp.SystemTime, this=this, expression=expression, kind=kind
+ )
+ elif self._match_text_seq("CONTAINED", "IN"):
+ args = self._parse_wrapped_csv(self._parse_bitwise)
+ system_time = self.expression(
+ exp.SystemTime,
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
+ kind="CONTAINED IN",
+ )
+ elif self._match(TokenType.ALL):
+ system_time = self.expression(exp.SystemTime, kind="ALL")
+ else:
+ system_time = None
+ self.raise_error("Unable to parse FOR SYSTEM_TIME clause")
+
+ return system_time
+
+ def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
+ table = super()._parse_table_parts(schema=schema)
+ table.set("system_time", self._parse_system_time())
+ return table
+
+ def _parse_returns(self) -> exp.Expression:
+ table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS)
+ returns = super()._parse_returns()
+ returns.set("table", table)
+ return returns
+
+ def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]:
to = self._parse_types()
self._match(TokenType.COMMA)
this = self._parse_conjunction()
+ if not to or not this:
+ return None
+
# Retrieve length of datatype and override to default if not specified
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
# Check whether a conversion with format is applicable
if self._match(TokenType.COMMA):
- format_val = self._parse_number().name
- if format_val not in TSQL.convert_format_mapping:
+ format_val = self._parse_number()
+ format_val_name = format_val.name if format_val else ""
+
+ if format_val_name not in TSQL.convert_format_mapping:
raise ValueError(
- f"CONVERT function at T-SQL does not support format style {format_val}"
+ f"CONVERT function at T-SQL does not support format style {format_val_name}"
)
- format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val])
+
+ format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val_name])
# Check whether the convert entails a string to date format
if to.this == DataType.Type.DATE:
@@ -333,6 +396,21 @@ class TSQL(Dialect):
# Entails a simple cast without any format requirement
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
+ def _parse_user_defined_function(
+ self, kind: t.Optional[TokenType] = None
+ ) -> t.Optional[exp.Expression]:
+ this = super()._parse_user_defined_function(kind=kind)
+
+ if (
+ kind == TokenType.FUNCTION
+ or isinstance(this, exp.UserDefinedFunction)
+ or self._match(TokenType.ALIAS, advance=False)
+ ):
+ return this
+
+ expressions = self._parse_csv(self._parse_udf_kwarg)
+ return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
+
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@@ -354,3 +432,27 @@ class TSQL(Dialect):
exp.TimeToStr: _format_sql,
exp.GroupConcat: _string_agg_sql,
}
+
+ TRANSFORMS.pop(exp.ReturnsProperty)
+
+ def systemtime_sql(self, expression: exp.SystemTime) -> str:
+ kind = expression.args["kind"]
+ if kind == "ALL":
+ return "FOR SYSTEM_TIME ALL"
+
+ start = self.sql(expression, "this")
+ if kind == "AS OF":
+ return f"FOR SYSTEM_TIME AS OF {start}"
+
+ end = self.sql(expression, "expression")
+ if kind == "FROM":
+ return f"FOR SYSTEM_TIME FROM {start} TO {end}"
+ if kind == "BETWEEN":
+ return f"FOR SYSTEM_TIME BETWEEN {start} AND {end}"
+
+ return f"FOR SYSTEM_TIME CONTAINED IN ({start}, {end})"
+
+ def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str:
+ table = expression.args.get("table")
+ table = f"{table} " if table else ""
+ return f"RETURNS {table}{self.sql(expression, 'this')}"