summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/tsql.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-02 23:59:40 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-06-02 23:59:46 +0000
commit20739a12c39121a9e7ad3c9a2469ec5a6876199d (patch)
treec000de91c59fd29b2d9beecf9f93b84e69727f37 /sqlglot/dialects/tsql.py
parentReleasing debian version 12.2.0-1. (diff)
downloadsqlglot-20739a12c39121a9e7ad3c9a2469ec5a6876199d.tar.xz
sqlglot-20739a12c39121a9e7ad3c9a2469ec5a6876199d.zip
Merging upstream version 15.0.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r--sqlglot/dialects/tsql.py90
1 files changed, 54 insertions, 36 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 03de99c..f6ad888 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -16,6 +16,9 @@ from sqlglot.helper import seq_get
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
+if t.TYPE_CHECKING:
+ from sqlglot._typing import E
+
FULL_FORMAT_TIME_MAPPING = {
"weekday": "%A",
"dw": "%A",
@@ -50,13 +53,17 @@ DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{
TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
-def _format_time_lambda(exp_class, full_format_mapping=None, default=None):
- def _format_time(args):
+def _format_time_lambda(
+ exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None
+) -> t.Callable[[t.List], E]:
+ def _format_time(args: t.List) -> E:
+ assert len(args) == 2
+
return exp_class(
- this=seq_get(args, 1),
+ this=args[1],
format=exp.Literal.string(
format_time(
- seq_get(args, 0).name or (TSQL.time_format if default is True else default),
+ args[0].name,
{**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING}
if full_format_mapping
else TSQL.time_mapping,
@@ -67,13 +74,17 @@ def _format_time_lambda(exp_class, full_format_mapping=None, default=None):
return _format_time
-def _parse_format(args):
- fmt = seq_get(args, 1)
- number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
+def _parse_format(args: t.List) -> exp.Expression:
+ assert len(args) == 2
+
+ fmt = args[1]
+ number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name)
+
if number_fmt:
- return exp.NumberToStr(this=seq_get(args, 0), format=fmt)
+ return exp.NumberToStr(this=args[0], format=fmt)
+
return exp.TimeToStr(
- this=seq_get(args, 0),
+ this=args[0],
format=exp.Literal.string(
format_time(fmt.name, TSQL.format_time_mapping)
if len(fmt.name) == 1
@@ -82,7 +93,7 @@ def _parse_format(args):
)
-def _parse_eomonth(args):
+def _parse_eomonth(args: t.List) -> exp.Expression:
date = seq_get(args, 0)
month_lag = seq_get(args, 1)
unit = DATE_DELTA_INTERVAL.get("month")
@@ -96,7 +107,7 @@ def _parse_eomonth(args):
return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
-def _parse_hashbytes(args):
+def _parse_hashbytes(args: t.List) -> exp.Expression:
kind, data = args
kind = kind.name.upper() if kind.is_string else ""
@@ -110,40 +121,47 @@ def _parse_hashbytes(args):
return exp.SHA2(this=data, length=exp.Literal.number(256))
if kind == "SHA2_512":
return exp.SHA2(this=data, length=exp.Literal.number(512))
+
return exp.func("HASHBYTES", *args)
-def generate_date_delta_with_unit_sql(self, e):
- func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
- return self.func(func, e.text("unit"), e.expression, e.this)
+def generate_date_delta_with_unit_sql(
+ self: generator.Generator, expression: exp.DateAdd | exp.DateDiff
+) -> str:
+ func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF"
+ return self.func(func, expression.text("unit"), expression.expression, expression.this)
-def _format_sql(self, e):
+def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str:
fmt = (
- e.args["format"]
- if isinstance(e, exp.NumberToStr)
- else exp.Literal.string(format_time(e.text("format"), TSQL.inverse_time_mapping))
+ expression.args["format"]
+ if isinstance(expression, exp.NumberToStr)
+ else exp.Literal.string(
+ format_time(
+ expression.text("format"), t.cast(t.Dict[str, str], TSQL.inverse_time_mapping)
+ )
+ )
)
- return self.func("FORMAT", e.this, fmt)
+ return self.func("FORMAT", expression.this, fmt)
-def _string_agg_sql(self, e):
- e = e.copy()
+def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str:
+ expression = expression.copy()
- this = e.this
- distinct = e.find(exp.Distinct)
+ this = expression.this
+ distinct = expression.find(exp.Distinct)
if distinct:
# exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression
self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.")
this = distinct.pop().expressions[0]
order = ""
- if isinstance(e.this, exp.Order):
- if e.this.this:
- this = e.this.this.pop()
- order = f" WITHIN GROUP ({self.sql(e.this)[1:]})" # Order has a leading space
+ if isinstance(expression.this, exp.Order):
+ if expression.this.this:
+ this = expression.this.this.pop()
+ order = f" WITHIN GROUP ({self.sql(expression.this)[1:]})" # Order has a leading space
- separator = e.args.get("separator") or exp.Literal.string(",")
+ separator = expression.args.get("separator") or exp.Literal.string(",")
return f"STRING_AGG({self.format_args(this, separator)}){order}"
@@ -292,7 +310,7 @@ class TSQL(Dialect):
class Parser(parser.Parser):
FUNCTIONS = {
- **parser.Parser.FUNCTIONS, # type: ignore
+ **parser.Parser.FUNCTIONS,
"CHARINDEX": lambda args: exp.StrPosition(
this=seq_get(args, 1),
substr=seq_get(args, 0),
@@ -332,13 +350,13 @@ class TSQL(Dialect):
DataType.Type.NCHAR,
}
- RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { # type: ignore
+ RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - {
TokenType.TABLE,
- *parser.Parser.TYPE_TOKENS, # type: ignore
+ *parser.Parser.TYPE_TOKENS,
}
STATEMENT_PARSERS = {
- **parser.Parser.STATEMENT_PARSERS, # type: ignore
+ **parser.Parser.STATEMENT_PARSERS,
TokenType.END: lambda self: self._parse_command(),
}
@@ -377,7 +395,7 @@ class TSQL(Dialect):
return system_time
- def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
+ def _parse_table_parts(self, schema: bool = False) -> exp.Table:
table = super()._parse_table_parts(schema=schema)
table.set("system_time", self._parse_system_time())
return table
@@ -450,7 +468,7 @@ class TSQL(Dialect):
LOCKING_READS_SUPPORTED = True
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING, # type: ignore
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",
exp.DataType.Type.DATETIME: "DATETIME2",
@@ -458,7 +476,7 @@ class TSQL(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS, # type: ignore
+ **generator.Generator.TRANSFORMS,
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
@@ -480,7 +498,7 @@ class TSQL(Dialect):
TRANSFORMS.pop(exp.ReturnsProperty)
PROPERTIES_LOCATION = {
- **generator.Generator.PROPERTIES_LOCATION, # type: ignore
+ **generator.Generator.PROPERTIES_LOCATION,
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}