summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/tsql.py
diff options
context:
space:
mode:
authorDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
committerDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
commit8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch)
tree6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/dialects/tsql.py
parentReleasing debian version 19.0.1-1. (diff)
downloadsqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz
sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r--sqlglot/dialects/tsql.py118
1 files changed, 90 insertions, 28 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index a281297..c3d4f0a 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -7,7 +7,9 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
+ NormalizationStrategy,
any_value_to_max_sql,
+ date_delta_sql,
generatedasidentitycolumnconstraint_sql,
max_or_greatest,
min_or_least,
@@ -135,11 +137,7 @@ def _parse_hashbytes(args: t.List) -> exp.Expression:
return exp.func("HASHBYTES", *args)
-def generate_date_delta_with_unit_sql(
- self: TSQL.Generator, expression: exp.DateAdd | exp.DateDiff
-) -> str:
- func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF"
- return self.func(func, expression.text("unit"), expression.expression, expression.this)
+DATEPART_ONLY_FORMATS = {"dw", "hour", "quarter"}
def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
@@ -153,6 +151,11 @@ def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToSt
)
)
)
+
+ # There is no format for "quarter"
+ if fmt.name.lower() in DATEPART_ONLY_FORMATS:
+ return self.func("DATEPART", fmt.name, expression.this)
+
return self.func("FORMAT", expression.this, fmt, expression.args.get("culture"))
@@ -202,18 +205,50 @@ def _parse_date_delta(
return inner_func
+def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression:
+ """Ensures all (unnamed) output columns are aliased for CTEs and Subqueries."""
+ alias = expression.args.get("alias")
+
+ if (
+ isinstance(expression, (exp.CTE, exp.Subquery))
+ and isinstance(alias, exp.TableAlias)
+ and not alias.columns
+ ):
+ from sqlglot.optimizer.qualify_columns import qualify_outputs
+
+ # We keep track of the unaliased column projection indexes instead of the expressions
+ # themselves, because the latter are going to be replaced by new nodes when the aliases
+ # are added and hence we won't be able to reach these newly added Alias parents
+ subqueryable = expression.this
+ unaliased_column_indexes = (
+ i
+ for i, c in enumerate(subqueryable.selects)
+ if isinstance(c, exp.Column) and not c.alias
+ )
+
+ qualify_outputs(subqueryable)
+
+ # Preserve the quoting information of columns for newly added Alias nodes
+ subqueryable_selects = subqueryable.selects
+ for select_index in unaliased_column_indexes:
+ alias = subqueryable_selects[select_index]
+ column = alias.this
+ if isinstance(column.this, exp.Identifier):
+ alias.args["alias"].set("quoted", column.this.quoted)
+
+ return expression
+
+
class TSQL(Dialect):
- RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
- NULL_ORDERING = "nulls_are_small"
+ NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
SUPPORTS_SEMI_ANTI_JOIN = False
LOG_BASE_FIRST = False
+ TYPED_DIVISION = True
+ CONCAT_COALESCE = True
TIME_MAPPING = {
"year": "%Y",
- "qq": "%q",
- "q": "%q",
- "quarter": "%q",
"dayofyear": "%j",
"day": "%d",
"dy": "%d",
@@ -320,6 +355,7 @@ class TSQL(Dialect):
IDENTIFIERS = ['"', ("[", "]")]
QUOTES = ["'", '"']
HEX_STRINGS = [("0x", ""), ("0X", "")]
+ VAR_SINGLE_TOKENS = {"@", "$", "#"}
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -403,9 +439,7 @@ class TSQL(Dialect):
LOG_DEFAULTS_TO_LN = True
- CONCAT_NULL_OUTPUTS_STRING = True
-
- ALTER_TABLE_ADD_COLUMN_KEYWORD = False
+ ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
def _parse_projections(self) -> t.List[exp.Expression]:
"""
@@ -433,7 +467,7 @@ class TSQL(Dialect):
"""
rollback = self._prev.token_type == TokenType.ROLLBACK
- self._match_texts({"TRAN", "TRANSACTION"})
+ self._match_texts(("TRAN", "TRANSACTION"))
this = self._parse_id_var()
if rollback:
@@ -579,23 +613,35 @@ class TSQL(Dialect):
return super()._parse_if()
def _parse_unique(self) -> exp.UniqueColumnConstraint:
- return self.expression(
- exp.UniqueColumnConstraint,
- this=None
- if self._curr and self._curr.text.upper() in {"CLUSTERED", "NONCLUSTERED"}
- else self._parse_schema(self._parse_id_var(any_token=False)),
- )
+ if self._match_texts(("CLUSTERED", "NONCLUSTERED")):
+ this = self.CONSTRAINT_PARSERS[self._prev.text.upper()](self)
+ else:
+ this = self._parse_schema(self._parse_id_var(any_token=False))
+
+ return self.expression(exp.UniqueColumnConstraint, this=this)
class Generator(generator.Generator):
LIMIT_IS_TOP = True
QUERY_HINTS = False
RETURNING_END = False
NVL2_SUPPORTED = False
- ALTER_TABLE_ADD_COLUMN_KEYWORD = False
+ ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False
LIMIT_FETCH = "FETCH"
COMPUTED_COLUMN_WITH_TYPE = False
- SUPPORTS_NESTED_CTES = False
CTE_RECURSIVE_KEYWORD_REQUIRED = False
+ ENSURE_BOOLS = True
+ NULL_ORDERING_SUPPORTED = False
+ SUPPORTS_SINGLE_ARG_CONCAT = False
+
+ EXPRESSIONS_WITHOUT_NESTED_CTES = {
+ exp.Delete,
+ exp.Insert,
+ exp.Merge,
+ exp.Select,
+ exp.Subquery,
+ exp.Union,
+ exp.Update,
+ }
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -614,14 +660,16 @@ class TSQL(Dialect):
**generator.Generator.TRANSFORMS,
exp.AnyValue: any_value_to_max_sql,
exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY",
- exp.DateAdd: generate_date_delta_with_unit_sql,
- exp.DateDiff: generate_date_delta_with_unit_sql,
+ exp.DateAdd: date_delta_sql("DATEADD"),
+ exp.DateDiff: date_delta_sql("DATEDIFF"),
+ exp.CTE: transforms.preprocess([qualify_derived_table_outputs]),
exp.CurrentDate: rename_func("GETDATE"),
exp.CurrentTimestamp: rename_func("GETDATE"),
exp.Extract: rename_func("DATEPART"),
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
+ exp.Length: rename_func("LEN"),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
@@ -633,15 +681,16 @@ class TSQL(Dialect):
transforms.eliminate_qualify,
]
),
+ exp.Subquery: transforms.preprocess([qualify_derived_table_outputs]),
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
- "HASHBYTES",
- exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"),
- e.this,
+ "HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
),
exp.TemporaryProperty: lambda self, e: "",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: _format_sql,
+ exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
+ exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"),
}
@@ -690,8 +739,21 @@ class TSQL(Dialect):
table = expression.find(exp.Table)
+ # Convert CTAS statement to SELECT .. INTO ..
if kind == "TABLE" and expression.expression:
- sql = f"SELECT * INTO {self.sql(table)} FROM ({self.sql(expression.expression)}) AS temp"
+ ctas_with = expression.expression.args.get("with")
+ if ctas_with:
+ ctas_with = ctas_with.pop()
+
+ subquery = expression.expression
+ if isinstance(subquery, exp.Subqueryable):
+ subquery = subquery.subquery()
+
+ select_into = exp.select("*").from_(exp.alias_(subquery, "temp", table=True))
+ select_into.set("into", exp.Into(this=table))
+ select_into.set("with", ctas_with)
+
+ sql = self.sql(select_into)
if exists:
identifier = self.sql(exp.Literal.string(exp.table_name(table) if table else ""))