summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/tsql.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-03 14:11:07 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-03 14:11:07 +0000
commit42a1548cecf48d18233f56e3385cf9c89abcb9c2 (patch)
tree5e0fff4ecbd1fd7dd1022a7580139038df2a824c /sqlglot/dialects/tsql.py
parentReleasing debian version 21.1.2-1. (diff)
downloadsqlglot-42a1548cecf48d18233f56e3385cf9c89abcb9c2.tar.xz
sqlglot-42a1548cecf48d18233f56e3385cf9c89abcb9c2.zip
Merging upstream version 22.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r--sqlglot/dialects/tsql.py189
1 files changed, 117 insertions, 72 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 5955352..b6f491f 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -18,7 +18,6 @@ from sqlglot.dialects.dialect import (
timestrtotime_sql,
trim_sql,
)
-from sqlglot.expressions import DataType
from sqlglot.helper import seq_get
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
@@ -63,6 +62,44 @@ DEFAULT_START_DATE = datetime.date(1900, 1, 1)
BIT_TYPES = {exp.EQ, exp.NEQ, exp.Is, exp.In, exp.Select, exp.Alias}
+# Unsupported options:
+# - OPTIMIZE FOR ( @variable_name { UNKNOWN | = <literal_constant> } [ , ...n ] )
+# - TABLE HINT
+OPTIONS: parser.OPTIONS_TYPE = {
+ **dict.fromkeys(
+ (
+ "DISABLE_OPTIMIZED_PLAN_FORCING",
+ "FAST",
+ "IGNORE_NONCLUSTERED_COLUMNSTORE_INDEX",
+ "LABEL",
+ "MAXDOP",
+ "MAXRECURSION",
+ "MAX_GRANT_PERCENT",
+ "MIN_GRANT_PERCENT",
+ "NO_PERFORMANCE_SPOOL",
+ "QUERYTRACEON",
+ "RECOMPILE",
+ ),
+ tuple(),
+ ),
+ "CONCAT": ("UNION",),
+ "DISABLE": ("EXTERNALPUSHDOWN", "SCALEOUTEXECUTION"),
+ "EXPAND": ("VIEWS",),
+ "FORCE": ("EXTERNALPUSHDOWN", "ORDER", "SCALEOUTEXECUTION"),
+ "HASH": ("GROUP", "JOIN", "UNION"),
+ "KEEP": ("PLAN",),
+ "KEEPFIXED": ("PLAN",),
+ "LOOP": ("JOIN",),
+ "MERGE": ("JOIN", "UNION"),
+ "OPTIMIZE": (("FOR", "UNKNOWN"),),
+ "ORDER": ("GROUP",),
+ "PARAMETERIZATION": ("FORCED", "SIMPLE"),
+ "ROBUST": ("PLAN",),
+ "USE": ("PLAN",),
+}
+
+OPTIONS_THAT_REQUIRE_EQUAL = ("MAX_GRANT_PERCENT", "MIN_GRANT_PERCENT", "LABEL")
+
def _build_formatted_time(
exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None
@@ -221,19 +258,17 @@ def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression:
# We keep track of the unaliased column projection indexes instead of the expressions
# themselves, because the latter are going to be replaced by new nodes when the aliases
# are added and hence we won't be able to reach these newly added Alias parents
- subqueryable = expression.this
+ query = expression.this
unaliased_column_indexes = (
- i
- for i, c in enumerate(subqueryable.selects)
- if isinstance(c, exp.Column) and not c.alias
+ i for i, c in enumerate(query.selects) if isinstance(c, exp.Column) and not c.alias
)
- qualify_outputs(subqueryable)
+ qualify_outputs(query)
# Preserve the quoting information of columns for newly added Alias nodes
- subqueryable_selects = subqueryable.selects
+ query_selects = query.selects
for select_index in unaliased_column_indexes:
- alias = subqueryable_selects[select_index]
+ alias = query_selects[select_index]
column = alias.this
if isinstance(column.this, exp.Identifier):
alias.args["alias"].set("quoted", column.this.quoted)
@@ -420,7 +455,6 @@ class TSQL(Dialect):
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
"NTEXT": TokenType.TEXT,
- "NVARCHAR(MAX)": TokenType.TEXT,
"PRINT": TokenType.COMMAND,
"PROC": TokenType.PROCEDURE,
"REAL": TokenType.FLOAT,
@@ -431,15 +465,24 @@ class TSQL(Dialect):
"TOP": TokenType.TOP,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"UPDATE STATISTICS": TokenType.COMMAND,
- "VARCHAR(MAX)": TokenType.TEXT,
"XML": TokenType.XML,
"OUTPUT": TokenType.RETURNING,
"SYSTEM_USER": TokenType.CURRENT_USER,
"FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
+ "OPTION": TokenType.OPTION,
}
class Parser(parser.Parser):
SET_REQUIRES_ASSIGNMENT_DELIMITER = False
+ LOG_DEFAULTS_TO_LN = True
+ ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
+ STRING_ALIASES = True
+ NO_PAREN_IF_COMMANDS = False
+
+ QUERY_MODIFIER_PARSERS = {
+ **parser.Parser.QUERY_MODIFIER_PARSERS,
+ TokenType.OPTION: lambda self: ("options", self._parse_options()),
+ }
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -472,19 +515,7 @@ class TSQL(Dialect):
"TIMEFROMPARTS": _build_timefromparts,
}
- JOIN_HINTS = {
- "LOOP",
- "HASH",
- "MERGE",
- "REMOTE",
- }
-
- VAR_LENGTH_DATATYPES = {
- DataType.Type.NVARCHAR,
- DataType.Type.VARCHAR,
- DataType.Type.CHAR,
- DataType.Type.NCHAR,
- }
+ JOIN_HINTS = {"LOOP", "HASH", "MERGE", "REMOTE"}
RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - {
TokenType.TABLE,
@@ -496,11 +527,21 @@ class TSQL(Dialect):
TokenType.END: lambda self: self._parse_command(),
}
- LOG_DEFAULTS_TO_LN = True
+ def _parse_options(self) -> t.Optional[t.List[exp.Expression]]:
+ if not self._match(TokenType.OPTION):
+ return None
- ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
- STRING_ALIASES = True
- NO_PAREN_IF_COMMANDS = False
+ def _parse_option() -> t.Optional[exp.Expression]:
+ option = self._parse_var_from_options(OPTIONS)
+ if not option:
+ return None
+
+ self._match(TokenType.EQ)
+ return self.expression(
+ exp.QueryOption, this=option, expression=self._parse_primary_or_var()
+ )
+
+ return self._parse_wrapped_csv(_parse_option)
def _parse_projections(self) -> t.List[exp.Expression]:
"""
@@ -576,48 +617,13 @@ class TSQL(Dialect):
def _parse_convert(
self, strict: bool, safe: t.Optional[bool] = None
) -> t.Optional[exp.Expression]:
- to = self._parse_types()
+ this = self._parse_types()
self._match(TokenType.COMMA)
- this = self._parse_conjunction()
-
- if not to or not this:
- return None
-
- # Retrieve length of datatype and override to default if not specified
- if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
- to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
-
- # Check whether a conversion with format is applicable
- if self._match(TokenType.COMMA):
- format_val = self._parse_number()
- format_val_name = format_val.name if format_val else ""
-
- if format_val_name not in TSQL.CONVERT_FORMAT_MAPPING:
- raise ValueError(
- f"CONVERT function at T-SQL does not support format style {format_val_name}"
- )
-
- format_norm = exp.Literal.string(TSQL.CONVERT_FORMAT_MAPPING[format_val_name])
-
- # Check whether the convert entails a string to date format
- if to.this == DataType.Type.DATE:
- return self.expression(exp.StrToDate, this=this, format=format_norm)
- # Check whether the convert entails a string to datetime format
- elif to.this == DataType.Type.DATETIME:
- return self.expression(exp.StrToTime, this=this, format=format_norm)
- # Check whether the convert entails a date to string format
- elif to.this in self.VAR_LENGTH_DATATYPES:
- return self.expression(
- exp.Cast if strict else exp.TryCast,
- to=to,
- this=self.expression(exp.TimeToStr, this=this, format=format_norm),
- safe=safe,
- )
- elif to.this == DataType.Type.TEXT:
- return self.expression(exp.TimeToStr, this=this, format=format_norm)
-
- # Entails a simple cast without any format requirement
- return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, safe=safe)
+ args = [this, *self._parse_csv(self._parse_conjunction)]
+ convert = exp.Convert.from_arg_list(args)
+ convert.set("safe", safe)
+ convert.set("strict", strict)
+ return convert
def _parse_user_defined_function(
self, kind: t.Optional[TokenType] = None
@@ -683,6 +689,26 @@ class TSQL(Dialect):
return self.expression(exp.UniqueColumnConstraint, this=this)
+ def _parse_partition(self) -> t.Optional[exp.Partition]:
+ if not self._match_text_seq("WITH", "(", "PARTITIONS"):
+ return None
+
+ def parse_range():
+ low = self._parse_bitwise()
+ high = self._parse_bitwise() if self._match_text_seq("TO") else None
+
+ return (
+ self.expression(exp.PartitionRange, this=low, expression=high) if high else low
+ )
+
+ partition = self.expression(
+ exp.Partition, expressions=self._parse_wrapped_csv(parse_range)
+ )
+
+ self._match_r_paren()
+
+ return partition
+
class Generator(generator.Generator):
LIMIT_IS_TOP = True
QUERY_HINTS = False
@@ -728,6 +754,9 @@ class TSQL(Dialect):
exp.DataType.Type.VARIANT: "SQL_VARIANT",
}
+ TYPE_MAPPING.pop(exp.DataType.Type.NCHAR)
+ TYPE_MAPPING.pop(exp.DataType.Type.NVARCHAR)
+
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.AnyValue: any_value_to_max_sql,
@@ -779,6 +808,20 @@ class TSQL(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def convert_sql(self, expression: exp.Convert) -> str:
+ name = "TRY_CONVERT" if expression.args.get("safe") else "CONVERT"
+ return self.func(
+ name, expression.this, expression.expression, expression.args.get("style")
+ )
+
+ def queryoption_sql(self, expression: exp.QueryOption) -> str:
+ option = self.sql(expression, "this")
+ value = self.sql(expression, "expression")
+ if value:
+ optional_equal_sign = "= " if option in OPTIONS_THAT_REQUIRE_EQUAL else ""
+ return f"{option} {optional_equal_sign}{value}"
+ return option
+
def lateral_op(self, expression: exp.Lateral) -> str:
cross_apply = expression.args.get("cross_apply")
if cross_apply is True:
@@ -876,11 +919,10 @@ class TSQL(Dialect):
if ctas_with:
ctas_with = ctas_with.pop()
- subquery = ctas_expression
- if isinstance(subquery, exp.Subqueryable):
- subquery = subquery.subquery()
+ if isinstance(ctas_expression, exp.UNWRAPPED_QUERIES):
+ ctas_expression = ctas_expression.subquery()
- select_into = exp.select("*").from_(exp.alias_(subquery, "temp", table=True))
+ select_into = exp.select("*").from_(exp.alias_(ctas_expression, "temp", table=True))
select_into.set("into", exp.Into(this=table))
select_into.set("with", ctas_with)
@@ -993,3 +1035,6 @@ class TSQL(Dialect):
this_sql = self.sql(this)
expression_sql = self.sql(expression, "expression")
return self.func(name, this_sql, expression_sql if expression_sql else None)
+
+ def partition_sql(self, expression: exp.Partition) -> str:
+ return f"WITH (PARTITIONS({self.expressions(expression, flat=True)}))"