diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-07 11:39:48 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-07 11:39:48 +0000 |
commit | f73e9af131151f1e058446361c35b05c4c90bf10 (patch) | |
tree | ed425b89f12d3f5e4709290bdc03d876f365bc97 /sqlglot/dialects/tsql.py | |
parent | Releasing debian version 17.12.0-1. (diff) | |
download | sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.tar.xz sqlglot-f73e9af131151f1e058446361c35b05c4c90bf10.zip |
Merging upstream version 18.2.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r-- | sqlglot/dialects/tsql.py | 157 |
1 files changed, 91 insertions, 66 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 131307f..b26f499 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -7,6 +7,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + any_value_to_max_sql, max_or_greatest, min_or_least, parse_date_delta, @@ -79,22 +80,23 @@ def _format_time_lambda( def _parse_format(args: t.List) -> exp.Expression: - assert len(args) == 2 + this = seq_get(args, 0) + fmt = seq_get(args, 1) + culture = seq_get(args, 2) - fmt = args[1] - number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name) + number_fmt = fmt and (fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name)) if number_fmt: - return exp.NumberToStr(this=args[0], format=fmt) + return exp.NumberToStr(this=this, format=fmt, culture=culture) - return exp.TimeToStr( - this=args[0], - format=exp.Literal.string( + if fmt: + fmt = exp.Literal.string( format_time(fmt.name, TSQL.FORMAT_TIME_MAPPING) if len(fmt.name) == 1 else format_time(fmt.name, TSQL.TIME_MAPPING) - ), - ) + ) + + return exp.TimeToStr(this=this, format=fmt, culture=culture) def _parse_eomonth(args: t.List) -> exp.Expression: @@ -130,13 +132,13 @@ def _parse_hashbytes(args: t.List) -> exp.Expression: def generate_date_delta_with_unit_sql( - self: generator.Generator, expression: exp.DateAdd | exp.DateDiff + self: TSQL.Generator, expression: exp.DateAdd | exp.DateDiff ) -> str: func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF" return self.func(func, expression.text("unit"), expression.expression, expression.this) -def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str: +def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str: fmt = ( expression.args["format"] if isinstance(expression, exp.NumberToStr) @@ -147,10 +149,10 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim ) ) ) - return self.func("FORMAT", expression.this, fmt) + return self.func("FORMAT", expression.this, fmt, expression.args.get("culture")) -def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str: +def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str: expression = expression.copy() this = expression.this @@ -332,10 +334,12 @@ class TSQL(Dialect): "SQL_VARIANT": TokenType.VARIANT, "TOP": TokenType.TOP, "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, + "UPDATE STATISTICS": TokenType.COMMAND, "VARCHAR(MAX)": TokenType.TEXT, "XML": TokenType.XML, "OUTPUT": TokenType.RETURNING, "SYSTEM_USER": TokenType.CURRENT_USER, + "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, } class Parser(parser.Parser): @@ -395,7 +399,9 @@ class TSQL(Dialect): CONCAT_NULL_OUTPUTS_STRING = True - def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]: + ALTER_TABLE_ADD_COLUMN_KEYWORD = False + + def _parse_projections(self) -> t.List[exp.Expression]: """ T-SQL supports the syntax alias = expression in the SELECT's projection list, so we transform all parsed Selects to convert their EQ projections into Aliases. @@ -458,43 +464,6 @@ class TSQL(Dialect): return self._parse_as_command(self._prev) - def _parse_system_time(self) -> t.Optional[exp.Expression]: - if not self._match_text_seq("FOR", "SYSTEM_TIME"): - return None - - if self._match_text_seq("AS", "OF"): - system_time = self.expression( - exp.SystemTime, this=self._parse_bitwise(), kind="AS OF" - ) - elif self._match_set((TokenType.FROM, TokenType.BETWEEN)): - kind = self._prev.text - this = self._parse_bitwise() - self._match_texts(("TO", "AND")) - expression = self._parse_bitwise() - system_time = self.expression( - exp.SystemTime, this=this, expression=expression, kind=kind - ) - elif self._match_text_seq("CONTAINED", "IN"): - args = self._parse_wrapped_csv(self._parse_bitwise) - system_time = self.expression( - exp.SystemTime, - this=seq_get(args, 0), - expression=seq_get(args, 1), - kind="CONTAINED IN", - ) - elif self._match(TokenType.ALL): - system_time = self.expression(exp.SystemTime, kind="ALL") - else: - system_time = None - self.raise_error("Unable to parse FOR SYSTEM_TIME clause") - - return system_time - - def _parse_table_parts(self, schema: bool = False) -> exp.Table: - table = super()._parse_table_parts(schema=schema) - table.set("system_time", self._parse_system_time()) - return table - def _parse_returns(self) -> exp.ReturnsProperty: table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS) returns = super()._parse_returns() @@ -589,14 +558,36 @@ class TSQL(Dialect): return create + def _parse_if(self) -> t.Optional[exp.Expression]: + index = self._index + + if self._match_text_seq("OBJECT_ID"): + self._parse_wrapped_csv(self._parse_string) + if self._match_text_seq("IS", "NOT", "NULL") and self._match(TokenType.DROP): + return self._parse_drop(exists=True) + self._retreat(index) + + return super()._parse_if() + + def _parse_unique(self) -> exp.UniqueColumnConstraint: + return self.expression( + exp.UniqueColumnConstraint, + this=None + if self._curr and self._curr.text.upper() in {"CLUSTERED", "NONCLUSTERED"} + else self._parse_schema(self._parse_id_var(any_token=False)), + ) + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True LIMIT_IS_TOP = True QUERY_HINTS = False RETURNING_END = False + NVL2_SUPPORTED = False + ALTER_TABLE_ADD_COLUMN_KEYWORD = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, + exp.DataType.Type.BOOLEAN: "BIT", exp.DataType.Type.DECIMAL: "NUMERIC", exp.DataType.Type.DATETIME: "DATETIME2", exp.DataType.Type.INT: "INTEGER", @@ -607,6 +598,8 @@ class TSQL(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.AnyValue: any_value_to_max_sql, + exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY", exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), @@ -651,25 +644,44 @@ class TSQL(Dialect): return sql - def offset_sql(self, expression: exp.Offset) -> str: - return f"{super().offset_sql(expression)} ROWS" + def create_sql(self, expression: exp.Create) -> str: + expression = expression.copy() + kind = self.sql(expression, "kind").upper() + exists = expression.args.pop("exists", None) + sql = super().create_sql(expression) + + if exists: + table = expression.find(exp.Table) + identifier = self.sql(exp.Literal.string(exp.table_name(table) if table else "")) + if kind == "SCHEMA": + sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = {identifier}) EXEC('{sql}')""" + elif kind == "TABLE": + sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = {identifier}) EXEC('{sql}')""" + elif kind == "INDEX": + index = self.sql(exp.Literal.string(expression.this.text("this"))) + sql = f"""IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id({identifier}) AND name = {index}) EXEC('{sql}')""" + elif expression.args.get("replace"): + sql = sql.replace("CREATE OR REPLACE ", "CREATE OR ALTER ", 1) - def systemtime_sql(self, expression: exp.SystemTime) -> str: - kind = expression.args["kind"] - if kind == "ALL": - return "FOR SYSTEM_TIME ALL" + return sql - start = self.sql(expression, "this") - if kind == "AS OF": - return f"FOR SYSTEM_TIME AS OF {start}" + def offset_sql(self, expression: exp.Offset) -> str: + return f"{super().offset_sql(expression)} ROWS" - end = self.sql(expression, "expression") - if kind == "FROM": - return f"FOR SYSTEM_TIME FROM {start} TO {end}" - if kind == "BETWEEN": - return f"FOR SYSTEM_TIME BETWEEN {start} AND {end}" + def version_sql(self, expression: exp.Version) -> str: + name = "SYSTEM_TIME" if expression.name == "TIMESTAMP" else expression.name + this = f"FOR {name}" + expr = expression.expression + kind = expression.text("kind") + if kind in ("FROM", "BETWEEN"): + args = expr.expressions + sep = "TO" if kind == "FROM" else "AND" + expr_sql = f"{self.sql(seq_get(args, 0))} {sep} {self.sql(seq_get(args, 1))}" + else: + expr_sql = self.sql(expr) - return f"FOR SYSTEM_TIME CONTAINED IN ({start}, {end})" + expr_sql = f" {expr_sql}" if expr_sql else "" + return f"{this} {kind}{expr_sql}" def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str: table = expression.args.get("table") @@ -713,3 +725,16 @@ class TSQL(Dialect): identifier = f"#{identifier}" return identifier + + def constraint_sql(self, expression: exp.Constraint) -> str: + this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True, sep=" ") + return f"CONSTRAINT {this} {expressions}" + + # https://learn.microsoft.com/en-us/answers/questions/448821/create-table-in-sql-server + def generatedasidentitycolumnconstraint_sql( + self, expression: exp.GeneratedAsIdentityColumnConstraint + ) -> str: + start = self.sql(expression, "start") or "1" + increment = self.sql(expression, "increment") or "1" + return f"IDENTITY({start}, {increment})" |