diff options
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r-- | sqlglot/dialects/tsql.py | 128 |
1 files changed, 101 insertions, 27 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 165a703..b9c347c 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -14,9 +14,10 @@ from sqlglot.dialects.dialect import ( max_or_greatest, min_or_least, parse_date_delta, + path_to_jsonpath, rename_func, timestrtotime_sql, - ts_or_ds_to_date_sql, + trim_sql, ) from sqlglot.expressions import DataType from sqlglot.helper import seq_get @@ -105,18 +106,17 @@ def _parse_format(args: t.List) -> exp.Expression: return exp.TimeToStr(this=this, format=fmt, culture=culture) -def _parse_eomonth(args: t.List) -> exp.Expression: - date = seq_get(args, 0) +def _parse_eomonth(args: t.List) -> exp.LastDay: + date = exp.TsOrDsToDate(this=seq_get(args, 0)) month_lag = seq_get(args, 1) - unit = DATE_DELTA_INTERVAL.get("month") if month_lag is None: - return exp.LastDateOfMonth(this=date) + this: exp.Expression = date + else: + unit = DATE_DELTA_INTERVAL.get("month") + this = exp.DateAdd(this=date, expression=month_lag, unit=unit and exp.var(unit)) - # Remove month lag argument in parser as its compared with the number of arguments of the resulting class - args.remove(month_lag) - - return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit)) + return exp.LastDay(this=this) def _parse_hashbytes(args: t.List) -> exp.Expression: @@ -137,26 +137,27 @@ def _parse_hashbytes(args: t.List) -> exp.Expression: return exp.func("HASHBYTES", *args) -DATEPART_ONLY_FORMATS = {"dw", "hour", "quarter"} +DATEPART_ONLY_FORMATS = {"DW", "HOUR", "QUARTER"} def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str: - fmt = ( - expression.args["format"] - if isinstance(expression, exp.NumberToStr) - else exp.Literal.string( - format_time( - expression.text("format"), - t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING), - ) - ) - ) + fmt = expression.args["format"] - # There is no format for "quarter" - if fmt.name.lower() in DATEPART_ONLY_FORMATS: - return self.func("DATEPART", fmt.name, expression.this) + if not isinstance(expression, exp.NumberToStr): + if fmt.is_string: + mapped_fmt = format_time(fmt.name, TSQL.INVERSE_TIME_MAPPING) - return self.func("FORMAT", expression.this, fmt, expression.args.get("culture")) + name = (mapped_fmt or "").upper() + if name in DATEPART_ONLY_FORMATS: + return self.func("DATEPART", name, expression.this) + + fmt_sql = self.sql(exp.Literal.string(mapped_fmt)) + else: + fmt_sql = self.format_time(expression) or self.sql(fmt) + else: + fmt_sql = self.sql(fmt) + + return self.func("FORMAT", expression.this, fmt_sql, expression.args.get("culture")) def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str: @@ -239,6 +240,30 @@ def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression: return expression +# https://learn.microsoft.com/en-us/sql/t-sql/functions/datetimefromparts-transact-sql?view=sql-server-ver16#syntax +def _parse_datetimefromparts(args: t.List) -> exp.TimestampFromParts: + return exp.TimestampFromParts( + year=seq_get(args, 0), + month=seq_get(args, 1), + day=seq_get(args, 2), + hour=seq_get(args, 3), + min=seq_get(args, 4), + sec=seq_get(args, 5), + milli=seq_get(args, 6), + ) + + +# https://learn.microsoft.com/en-us/sql/t-sql/functions/timefromparts-transact-sql?view=sql-server-ver16#syntax +def _parse_timefromparts(args: t.List) -> exp.TimeFromParts: + return exp.TimeFromParts( + hour=seq_get(args, 0), + min=seq_get(args, 1), + sec=seq_get(args, 2), + fractions=seq_get(args, 3), + precision=seq_get(args, 4), + ) + + class TSQL(Dialect): NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'" @@ -352,7 +377,7 @@ class TSQL(Dialect): } class Tokenizer(tokens.Tokenizer): - IDENTIFIERS = ['"', ("[", "]")] + IDENTIFIERS = [("[", "]"), '"'] QUOTES = ["'", '"'] HEX_STRINGS = [("0x", ""), ("0X", "")] VAR_SINGLE_TOKENS = {"@", "$", "#"} @@ -362,6 +387,7 @@ class TSQL(Dialect): "DATETIME2": TokenType.DATETIME, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, "DECLARE": TokenType.COMMAND, + "EXEC": TokenType.COMMAND, "IMAGE": TokenType.IMAGE, "MONEY": TokenType.MONEY, "NTEXT": TokenType.TEXT, @@ -397,6 +423,7 @@ class TSQL(Dialect): "DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True), "DATEPART": _format_time_lambda(exp.TimeToStr), + "DATETIMEFROMPARTS": _parse_datetimefromparts, "EOMONTH": _parse_eomonth, "FORMAT": _parse_format, "GETDATE": exp.CurrentTimestamp.from_arg_list, @@ -411,6 +438,7 @@ class TSQL(Dialect): "SUSER_NAME": exp.CurrentUser.from_arg_list, "SUSER_SNAME": exp.CurrentUser.from_arg_list, "SYSTEM_USER": exp.CurrentUser.from_arg_list, + "TIMEFROMPARTS": _parse_timefromparts, } JOIN_HINTS = { @@ -440,6 +468,7 @@ class TSQL(Dialect): LOG_DEFAULTS_TO_LN = True ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False + STRING_ALIASES = True def _parse_projections(self) -> t.List[exp.Expression]: """ @@ -630,8 +659,10 @@ class TSQL(Dialect): COMPUTED_COLUMN_WITH_TYPE = False CTE_RECURSIVE_KEYWORD_REQUIRED = False ENSURE_BOOLS = True - NULL_ORDERING_SUPPORTED = False + NULL_ORDERING_SUPPORTED = None SUPPORTS_SINGLE_ARG_CONCAT = False + TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" + SUPPORTS_SELECT_INTO = True EXPRESSIONS_WITHOUT_NESTED_CTES = { exp.Delete, @@ -667,13 +698,16 @@ class TSQL(Dialect): exp.CurrentTimestamp: rename_func("GETDATE"), exp.Extract: rename_func("DATEPART"), exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, + exp.GetPath: path_to_jsonpath("JSON_VALUE"), exp.GroupConcat: _string_agg_sql, exp.If: rename_func("IIF"), + exp.LastDay: lambda self, e: self.func("EOMONTH", e.this), exp.Length: rename_func("LEN"), exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this), exp.Min: min_or_least, exp.NumberToStr: _format_sql, + exp.ParseJSON: lambda self, e: self.sql(e, "this"), exp.Select: transforms.preprocess( [ transforms.eliminate_distinct_on, @@ -689,9 +723,9 @@ class TSQL(Dialect): exp.TemporaryProperty: lambda self, e: "", exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: _format_sql, + exp.Trim: trim_sql, exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), - exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"), } TRANSFORMS.pop(exp.ReturnsProperty) @@ -701,6 +735,46 @@ class TSQL(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def lateral_op(self, expression: exp.Lateral) -> str: + cross_apply = expression.args.get("cross_apply") + if cross_apply is True: + return "CROSS APPLY" + if cross_apply is False: + return "OUTER APPLY" + + # TODO: perhaps we can check if the parent is a Join and transpile it appropriately + self.unsupported("LATERAL clause is not supported.") + return "LATERAL" + + def timefromparts_sql(self, expression: exp.TimeFromParts) -> str: + nano = expression.args.get("nano") + if nano is not None: + nano.pop() + self.unsupported("Specifying nanoseconds is not supported in TIMEFROMPARTS.") + + if expression.args.get("fractions") is None: + expression.set("fractions", exp.Literal.number(0)) + if expression.args.get("precision") is None: + expression.set("precision", exp.Literal.number(0)) + + return rename_func("TIMEFROMPARTS")(self, expression) + + def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str: + zone = expression.args.get("zone") + if zone is not None: + zone.pop() + self.unsupported("Time zone is not supported in DATETIMEFROMPARTS.") + + nano = expression.args.get("nano") + if nano is not None: + nano.pop() + self.unsupported("Specifying nanoseconds is not supported in DATETIMEFROMPARTS.") + + if expression.args.get("milli") is None: + expression.set("milli", exp.Literal.number(0)) + + return rename_func("DATETIMEFROMPARTS")(self, expression) + def set_operation(self, expression: exp.Union, op: str) -> str: limit = expression.args.get("limit") if limit: |