diff options
author | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
---|---|---|
committer | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
commit | 8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch) | |
tree | 6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/dialects/tsql.py | |
parent | Releasing debian version 19.0.1-1. (diff) | |
download | sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip |
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r-- | sqlglot/dialects/tsql.py | 118 |
1 files changed, 90 insertions, 28 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index a281297..c3d4f0a 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -7,7 +7,9 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + NormalizationStrategy, any_value_to_max_sql, + date_delta_sql, generatedasidentitycolumnconstraint_sql, max_or_greatest, min_or_least, @@ -135,11 +137,7 @@ def _parse_hashbytes(args: t.List) -> exp.Expression: return exp.func("HASHBYTES", *args) -def generate_date_delta_with_unit_sql( - self: TSQL.Generator, expression: exp.DateAdd | exp.DateDiff -) -> str: - func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF" - return self.func(func, expression.text("unit"), expression.expression, expression.this) +DATEPART_ONLY_FORMATS = {"dw", "hour", "quarter"} def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str: @@ -153,6 +151,11 @@ def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToSt ) ) ) + + # There is no format for "quarter" + if fmt.name.lower() in DATEPART_ONLY_FORMATS: + return self.func("DATEPART", fmt.name, expression.this) + return self.func("FORMAT", expression.this, fmt, expression.args.get("culture")) @@ -202,18 +205,50 @@ def _parse_date_delta( return inner_func +def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression: + """Ensures all (unnamed) output columns are aliased for CTEs and Subqueries.""" + alias = expression.args.get("alias") + + if ( + isinstance(expression, (exp.CTE, exp.Subquery)) + and isinstance(alias, exp.TableAlias) + and not alias.columns + ): + from sqlglot.optimizer.qualify_columns import qualify_outputs + + # We keep track of the unaliased column projection indexes instead of the expressions + # themselves, because the latter are going to be replaced by new nodes when the aliases + # are added and hence we won't be able to reach these newly added Alias parents + subqueryable = expression.this + unaliased_column_indexes = ( + i + for i, c in enumerate(subqueryable.selects) + if isinstance(c, exp.Column) and not c.alias + ) + + qualify_outputs(subqueryable) + + # Preserve the quoting information of columns for newly added Alias nodes + subqueryable_selects = subqueryable.selects + for select_index in unaliased_column_indexes: + alias = subqueryable_selects[select_index] + column = alias.this + if isinstance(column.this, exp.Identifier): + alias.args["alias"].set("quoted", column.this.quoted) + + return expression + + class TSQL(Dialect): - RESOLVES_IDENTIFIERS_AS_UPPERCASE = None - NULL_ORDERING = "nulls_are_small" + NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'" SUPPORTS_SEMI_ANTI_JOIN = False LOG_BASE_FIRST = False + TYPED_DIVISION = True + CONCAT_COALESCE = True TIME_MAPPING = { "year": "%Y", - "qq": "%q", - "q": "%q", - "quarter": "%q", "dayofyear": "%j", "day": "%d", "dy": "%d", @@ -320,6 +355,7 @@ class TSQL(Dialect): IDENTIFIERS = ['"', ("[", "]")] QUOTES = ["'", '"'] HEX_STRINGS = [("0x", ""), ("0X", "")] + VAR_SINGLE_TOKENS = {"@", "$", "#"} KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -403,9 +439,7 @@ class TSQL(Dialect): LOG_DEFAULTS_TO_LN = True - CONCAT_NULL_OUTPUTS_STRING = True - - ALTER_TABLE_ADD_COLUMN_KEYWORD = False + ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False def _parse_projections(self) -> t.List[exp.Expression]: """ @@ -433,7 +467,7 @@ class TSQL(Dialect): """ rollback = self._prev.token_type == TokenType.ROLLBACK - self._match_texts({"TRAN", "TRANSACTION"}) + self._match_texts(("TRAN", "TRANSACTION")) this = self._parse_id_var() if rollback: @@ -579,23 +613,35 @@ class TSQL(Dialect): return super()._parse_if() def _parse_unique(self) -> exp.UniqueColumnConstraint: - return self.expression( - exp.UniqueColumnConstraint, - this=None - if self._curr and self._curr.text.upper() in {"CLUSTERED", "NONCLUSTERED"} - else self._parse_schema(self._parse_id_var(any_token=False)), - ) + if self._match_texts(("CLUSTERED", "NONCLUSTERED")): + this = self.CONSTRAINT_PARSERS[self._prev.text.upper()](self) + else: + this = self._parse_schema(self._parse_id_var(any_token=False)) + + return self.expression(exp.UniqueColumnConstraint, this=this) class Generator(generator.Generator): LIMIT_IS_TOP = True QUERY_HINTS = False RETURNING_END = False NVL2_SUPPORTED = False - ALTER_TABLE_ADD_COLUMN_KEYWORD = False + ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False LIMIT_FETCH = "FETCH" COMPUTED_COLUMN_WITH_TYPE = False - SUPPORTS_NESTED_CTES = False CTE_RECURSIVE_KEYWORD_REQUIRED = False + ENSURE_BOOLS = True + NULL_ORDERING_SUPPORTED = False + SUPPORTS_SINGLE_ARG_CONCAT = False + + EXPRESSIONS_WITHOUT_NESTED_CTES = { + exp.Delete, + exp.Insert, + exp.Merge, + exp.Select, + exp.Subquery, + exp.Union, + exp.Update, + } TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -614,14 +660,16 @@ class TSQL(Dialect): **generator.Generator.TRANSFORMS, exp.AnyValue: any_value_to_max_sql, exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY", - exp.DateAdd: generate_date_delta_with_unit_sql, - exp.DateDiff: generate_date_delta_with_unit_sql, + exp.DateAdd: date_delta_sql("DATEADD"), + exp.DateDiff: date_delta_sql("DATEDIFF"), + exp.CTE: transforms.preprocess([qualify_derived_table_outputs]), exp.CurrentDate: rename_func("GETDATE"), exp.CurrentTimestamp: rename_func("GETDATE"), exp.Extract: rename_func("DATEPART"), exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, exp.GroupConcat: _string_agg_sql, exp.If: rename_func("IIF"), + exp.Length: rename_func("LEN"), exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this), exp.Min: min_or_least, @@ -633,15 +681,16 @@ class TSQL(Dialect): transforms.eliminate_qualify, ] ), + exp.Subquery: transforms.preprocess([qualify_derived_table_outputs]), exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this), exp.SHA2: lambda self, e: self.func( - "HASHBYTES", - exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), - e.this, + "HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this ), exp.TemporaryProperty: lambda self, e: "", exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: _format_sql, + exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), + exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"), } @@ -690,8 +739,21 @@ class TSQL(Dialect): table = expression.find(exp.Table) + # Convert CTAS statement to SELECT .. INTO .. if kind == "TABLE" and expression.expression: - sql = f"SELECT * INTO {self.sql(table)} FROM ({self.sql(expression.expression)}) AS temp" + ctas_with = expression.expression.args.get("with") + if ctas_with: + ctas_with = ctas_with.pop() + + subquery = expression.expression + if isinstance(subquery, exp.Subqueryable): + subquery = subquery.subquery() + + select_into = exp.select("*").from_(exp.alias_(subquery, "temp", table=True)) + select_into.set("into", exp.Into(this=table)) + select_into.set("with", ctas_with) + + sql = self.sql(select_into) if exists: identifier = self.sql(exp.Literal.string(exp.table_name(table) if table else "")) |