summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/tsql.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:48 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:48 +0000
commitf73e9af131151f1e058446361c35b05c4c90bf10 (patch)
treeed425b89f12d3f5e4709290bdc03d876f365bc97 /sqlglot/dialects/tsql.py
parentReleasing debian version 17.12.0-1. (diff)
downloadsqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.tar.xz
sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.zip
Merging upstream version 18.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r--sqlglot/dialects/tsql.py157
1 files changed, 91 insertions, 66 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 131307f..b26f499 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -7,6 +7,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
+ any_value_to_max_sql,
max_or_greatest,
min_or_least,
parse_date_delta,
@@ -79,22 +80,23 @@ def _format_time_lambda(
def _parse_format(args: t.List) -> exp.Expression:
- assert len(args) == 2
+ this = seq_get(args, 0)
+ fmt = seq_get(args, 1)
+ culture = seq_get(args, 2)
- fmt = args[1]
- number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name)
+ number_fmt = fmt and (fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name))
if number_fmt:
- return exp.NumberToStr(this=args[0], format=fmt)
+ return exp.NumberToStr(this=this, format=fmt, culture=culture)
- return exp.TimeToStr(
- this=args[0],
- format=exp.Literal.string(
+ if fmt:
+ fmt = exp.Literal.string(
format_time(fmt.name, TSQL.FORMAT_TIME_MAPPING)
if len(fmt.name) == 1
else format_time(fmt.name, TSQL.TIME_MAPPING)
- ),
- )
+ )
+
+ return exp.TimeToStr(this=this, format=fmt, culture=culture)
def _parse_eomonth(args: t.List) -> exp.Expression:
@@ -130,13 +132,13 @@ def _parse_hashbytes(args: t.List) -> exp.Expression:
def generate_date_delta_with_unit_sql(
- self: generator.Generator, expression: exp.DateAdd | exp.DateDiff
+ self: TSQL.Generator, expression: exp.DateAdd | exp.DateDiff
) -> str:
func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF"
return self.func(func, expression.text("unit"), expression.expression, expression.this)
-def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
+def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
fmt = (
expression.args["format"]
if isinstance(expression, exp.NumberToStr)
@@ -147,10 +149,10 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim
)
)
)
- return self.func("FORMAT", expression.this, fmt)
+ return self.func("FORMAT", expression.this, fmt, expression.args.get("culture"))
-def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
+def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
expression = expression.copy()
this = expression.this
@@ -332,10 +334,12 @@ class TSQL(Dialect):
"SQL_VARIANT": TokenType.VARIANT,
"TOP": TokenType.TOP,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
+ "UPDATE STATISTICS": TokenType.COMMAND,
"VARCHAR(MAX)": TokenType.TEXT,
"XML": TokenType.XML,
"OUTPUT": TokenType.RETURNING,
"SYSTEM_USER": TokenType.CURRENT_USER,
+ "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT,
}
class Parser(parser.Parser):
@@ -395,7 +399,9 @@ class TSQL(Dialect):
CONCAT_NULL_OUTPUTS_STRING = True
- def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]:
+ ALTER_TABLE_ADD_COLUMN_KEYWORD = False
+
+ def _parse_projections(self) -> t.List[exp.Expression]:
"""
T-SQL supports the syntax alias = expression in the SELECT's projection list,
so we transform all parsed Selects to convert their EQ projections into Aliases.
@@ -458,43 +464,6 @@ class TSQL(Dialect):
return self._parse_as_command(self._prev)
- def _parse_system_time(self) -> t.Optional[exp.Expression]:
- if not self._match_text_seq("FOR", "SYSTEM_TIME"):
- return None
-
- if self._match_text_seq("AS", "OF"):
- system_time = self.expression(
- exp.SystemTime, this=self._parse_bitwise(), kind="AS OF"
- )
- elif self._match_set((TokenType.FROM, TokenType.BETWEEN)):
- kind = self._prev.text
- this = self._parse_bitwise()
- self._match_texts(("TO", "AND"))
- expression = self._parse_bitwise()
- system_time = self.expression(
- exp.SystemTime, this=this, expression=expression, kind=kind
- )
- elif self._match_text_seq("CONTAINED", "IN"):
- args = self._parse_wrapped_csv(self._parse_bitwise)
- system_time = self.expression(
- exp.SystemTime,
- this=seq_get(args, 0),
- expression=seq_get(args, 1),
- kind="CONTAINED IN",
- )
- elif self._match(TokenType.ALL):
- system_time = self.expression(exp.SystemTime, kind="ALL")
- else:
- system_time = None
- self.raise_error("Unable to parse FOR SYSTEM_TIME clause")
-
- return system_time
-
- def _parse_table_parts(self, schema: bool = False) -> exp.Table:
- table = super()._parse_table_parts(schema=schema)
- table.set("system_time", self._parse_system_time())
- return table
-
def _parse_returns(self) -> exp.ReturnsProperty:
table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS)
returns = super()._parse_returns()
@@ -589,14 +558,36 @@ class TSQL(Dialect):
return create
+ def _parse_if(self) -> t.Optional[exp.Expression]:
+ index = self._index
+
+ if self._match_text_seq("OBJECT_ID"):
+ self._parse_wrapped_csv(self._parse_string)
+ if self._match_text_seq("IS", "NOT", "NULL") and self._match(TokenType.DROP):
+ return self._parse_drop(exists=True)
+ self._retreat(index)
+
+ return super()._parse_if()
+
+ def _parse_unique(self) -> exp.UniqueColumnConstraint:
+ return self.expression(
+ exp.UniqueColumnConstraint,
+ this=None
+ if self._curr and self._curr.text.upper() in {"CLUSTERED", "NONCLUSTERED"}
+ else self._parse_schema(self._parse_id_var(any_token=False)),
+ )
+
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
LIMIT_IS_TOP = True
QUERY_HINTS = False
RETURNING_END = False
+ NVL2_SUPPORTED = False
+ ALTER_TABLE_ADD_COLUMN_KEYWORD = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
+ exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.DATETIME: "DATETIME2",
exp.DataType.Type.INT: "INTEGER",
@@ -607,6 +598,8 @@ class TSQL(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
+ exp.AnyValue: any_value_to_max_sql,
+ exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY",
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
@@ -651,25 +644,44 @@ class TSQL(Dialect):
return sql
- def offset_sql(self, expression: exp.Offset) -> str:
- return f"{super().offset_sql(expression)} ROWS"
+ def create_sql(self, expression: exp.Create) -> str:
+ expression = expression.copy()
+ kind = self.sql(expression, "kind").upper()
+ exists = expression.args.pop("exists", None)
+ sql = super().create_sql(expression)
+
+ if exists:
+ table = expression.find(exp.Table)
+ identifier = self.sql(exp.Literal.string(exp.table_name(table) if table else ""))
+ if kind == "SCHEMA":
+ sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = {identifier}) EXEC('{sql}')"""
+ elif kind == "TABLE":
+ sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = {identifier}) EXEC('{sql}')"""
+ elif kind == "INDEX":
+ index = self.sql(exp.Literal.string(expression.this.text("this")))
+ sql = f"""IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id({identifier}) AND name = {index}) EXEC('{sql}')"""
+ elif expression.args.get("replace"):
+ sql = sql.replace("CREATE OR REPLACE ", "CREATE OR ALTER ", 1)
- def systemtime_sql(self, expression: exp.SystemTime) -> str:
- kind = expression.args["kind"]
- if kind == "ALL":
- return "FOR SYSTEM_TIME ALL"
+ return sql
- start = self.sql(expression, "this")
- if kind == "AS OF":
- return f"FOR SYSTEM_TIME AS OF {start}"
+ def offset_sql(self, expression: exp.Offset) -> str:
+ return f"{super().offset_sql(expression)} ROWS"
- end = self.sql(expression, "expression")
- if kind == "FROM":
- return f"FOR SYSTEM_TIME FROM {start} TO {end}"
- if kind == "BETWEEN":
- return f"FOR SYSTEM_TIME BETWEEN {start} AND {end}"
+ def version_sql(self, expression: exp.Version) -> str:
+ name = "SYSTEM_TIME" if expression.name == "TIMESTAMP" else expression.name
+ this = f"FOR {name}"
+ expr = expression.expression
+ kind = expression.text("kind")
+ if kind in ("FROM", "BETWEEN"):
+ args = expr.expressions
+ sep = "TO" if kind == "FROM" else "AND"
+ expr_sql = f"{self.sql(seq_get(args, 0))} {sep} {self.sql(seq_get(args, 1))}"
+ else:
+ expr_sql = self.sql(expr)
- return f"FOR SYSTEM_TIME CONTAINED IN ({start}, {end})"
+ expr_sql = f" {expr_sql}" if expr_sql else ""
+ return f"{this} {kind}{expr_sql}"
def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str:
table = expression.args.get("table")
@@ -713,3 +725,16 @@ class TSQL(Dialect):
identifier = f"#{identifier}"
return identifier
+
+ def constraint_sql(self, expression: exp.Constraint) -> str:
+ this = self.sql(expression, "this")
+ expressions = self.expressions(expression, flat=True, sep=" ")
+ return f"CONSTRAINT {this} {expressions}"
+
+ # https://learn.microsoft.com/en-us/answers/questions/448821/create-table-in-sql-server
+ def generatedasidentitycolumnconstraint_sql(
+ self, expression: exp.GeneratedAsIdentityColumnConstraint
+ ) -> str:
+ start = self.sql(expression, "start") or "1"
+ increment = self.sql(expression, "increment") or "1"
+ return f"IDENTITY({start}, {increment})"