summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/tsql.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-02-20 09:38:01 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-02-20 09:38:01 +0000
commitccb96d1393ae2c16620ea8e8dc749d9642b94e9b (patch)
treed21a77d0cc7da73a84cd6d6ef8212602f5d762e8 /sqlglot/dialects/tsql.py
parentReleasing debian version 21.1.1-1. (diff)
downloadsqlglot-ccb96d1393ae2c16620ea8e8dc749d9642b94e9b.tar.xz
sqlglot-ccb96d1393ae2c16620ea8e8dc749d9642b94e9b.zip
Merging upstream version 21.1.2.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r--sqlglot/dialects/tsql.py63
1 files changed, 33 insertions, 30 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 85b2e12..5955352 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -13,7 +13,7 @@ from sqlglot.dialects.dialect import (
generatedasidentitycolumnconstraint_sql,
max_or_greatest,
min_or_least,
- parse_date_delta,
+ build_date_delta,
rename_func,
timestrtotime_sql,
trim_sql,
@@ -64,10 +64,10 @@ DEFAULT_START_DATE = datetime.date(1900, 1, 1)
BIT_TYPES = {exp.EQ, exp.NEQ, exp.Is, exp.In, exp.Select, exp.Alias}
-def _format_time_lambda(
+def _build_formatted_time(
exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None
) -> t.Callable[[t.List], E]:
- def _format_time(args: t.List) -> E:
+ def _builder(args: t.List) -> E:
assert len(args) == 2
return exp_class(
@@ -84,10 +84,10 @@ def _format_time_lambda(
),
)
- return _format_time
+ return _builder
-def _parse_format(args: t.List) -> exp.Expression:
+def _build_format(args: t.List) -> exp.NumberToStr | exp.TimeToStr:
this = seq_get(args, 0)
fmt = seq_get(args, 1)
culture = seq_get(args, 2)
@@ -107,7 +107,7 @@ def _parse_format(args: t.List) -> exp.Expression:
return exp.TimeToStr(this=this, format=fmt, culture=culture)
-def _parse_eomonth(args: t.List) -> exp.LastDay:
+def _build_eomonth(args: t.List) -> exp.LastDay:
date = exp.TsOrDsToDate(this=seq_get(args, 0))
month_lag = seq_get(args, 1)
@@ -120,7 +120,7 @@ def _parse_eomonth(args: t.List) -> exp.LastDay:
return exp.LastDay(this=this)
-def _parse_hashbytes(args: t.List) -> exp.Expression:
+def _build_hashbytes(args: t.List) -> exp.Expression:
kind, data = args
kind = kind.name.upper() if kind.is_string else ""
@@ -179,10 +179,10 @@ def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str:
return f"STRING_AGG({self.format_args(this, separator)}){order}"
-def _parse_date_delta(
+def _build_date_delta(
exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None
) -> t.Callable[[t.List], E]:
- def inner_func(args: t.List) -> E:
+ def _builder(args: t.List) -> E:
unit = seq_get(args, 0)
if unit and unit_mapping:
unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name))
@@ -204,7 +204,7 @@ def _parse_date_delta(
unit=unit,
)
- return inner_func
+ return _builder
def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression:
@@ -242,7 +242,7 @@ def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression:
# https://learn.microsoft.com/en-us/sql/t-sql/functions/datetimefromparts-transact-sql?view=sql-server-ver16#syntax
-def _parse_datetimefromparts(args: t.List) -> exp.TimestampFromParts:
+def _build_datetimefromparts(args: t.List) -> exp.TimestampFromParts:
return exp.TimestampFromParts(
year=seq_get(args, 0),
month=seq_get(args, 1),
@@ -255,7 +255,7 @@ def _parse_datetimefromparts(args: t.List) -> exp.TimestampFromParts:
# https://learn.microsoft.com/en-us/sql/t-sql/functions/timefromparts-transact-sql?view=sql-server-ver16#syntax
-def _parse_timefromparts(args: t.List) -> exp.TimeFromParts:
+def _build_timefromparts(args: t.List) -> exp.TimeFromParts:
return exp.TimeFromParts(
hour=seq_get(args, 0),
min=seq_get(args, 1),
@@ -265,7 +265,7 @@ def _parse_timefromparts(args: t.List) -> exp.TimeFromParts:
)
-def _parse_as_text(
+def _build_with_arg_as_text(
klass: t.Type[exp.Expression],
) -> t.Callable[[t.List[exp.Expression]], exp.Expression]:
def _parse(args: t.List[exp.Expression]) -> exp.Expression:
@@ -288,8 +288,8 @@ def _parse_as_text(
def _json_extract_sql(
self: TSQL.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar
) -> str:
- json_query = rename_func("JSON_QUERY")(self, expression)
- json_value = rename_func("JSON_VALUE")(self, expression)
+ json_query = self.func("JSON_QUERY", expression.this, expression.expression)
+ json_value = self.func("JSON_VALUE", expression.this, expression.expression)
return self.func("ISNULL", json_query, json_value)
@@ -448,28 +448,28 @@ class TSQL(Dialect):
substr=seq_get(args, 0),
position=seq_get(args, 2),
),
- "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
- "DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
- "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
- "DATEPART": _format_time_lambda(exp.TimeToStr),
- "DATETIMEFROMPARTS": _parse_datetimefromparts,
- "EOMONTH": _parse_eomonth,
- "FORMAT": _parse_format,
+ "DATEADD": build_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
+ "DATEDIFF": _build_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
+ "DATENAME": _build_formatted_time(exp.TimeToStr, full_format_mapping=True),
+ "DATEPART": _build_formatted_time(exp.TimeToStr),
+ "DATETIMEFROMPARTS": _build_datetimefromparts,
+ "EOMONTH": _build_eomonth,
+ "FORMAT": _build_format,
"GETDATE": exp.CurrentTimestamp.from_arg_list,
- "HASHBYTES": _parse_hashbytes,
+ "HASHBYTES": _build_hashbytes,
"ISNULL": exp.Coalesce.from_arg_list,
- "JSON_QUERY": parser.parse_extract_json_with_path(exp.JSONExtract),
- "JSON_VALUE": parser.parse_extract_json_with_path(exp.JSONExtractScalar),
- "LEN": _parse_as_text(exp.Length),
- "LEFT": _parse_as_text(exp.Left),
- "RIGHT": _parse_as_text(exp.Right),
+ "JSON_QUERY": parser.build_extract_json_with_path(exp.JSONExtract),
+ "JSON_VALUE": parser.build_extract_json_with_path(exp.JSONExtractScalar),
+ "LEN": _build_with_arg_as_text(exp.Length),
+ "LEFT": _build_with_arg_as_text(exp.Left),
+ "RIGHT": _build_with_arg_as_text(exp.Right),
"REPLICATE": exp.Repeat.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
"SUSER_NAME": exp.CurrentUser.from_arg_list,
"SUSER_SNAME": exp.CurrentUser.from_arg_list,
"SYSTEM_USER": exp.CurrentUser.from_arg_list,
- "TIMEFROMPARTS": _parse_timefromparts,
+ "TIMEFROMPARTS": _build_timefromparts,
}
JOIN_HINTS = {
@@ -756,6 +756,9 @@ class TSQL(Dialect):
transforms.eliminate_qualify,
]
),
+ exp.StrPosition: lambda self, e: self.func(
+ "CHARINDEX", e.args.get("substr"), e.this, e.args.get("position")
+ ),
exp.Subquery: transforms.preprocess([qualify_derived_table_outputs]),
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
@@ -855,7 +858,7 @@ class TSQL(Dialect):
return sql
def create_sql(self, expression: exp.Create) -> str:
- kind = self.sql(expression, "kind").upper()
+ kind = expression.kind
exists = expression.args.pop("exists", None)
sql = super().create_sql(expression)