diff options
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r-- | sqlglot/dialects/tsql.py | 101 |
1 files changed, 94 insertions, 7 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 01d5001..0eb0906 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -1,5 +1,6 @@ from __future__ import annotations +import datetime import re import typing as t @@ -10,6 +11,7 @@ from sqlglot.dialects.dialect import ( min_or_least, parse_date_delta, rename_func, + timestrtotime_sql, ) from sqlglot.expressions import DataType from sqlglot.helper import seq_get @@ -52,6 +54,8 @@ DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{ # N = Numeric, C=Currency TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"} +DEFAULT_START_DATE = datetime.date(1900, 1, 1) + def _format_time_lambda( exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None @@ -166,6 +170,34 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s return f"STRING_AGG({self.format_args(this, separator)}){order}" +def _parse_date_delta( + exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None +) -> t.Callable[[t.List], E]: + def inner_func(args: t.List) -> E: + unit = seq_get(args, 0) + if unit and unit_mapping: + unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) + + start_date = seq_get(args, 1) + if start_date and start_date.is_number: + # Numeric types are valid DATETIME values + if start_date.is_int: + adds = DEFAULT_START_DATE + datetime.timedelta(days=int(start_date.this)) + start_date = exp.Literal.string(adds.strftime("%F")) + else: + # We currently don't handle float values, i.e. they're not converted to equivalent DATETIMEs. + # This is not a problem when generating T-SQL code, it is when transpiling to other dialects. + return exp_class(this=seq_get(args, 2), expression=start_date, unit=unit) + + return exp_class( + this=exp.TimeStrToTime(this=seq_get(args, 2)), + expression=exp.TimeStrToTime(this=start_date), + unit=unit, + ) + + return inner_func + + class TSQL(Dialect): RESOLVES_IDENTIFIERS_AS_UPPERCASE = None NULL_ORDERING = "nulls_are_small" @@ -298,7 +330,6 @@ class TSQL(Dialect): "SMALLDATETIME": TokenType.DATETIME, "SMALLMONEY": TokenType.SMALLMONEY, "SQL_VARIANT": TokenType.VARIANT, - "TIME": TokenType.TIMESTAMP, "TOP": TokenType.TOP, "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, "VARCHAR(MAX)": TokenType.TEXT, @@ -307,10 +338,6 @@ class TSQL(Dialect): "SYSTEM_USER": TokenType.CURRENT_USER, } - # TSQL allows @, # to appear as a variable/identifier prefix - SINGLE_TOKENS = tokens.Tokenizer.SINGLE_TOKENS.copy() - SINGLE_TOKENS.pop("#") - class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -320,7 +347,7 @@ class TSQL(Dialect): position=seq_get(args, 2), ), "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), - "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), + "DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True), "DATEPART": _format_time_lambda(exp.TimeToStr), "EOMONTH": _parse_eomonth, @@ -518,6 +545,36 @@ class TSQL(Dialect): expressions = self._parse_csv(self._parse_function_parameter) return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) + def _parse_id_var( + self, + any_token: bool = True, + tokens: t.Optional[t.Collection[TokenType]] = None, + ) -> t.Optional[exp.Expression]: + is_temporary = self._match(TokenType.HASH) + is_global = is_temporary and self._match(TokenType.HASH) + + this = super()._parse_id_var(any_token=any_token, tokens=tokens) + if this: + if is_global: + this.set("global", True) + elif is_temporary: + this.set("temporary", True) + + return this + + def _parse_create(self) -> exp.Create | exp.Command: + create = super()._parse_create() + + if isinstance(create, exp.Create): + table = create.this.this if isinstance(create.this, exp.Schema) else create.this + if isinstance(table, exp.Table) and table.this.args.get("temporary"): + if not create.args.get("properties"): + create.set("properties", exp.Properties(expressions=[])) + + create.args["properties"].append("expressions", exp.TemporaryProperty()) + + return create + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True LIMIT_IS_TOP = True @@ -526,9 +583,11 @@ class TSQL(Dialect): TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, - exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.DECIMAL: "NUMERIC", exp.DataType.Type.DATETIME: "DATETIME2", + exp.DataType.Type.INT: "INTEGER", + exp.DataType.Type.TIMESTAMP: "DATETIME2", + exp.DataType.Type.TIMESTAMPTZ: "DATETIMEOFFSET", exp.DataType.Type.VARIANT: "SQL_VARIANT", } @@ -552,6 +611,8 @@ class TSQL(Dialect): exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this, ), + exp.TemporaryProperty: lambda self, e: "", + exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: _format_sql, } @@ -564,6 +625,22 @@ class TSQL(Dialect): LIMIT_FETCH = "FETCH" + def createable_sql( + self, + expression: exp.Create, + locations: dict[exp.Properties.Location, list[exp.Property]], + ) -> str: + sql = self.sql(expression, "this") + properties = expression.args.get("properties") + + if sql[:1] != "#" and any( + isinstance(prop, exp.TemporaryProperty) + for prop in (properties.expressions if properties else []) + ): + sql = f"#{sql}" + + return sql + def offset_sql(self, expression: exp.Offset) -> str: return f"{super().offset_sql(expression)} ROWS" @@ -616,3 +693,13 @@ class TSQL(Dialect): this = self.sql(expression, "this") this = f" {this}" if this else "" return f"ROLLBACK TRANSACTION{this}" + + def identifier_sql(self, expression: exp.Identifier) -> str: + identifier = super().identifier_sql(expression) + + if expression.args.get("global"): + identifier = f"##{identifier}" + elif expression.args.get("temporary"): + identifier = f"#{identifier}" + + return identifier |