summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/tsql.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r--sqlglot/dialects/tsql.py128
1 files changed, 101 insertions, 27 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 165a703..b9c347c 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -14,9 +14,10 @@ from sqlglot.dialects.dialect import (
max_or_greatest,
min_or_least,
parse_date_delta,
+ path_to_jsonpath,
rename_func,
timestrtotime_sql,
- ts_or_ds_to_date_sql,
+ trim_sql,
)
from sqlglot.expressions import DataType
from sqlglot.helper import seq_get
@@ -105,18 +106,17 @@ def _parse_format(args: t.List) -> exp.Expression:
return exp.TimeToStr(this=this, format=fmt, culture=culture)
-def _parse_eomonth(args: t.List) -> exp.Expression:
- date = seq_get(args, 0)
+def _parse_eomonth(args: t.List) -> exp.LastDay:
+ date = exp.TsOrDsToDate(this=seq_get(args, 0))
month_lag = seq_get(args, 1)
- unit = DATE_DELTA_INTERVAL.get("month")
if month_lag is None:
- return exp.LastDateOfMonth(this=date)
+ this: exp.Expression = date
+ else:
+ unit = DATE_DELTA_INTERVAL.get("month")
+ this = exp.DateAdd(this=date, expression=month_lag, unit=unit and exp.var(unit))
- # Remove month lag argument in parser as its compared with the number of arguments of the resulting class
- args.remove(month_lag)
-
- return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
+ return exp.LastDay(this=this)
def _parse_hashbytes(args: t.List) -> exp.Expression:
@@ -137,26 +137,27 @@ def _parse_hashbytes(args: t.List) -> exp.Expression:
return exp.func("HASHBYTES", *args)
-DATEPART_ONLY_FORMATS = {"dw", "hour", "quarter"}
+DATEPART_ONLY_FORMATS = {"DW", "HOUR", "QUARTER"}
def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
- fmt = (
- expression.args["format"]
- if isinstance(expression, exp.NumberToStr)
- else exp.Literal.string(
- format_time(
- expression.text("format"),
- t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING),
- )
- )
- )
+ fmt = expression.args["format"]
- # There is no format for "quarter"
- if fmt.name.lower() in DATEPART_ONLY_FORMATS:
- return self.func("DATEPART", fmt.name, expression.this)
+ if not isinstance(expression, exp.NumberToStr):
+ if fmt.is_string:
+ mapped_fmt = format_time(fmt.name, TSQL.INVERSE_TIME_MAPPING)
- return self.func("FORMAT", expression.this, fmt, expression.args.get("culture"))
+ name = (mapped_fmt or "").upper()
+ if name in DATEPART_ONLY_FORMATS:
+ return self.func("DATEPART", name, expression.this)
+
+ fmt_sql = self.sql(exp.Literal.string(mapped_fmt))
+ else:
+ fmt_sql = self.format_time(expression) or self.sql(fmt)
+ else:
+ fmt_sql = self.sql(fmt)
+
+ return self.func("FORMAT", expression.this, fmt_sql, expression.args.get("culture"))
def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
@@ -239,6 +240,30 @@ def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression:
return expression
+# https://learn.microsoft.com/en-us/sql/t-sql/functions/datetimefromparts-transact-sql?view=sql-server-ver16#syntax
+def _parse_datetimefromparts(args: t.List) -> exp.TimestampFromParts:
+ return exp.TimestampFromParts(
+ year=seq_get(args, 0),
+ month=seq_get(args, 1),
+ day=seq_get(args, 2),
+ hour=seq_get(args, 3),
+ min=seq_get(args, 4),
+ sec=seq_get(args, 5),
+ milli=seq_get(args, 6),
+ )
+
+
+# https://learn.microsoft.com/en-us/sql/t-sql/functions/timefromparts-transact-sql?view=sql-server-ver16#syntax
+def _parse_timefromparts(args: t.List) -> exp.TimeFromParts:
+ return exp.TimeFromParts(
+ hour=seq_get(args, 0),
+ min=seq_get(args, 1),
+ sec=seq_get(args, 2),
+ fractions=seq_get(args, 3),
+ precision=seq_get(args, 4),
+ )
+
+
class TSQL(Dialect):
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE
TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'"
@@ -352,7 +377,7 @@ class TSQL(Dialect):
}
class Tokenizer(tokens.Tokenizer):
- IDENTIFIERS = ['"', ("[", "]")]
+ IDENTIFIERS = [("[", "]"), '"']
QUOTES = ["'", '"']
HEX_STRINGS = [("0x", ""), ("0X", "")]
VAR_SINGLE_TOKENS = {"@", "$", "#"}
@@ -362,6 +387,7 @@ class TSQL(Dialect):
"DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"DECLARE": TokenType.COMMAND,
+ "EXEC": TokenType.COMMAND,
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
"NTEXT": TokenType.TEXT,
@@ -397,6 +423,7 @@ class TSQL(Dialect):
"DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": _format_time_lambda(exp.TimeToStr),
+ "DATETIMEFROMPARTS": _parse_datetimefromparts,
"EOMONTH": _parse_eomonth,
"FORMAT": _parse_format,
"GETDATE": exp.CurrentTimestamp.from_arg_list,
@@ -411,6 +438,7 @@ class TSQL(Dialect):
"SUSER_NAME": exp.CurrentUser.from_arg_list,
"SUSER_SNAME": exp.CurrentUser.from_arg_list,
"SYSTEM_USER": exp.CurrentUser.from_arg_list,
+ "TIMEFROMPARTS": _parse_timefromparts,
}
JOIN_HINTS = {
@@ -440,6 +468,7 @@ class TSQL(Dialect):
LOG_DEFAULTS_TO_LN = True
ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False
+ STRING_ALIASES = True
def _parse_projections(self) -> t.List[exp.Expression]:
"""
@@ -630,8 +659,10 @@ class TSQL(Dialect):
COMPUTED_COLUMN_WITH_TYPE = False
CTE_RECURSIVE_KEYWORD_REQUIRED = False
ENSURE_BOOLS = True
- NULL_ORDERING_SUPPORTED = False
+ NULL_ORDERING_SUPPORTED = None
SUPPORTS_SINGLE_ARG_CONCAT = False
+ TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
+ SUPPORTS_SELECT_INTO = True
EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Delete,
@@ -667,13 +698,16 @@ class TSQL(Dialect):
exp.CurrentTimestamp: rename_func("GETDATE"),
exp.Extract: rename_func("DATEPART"),
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
+ exp.GetPath: path_to_jsonpath("JSON_VALUE"),
exp.GroupConcat: _string_agg_sql,
exp.If: rename_func("IIF"),
+ exp.LastDay: lambda self, e: self.func("EOMONTH", e.this),
exp.Length: rename_func("LEN"),
exp.Max: max_or_greatest,
exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this),
exp.Min: min_or_least,
exp.NumberToStr: _format_sql,
+ exp.ParseJSON: lambda self, e: self.sql(e, "this"),
exp.Select: transforms.preprocess(
[
transforms.eliminate_distinct_on,
@@ -689,9 +723,9 @@ class TSQL(Dialect):
exp.TemporaryProperty: lambda self, e: "",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: _format_sql,
+ exp.Trim: trim_sql,
exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
- exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"),
}
TRANSFORMS.pop(exp.ReturnsProperty)
@@ -701,6 +735,46 @@ class TSQL(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
+ def lateral_op(self, expression: exp.Lateral) -> str:
+ cross_apply = expression.args.get("cross_apply")
+ if cross_apply is True:
+ return "CROSS APPLY"
+ if cross_apply is False:
+ return "OUTER APPLY"
+
+ # TODO: perhaps we can check if the parent is a Join and transpile it appropriately
+ self.unsupported("LATERAL clause is not supported.")
+ return "LATERAL"
+
+ def timefromparts_sql(self, expression: exp.TimeFromParts) -> str:
+ nano = expression.args.get("nano")
+ if nano is not None:
+ nano.pop()
+ self.unsupported("Specifying nanoseconds is not supported in TIMEFROMPARTS.")
+
+ if expression.args.get("fractions") is None:
+ expression.set("fractions", exp.Literal.number(0))
+ if expression.args.get("precision") is None:
+ expression.set("precision", exp.Literal.number(0))
+
+ return rename_func("TIMEFROMPARTS")(self, expression)
+
+ def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
+ zone = expression.args.get("zone")
+ if zone is not None:
+ zone.pop()
+ self.unsupported("Time zone is not supported in DATETIMEFROMPARTS.")
+
+ nano = expression.args.get("nano")
+ if nano is not None:
+ nano.pop()
+ self.unsupported("Specifying nanoseconds is not supported in DATETIMEFROMPARTS.")
+
+ if expression.args.get("milli") is None:
+ expression.set("milli", exp.Literal.number(0))
+
+ return rename_func("DATETIMEFROMPARTS")(self, expression)
+
def set_operation(self, expression: exp.Union, op: str) -> str:
limit = expression.args.get("limit")
if limit: