diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-02 23:59:40 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-02 23:59:46 +0000 |
commit | 20739a12c39121a9e7ad3c9a2469ec5a6876199d (patch) | |
tree | c000de91c59fd29b2d9beecf9f93b84e69727f37 /sqlglot/dialects/tsql.py | |
parent | Releasing debian version 12.2.0-1. (diff) | |
download | sqlglot-20739a12c39121a9e7ad3c9a2469ec5a6876199d.tar.xz sqlglot-20739a12c39121a9e7ad3c9a2469ec5a6876199d.zip |
Merging upstream version 15.0.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r-- | sqlglot/dialects/tsql.py | 90 |
1 files changed, 54 insertions, 36 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 03de99c..f6ad888 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -16,6 +16,9 @@ from sqlglot.helper import seq_get from sqlglot.time import format_time from sqlglot.tokens import TokenType +if t.TYPE_CHECKING: + from sqlglot._typing import E + FULL_FORMAT_TIME_MAPPING = { "weekday": "%A", "dw": "%A", @@ -50,13 +53,17 @@ DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{ TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"} -def _format_time_lambda(exp_class, full_format_mapping=None, default=None): - def _format_time(args): +def _format_time_lambda( + exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None +) -> t.Callable[[t.List], E]: + def _format_time(args: t.List) -> E: + assert len(args) == 2 + return exp_class( - this=seq_get(args, 1), + this=args[1], format=exp.Literal.string( format_time( - seq_get(args, 0).name or (TSQL.time_format if default is True else default), + args[0].name, {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.time_mapping, @@ -67,13 +74,17 @@ def _format_time_lambda(exp_class, full_format_mapping=None, default=None): return _format_time -def _parse_format(args): - fmt = seq_get(args, 1) - number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this) +def _parse_format(args: t.List) -> exp.Expression: + assert len(args) == 2 + + fmt = args[1] + number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name) + if number_fmt: - return exp.NumberToStr(this=seq_get(args, 0), format=fmt) + return exp.NumberToStr(this=args[0], format=fmt) + return exp.TimeToStr( - this=seq_get(args, 0), + this=args[0], format=exp.Literal.string( format_time(fmt.name, TSQL.format_time_mapping) if len(fmt.name) == 1 @@ -82,7 +93,7 @@ def _parse_format(args): ) -def _parse_eomonth(args): +def _parse_eomonth(args: t.List) -> exp.Expression: date = seq_get(args, 0) month_lag = seq_get(args, 1) unit = DATE_DELTA_INTERVAL.get("month") @@ -96,7 +107,7 @@ def _parse_eomonth(args): return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit)) -def _parse_hashbytes(args): +def _parse_hashbytes(args: t.List) -> exp.Expression: kind, data = args kind = kind.name.upper() if kind.is_string else "" @@ -110,40 +121,47 @@ def _parse_hashbytes(args): return exp.SHA2(this=data, length=exp.Literal.number(256)) if kind == "SHA2_512": return exp.SHA2(this=data, length=exp.Literal.number(512)) + return exp.func("HASHBYTES", *args) -def generate_date_delta_with_unit_sql(self, e): - func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF" - return self.func(func, e.text("unit"), e.expression, e.this) +def generate_date_delta_with_unit_sql( + self: generator.Generator, expression: exp.DateAdd | exp.DateDiff +) -> str: + func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF" + return self.func(func, expression.text("unit"), expression.expression, expression.this) -def _format_sql(self, e): +def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str: fmt = ( - e.args["format"] - if isinstance(e, exp.NumberToStr) - else exp.Literal.string(format_time(e.text("format"), TSQL.inverse_time_mapping)) + expression.args["format"] + if isinstance(expression, exp.NumberToStr) + else exp.Literal.string( + format_time( + expression.text("format"), t.cast(t.Dict[str, str], TSQL.inverse_time_mapping) + ) + ) ) - return self.func("FORMAT", e.this, fmt) + return self.func("FORMAT", expression.this, fmt) -def _string_agg_sql(self, e): - e = e.copy() +def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str: + expression = expression.copy() - this = e.this - distinct = e.find(exp.Distinct) + this = expression.this + distinct = expression.find(exp.Distinct) if distinct: # exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.") this = distinct.pop().expressions[0] order = "" - if isinstance(e.this, exp.Order): - if e.this.this: - this = e.this.this.pop() - order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space + if isinstance(expression.this, exp.Order): + if expression.this.this: + this = expression.this.this.pop() + order = f" WITHIN GROUP ({self.sql(expression.this)[1:]})" # Order has a leading space - separator = e.args.get("separator") or exp.Literal.string(",") + separator = expression.args.get("separator") or exp.Literal.string(",") return f"STRING_AGG({self.format_args(this, separator)}){order}" @@ -292,7 +310,7 @@ class TSQL(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "CHARINDEX": lambda args: exp.StrPosition( this=seq_get(args, 1), substr=seq_get(args, 0), @@ -332,13 +350,13 @@ class TSQL(Dialect): DataType.Type.NCHAR, } - RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { # type: ignore + RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { TokenType.TABLE, - *parser.Parser.TYPE_TOKENS, # type: ignore + *parser.Parser.TYPE_TOKENS, } STATEMENT_PARSERS = { - **parser.Parser.STATEMENT_PARSERS, # type: ignore + **parser.Parser.STATEMENT_PARSERS, TokenType.END: lambda self: self._parse_command(), } @@ -377,7 +395,7 @@ class TSQL(Dialect): return system_time - def _parse_table_parts(self, schema: bool = False) -> exp.Expression: + def _parse_table_parts(self, schema: bool = False) -> exp.Table: table = super()._parse_table_parts(schema=schema) table.set("system_time", self._parse_system_time()) return table @@ -450,7 +468,7 @@ class TSQL(Dialect): LOCKING_READS_SUPPORTED = True TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.DECIMAL: "NUMERIC", exp.DataType.Type.DATETIME: "DATETIME2", @@ -458,7 +476,7 @@ class TSQL(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), @@ -480,7 +498,7 @@ class TSQL(Dialect): TRANSFORMS.pop(exp.ReturnsProperty) PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } |