summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/tsql.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r--sqlglot/dialects/tsql.py101
1 files changed, 94 insertions, 7 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 01d5001..0eb0906 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import datetime
import re
import typing as t
@@ -10,6 +11,7 @@ from sqlglot.dialects.dialect import (
min_or_least,
parse_date_delta,
rename_func,
+ timestrtotime_sql,
)
from sqlglot.expressions import DataType
from sqlglot.helper import seq_get
@@ -52,6 +54,8 @@ DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{
# N = Numeric, C=Currency
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
+DEFAULT_START_DATE = datetime.date(1900, 1, 1)
+
def _format_time_lambda(
exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None
@@ -166,6 +170,34 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s
return f"STRING_AGG({self.format_args(this, separator)}){order}"
+def _parse_date_delta(
+ exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
+) -> t.Callable[[t.List], E]:
+ def inner_func(args: t.List) -> E:
+ unit = seq_get(args, 0)
+ if unit and unit_mapping:
+ unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name))
+
+ start_date = seq_get(args, 1)
+ if start_date and start_date.is_number:
+ # Numeric types are valid DATETIME values
+ if start_date.is_int:
+ adds = DEFAULT_START_DATE + datetime.timedelta(days=int(start_date.this))
+ start_date = exp.Literal.string(adds.strftime("%F"))
+ else:
+ # We currently don't handle float values, i.e. they're not converted to equivalent DATETIMEs.
+ # This is not a problem when generating T-SQL code, it is when transpiling to other dialects.
+ return exp_class(this=seq_get(args, 2), expression=start_date, unit=unit)
+
+ return exp_class(
+ this=exp.TimeStrToTime(this=seq_get(args, 2)),
+ expression=exp.TimeStrToTime(this=start_date),
+ unit=unit,
+ )
+
+ return inner_func
+
+
class TSQL(Dialect):
RESOLVES_IDENTIFIERS_AS_UPPERCASE = None
NULL_ORDERING = "nulls_are_small"
@@ -298,7 +330,6 @@ class TSQL(Dialect):
"SMALLDATETIME": TokenType.DATETIME,
"SMALLMONEY": TokenType.SMALLMONEY,
"SQL_VARIANT": TokenType.VARIANT,
- "TIME": TokenType.TIMESTAMP,
"TOP": TokenType.TOP,
"UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
"VARCHAR(MAX)": TokenType.TEXT,
@@ -307,10 +338,6 @@ class TSQL(Dialect):
"SYSTEM_USER": TokenType.CURRENT_USER,
}
- # TSQL allows @, # to appear as a variable/identifier prefix
- SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy()
- SINGLE_TOKENS.pop("#")
-
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
@@ -320,7 +347,7 @@ class TSQL(Dialect):
position=seq_get(args, 2),
),
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
- "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
+ "DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": _format_time_lambda(exp.TimeToStr),
"EOMONTH": _parse_eomonth,
@@ -518,6 +545,36 @@ class TSQL(Dialect):
expressions = self._parse_csv(self._parse_function_parameter)
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
+ def _parse_id_var(
+ self,
+ any_token: bool = True,
+ tokens: t.Optional[t.Collection[TokenType]] = None,
+ ) -> t.Optional[exp.Expression]:
+ is_temporary = self._match(TokenType.HASH)
+ is_global = is_temporary and self._match(TokenType.HASH)
+
+ this = super()._parse_id_var(any_token=any_token, tokens=tokens)
+ if this:
+ if is_global:
+ this.set("global", True)
+ elif is_temporary:
+ this.set("temporary", True)
+
+ return this
+
+ def _parse_create(self) -> exp.Create | exp.Command:
+ create = super()._parse_create()
+
+ if isinstance(create, exp.Create):
+ table = create.this.this if isinstance(create.this, exp.Schema) else create.this
+ if isinstance(table, exp.Table) and table.this.args.get("temporary"):
+ if not create.args.get("properties"):
+ create.set("properties", exp.Properties(expressions=[]))
+
+ create.args["properties"].append("expressions", exp.TemporaryProperty())
+
+ return create
+
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
LIMIT_IS_TOP = True
@@ -526,9 +583,11 @@ class TSQL(Dialect):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
- exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.DATETIME: "DATETIME2",
+ exp.DataType.Type.INT: "INTEGER",
+ exp.DataType.Type.TIMESTAMP: "DATETIME2",
+ exp.DataType.Type.TIMESTAMPTZ: "DATETIMEOFFSET",
exp.DataType.Type.VARIANT: "SQL_VARIANT",
}
@@ -552,6 +611,8 @@ class TSQL(Dialect):
exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"),
e.this,
),
+ exp.TemporaryProperty: lambda self, e: "",
+ exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: _format_sql,
}
@@ -564,6 +625,22 @@ class TSQL(Dialect):
LIMIT_FETCH = "FETCH"
+ def createable_sql(
+ self,
+ expression: exp.Create,
+ locations: dict[exp.Properties.Location, list[exp.Property]],
+ ) -> str:
+ sql = self.sql(expression, "this")
+ properties = expression.args.get("properties")
+
+ if sql[:1] != "#" and any(
+ isinstance(prop, exp.TemporaryProperty)
+ for prop in (properties.expressions if properties else [])
+ ):
+ sql = f"#{sql}"
+
+ return sql
+
def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"
@@ -616,3 +693,13 @@ class TSQL(Dialect):
this = self.sql(expression, "this")
this = f" {this}" if this else ""
return f"ROLLBACK TRANSACTION{this}"
+
+ def identifier_sql(self, expression: exp.Identifier) -> str:
+ identifier = super().identifier_sql(expression)
+
+ if expression.args.get("global"):
+ identifier = f"##{identifier}"
+ elif expression.args.get("temporary"):
+ identifier = f"#{identifier}"
+
+ return identifier