diff options
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r-- | sqlglot/dialects/tsql.py | 63 |
1 files changed, 33 insertions, 30 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 85b2e12..5955352 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -13,7 +13,7 @@ from sqlglot.dialects.dialect import ( generatedasidentitycolumnconstraint_sql, max_or_greatest, min_or_least, - parse_date_delta, + build_date_delta, rename_func, timestrtotime_sql, trim_sql, @@ -64,10 +64,10 @@ DEFAULT_START_DATE = datetime.date(1900, 1, 1) BIT_TYPES = {exp.EQ, exp.NEQ, exp.Is, exp.In, exp.Select, exp.Alias} -def _format_time_lambda( +def _build_formatted_time( exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None ) -> t.Callable[[t.List], E]: - def _format_time(args: t.List) -> E: + def _builder(args: t.List) -> E: assert len(args) == 2 return exp_class( @@ -84,10 +84,10 @@ def _format_time_lambda( ), ) - return _format_time + return _builder -def _parse_format(args: t.List) -> exp.Expression: +def _build_format(args: t.List) -> exp.NumberToStr | exp.TimeToStr: this = seq_get(args, 0) fmt = seq_get(args, 1) culture = seq_get(args, 2) @@ -107,7 +107,7 @@ def _parse_format(args: t.List) -> exp.Expression: return exp.TimeToStr(this=this, format=fmt, culture=culture) -def _parse_eomonth(args: t.List) -> exp.LastDay: +def _build_eomonth(args: t.List) -> exp.LastDay: date = exp.TsOrDsToDate(this=seq_get(args, 0)) month_lag = seq_get(args, 1) @@ -120,7 +120,7 @@ def _parse_eomonth(args: t.List) -> exp.LastDay: return exp.LastDay(this=this) -def _parse_hashbytes(args: t.List) -> exp.Expression: +def _build_hashbytes(args: t.List) -> exp.Expression: kind, data = args kind = kind.name.upper() if kind.is_string else "" @@ -179,10 +179,10 @@ def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str: return f"STRING_AGG({self.format_args(this, separator)}){order}" -def _parse_date_delta( +def _build_date_delta( exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None ) -> t.Callable[[t.List], E]: - def inner_func(args: t.List) -> E: + def _builder(args: t.List) -> E: unit = seq_get(args, 0) if unit and unit_mapping: unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) @@ -204,7 +204,7 @@ def _parse_date_delta( unit=unit, ) - return inner_func + return _builder def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression: @@ -242,7 +242,7 @@ def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression: # https://learn.microsoft.com/en-us/sql/t-sql/functions/datetimefromparts-transact-sql?view=sql-server-ver16#syntax -def _parse_datetimefromparts(args: t.List) -> exp.TimestampFromParts: +def _build_datetimefromparts(args: t.List) -> exp.TimestampFromParts: return exp.TimestampFromParts( year=seq_get(args, 0), month=seq_get(args, 1), @@ -255,7 +255,7 @@ def _parse_datetimefromparts(args: t.List) -> exp.TimestampFromParts: # https://learn.microsoft.com/en-us/sql/t-sql/functions/timefromparts-transact-sql?view=sql-server-ver16#syntax -def _parse_timefromparts(args: t.List) -> exp.TimeFromParts: +def _build_timefromparts(args: t.List) -> exp.TimeFromParts: return exp.TimeFromParts( hour=seq_get(args, 0), min=seq_get(args, 1), @@ -265,7 +265,7 @@ def _parse_timefromparts(args: t.List) -> exp.TimeFromParts: ) -def _parse_as_text( +def _build_with_arg_as_text( klass: t.Type[exp.Expression], ) -> t.Callable[[t.List[exp.Expression]], exp.Expression]: def _parse(args: t.List[exp.Expression]) -> exp.Expression: @@ -288,8 +288,8 @@ def _parse_as_text( def _json_extract_sql( self: TSQL.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar ) -> str: - json_query = rename_func("JSON_QUERY")(self, expression) - json_value = rename_func("JSON_VALUE")(self, expression) + json_query = self.func("JSON_QUERY", expression.this, expression.expression) + json_value = self.func("JSON_VALUE", expression.this, expression.expression) return self.func("ISNULL", json_query, json_value) @@ -448,28 +448,28 @@ class TSQL(Dialect): substr=seq_get(args, 0), position=seq_get(args, 2), ), - "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), - "DATEDIFF": _parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), - "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True), - "DATEPART": _format_time_lambda(exp.TimeToStr), - "DATETIMEFROMPARTS": _parse_datetimefromparts, - "EOMONTH": _parse_eomonth, - "FORMAT": _parse_format, + "DATEADD": build_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), + "DATEDIFF": _build_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), + "DATENAME": _build_formatted_time(exp.TimeToStr, full_format_mapping=True), + "DATEPART": _build_formatted_time(exp.TimeToStr), + "DATETIMEFROMPARTS": _build_datetimefromparts, + "EOMONTH": _build_eomonth, + "FORMAT": _build_format, "GETDATE": exp.CurrentTimestamp.from_arg_list, - "HASHBYTES": _parse_hashbytes, + "HASHBYTES": _build_hashbytes, "ISNULL": exp.Coalesce.from_arg_list, - "JSON_QUERY": parser.parse_extract_json_with_path(exp.JSONExtract), - "JSON_VALUE": parser.parse_extract_json_with_path(exp.JSONExtractScalar), - "LEN": _parse_as_text(exp.Length), - "LEFT": _parse_as_text(exp.Left), - "RIGHT": _parse_as_text(exp.Right), + "JSON_QUERY": parser.build_extract_json_with_path(exp.JSONExtract), + "JSON_VALUE": parser.build_extract_json_with_path(exp.JSONExtractScalar), + "LEN": _build_with_arg_as_text(exp.Length), + "LEFT": _build_with_arg_as_text(exp.Left), + "RIGHT": _build_with_arg_as_text(exp.Right), "REPLICATE": exp.Repeat.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "SYSDATETIME": exp.CurrentTimestamp.from_arg_list, "SUSER_NAME": exp.CurrentUser.from_arg_list, "SUSER_SNAME": exp.CurrentUser.from_arg_list, "SYSTEM_USER": exp.CurrentUser.from_arg_list, - "TIMEFROMPARTS": _parse_timefromparts, + "TIMEFROMPARTS": _build_timefromparts, } JOIN_HINTS = { @@ -756,6 +756,9 @@ class TSQL(Dialect): transforms.eliminate_qualify, ] ), + exp.StrPosition: lambda self, e: self.func( + "CHARINDEX", e.args.get("substr"), e.this, e.args.get("position") + ), exp.Subquery: transforms.preprocess([qualify_derived_table_outputs]), exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this), exp.SHA2: lambda self, e: self.func( @@ -855,7 +858,7 @@ class TSQL(Dialect): return sql def create_sql(self, expression: exp.Create) -> str: - kind = self.sql(expression, "kind").upper() + kind = expression.kind exists = expression.args.pop("exists", None) sql = super().create_sql(expression) |