diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/hive.py | 40 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/starrocks.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 197 |
4 files changed, 236 insertions, 7 deletions
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 0810e0c..63fdb85 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -18,6 +18,36 @@ from sqlglot.helper import list_get from sqlglot.parser import Parser, parse_var_map from sqlglot.tokens import Tokenizer +# (FuncType, Multiplier) +DATE_DELTA_INTERVAL = { + "YEAR": ("ADD_MONTHS", 12), + "MONTH": ("ADD_MONTHS", 1), + "QUARTER": ("ADD_MONTHS", 3), + "WEEK": ("DATE_ADD", 7), + "DAY": ("DATE_ADD", 1), +} + +DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") + + +def _add_date_sql(self, expression): + unit = expression.text("unit").upper() + func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) + modified_increment = ( + int(expression.text("expression")) * multiplier if expression.expression.is_number else expression.expression + ) + modified_increment = exp.Literal.number(modified_increment) + return f"{func}({self.format_args(expression.this, modified_increment.this)})" + + +def _date_diff_sql(self, expression): + unit = expression.text("unit").upper() + sql_func = "MONTHS_BETWEEN" if unit in DIFF_MONTH_SWITCH else "DATEDIFF" + _, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1)) + multiplier_sql = f" / {multiplier}" if multiplier > 1 else "" + diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})" + return f"{diff_sql}{multiplier_sql}" + def _array_sort(self, expression): if expression.expression: @@ -120,10 +150,14 @@ class Hive(Dialect): "m": "%-M", "ss": "%S", "s": "%-S", - "S": "%f", + "SSSSSS": "%f", "a": "%p", "DD": "%j", "D": "%-j", + "E": "%a", + "EE": "%a", + "EEE": "%a", + "EEEE": "%A", } date_format = "'yyyy-MM-dd'" @@ -207,8 +241,8 @@ class Hive(Dialect): exp.ArraySize: rename_func("SIZE"), exp.ArraySort: _array_sort, exp.With: no_recursive_cte_sql, - exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", - exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.DateAdd: _add_date_sql, + exp.DateDiff: _date_diff_sql, exp.DateStrToDate: rename_func("TO_DATE"), exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})", diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 6bf4ff0..572f411 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -71,6 +71,7 @@ class Spark(Hive): length=list_get(args, 1), ), "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, + "IIF": exp.If.from_arg_list, } FUNCTION_PARSERS = { @@ -111,6 +112,7 @@ class Spark(Hive): exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}", exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})", exp.VariancePop: rename_func("VAR_POP"), + exp.DateFromParts: rename_func("MAKE_DATE"), } WRAP_DERIVED_VALUES = False diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index ef8c82d..0cba6fe 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -1,5 +1,5 @@ from sqlglot import exp -from sqlglot.dialects.dialect import rename_func +from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func from sqlglot.dialects.mysql import MySQL @@ -14,6 +14,8 @@ class StarRocks(MySQL): TRANSFORMS = { **MySQL.Generator.TRANSFORMS, + exp.JSONExtractScalar: arrow_json_extract_sql, + exp.JSONExtract: arrow_json_extract_sql, exp.DateDiff: rename_func("DATEDIFF"), exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeStrToDate: rename_func("TO_DATE"), diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 1f2e50d..107ace7 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -1,14 +1,149 @@ from sqlglot import exp -from sqlglot.dialects.dialect import Dialect +from sqlglot.dialects.dialect import Dialect, rename_func +from sqlglot.expressions import DataType from sqlglot.generator import Generator +from sqlglot.helper import list_get from sqlglot.parser import Parser +from sqlglot.time import format_time from sqlglot.tokens import Tokenizer, TokenType +FULL_FORMAT_TIME_MAPPING = {"weekday": "%A", "dw": "%A", "w": "%A", "month": "%B", "mm": "%B", "m": "%B"} +DATE_DELTA_INTERVAL = { + "year": "year", + "yyyy": "year", + "yy": "year", + "quarter": "quarter", + "qq": "quarter", + "q": "quarter", + "month": "month", + "mm": "month", + "m": "month", + "week": "week", + "ww": "week", + "wk": "week", + "day": "day", + "dd": "day", + "d": "day", +} + + +def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): + def _format_time(args): + return exp_class( + this=list_get(args, 1), + format=exp.Literal.string( + format_time( + list_get(args, 0).name or (TSQL.time_format if default is True else default), + {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.time_mapping, + ) + ), + ) + + return _format_time + + +def parse_date_delta(exp_class): + def inner_func(args): + unit = DATE_DELTA_INTERVAL.get(list_get(args, 0).name.lower(), "day") + return exp_class(this=list_get(args, 2), expression=list_get(args, 1), unit=unit) + + return inner_func + + +def generate_date_delta(self, e): + func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF" + return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})" + class TSQL(Dialect): null_ordering = "nulls_are_small" time_format = "'yyyy-mm-dd hh:mm:ss'" + time_mapping = { + "yyyy": "%Y", + "yy": "%y", + "year": "%Y", + "qq": "%q", + "q": "%q", + "quarter": "%q", + "dayofyear": "%j", + "day": "%d", + "dy": "%d", + "y": "%Y", + "week": "%W", + "ww": "%W", + "wk": "%W", + "hour": "%h", + "hh": "%I", + "minute": "%M", + "mi": "%M", + "n": "%M", + "second": "%S", + "ss": "%S", + "s": "%-S", + "millisecond": "%f", + "ms": "%f", + "weekday": "%W", + "dw": "%W", + "month": "%m", + "mm": "%M", + "m": "%-M", + "Y": "%Y", + "YYYY": "%Y", + "YY": "%y", + "MMMM": "%B", + "MMM": "%b", + "MM": "%m", + "M": "%-m", + "dd": "%d", + "d": "%-d", + "HH": "%H", + "H": "%-H", + "h": "%-I", + "S": "%f", + } + + convert_format_mapping = { + "0": "%b %d %Y %-I:%M%p", + "1": "%m/%d/%y", + "2": "%y.%m.%d", + "3": "%d/%m/%y", + "4": "%d.%m.%y", + "5": "%d-%m-%y", + "6": "%d %b %y", + "7": "%b %d, %y", + "8": "%H:%M:%S", + "9": "%b %d %Y %-I:%M:%S:%f%p", + "10": "mm-dd-yy", + "11": "yy/mm/dd", + "12": "yymmdd", + "13": "%d %b %Y %H:%M:ss:%f", + "14": "%H:%M:%S:%f", + "20": "%Y-%m-%d %H:%M:%S", + "21": "%Y-%m-%d %H:%M:%S.%f", + "22": "%m/%d/%y %-I:%M:%S %p", + "23": "%Y-%m-%d", + "24": "%H:%M:%S", + "25": "%Y-%m-%d %H:%M:%S.%f", + "100": "%b %d %Y %-I:%M%p", + "101": "%m/%d/%Y", + "102": "%Y.%m.%d", + "103": "%d/%m/%Y", + "104": "%d.%m.%Y", + "105": "%d-%m-%Y", + "106": "%d %b %Y", + "107": "%b %d, %Y", + "108": "%H:%M:%S", + "109": "%b %d %Y %-I:%M:%S:%f%p", + "110": "%m-%d-%Y", + "111": "%Y/%m/%d", + "112": "%Y%m%d", + "113": "%d %b %Y %H:%M:%S:%f", + "114": "%H:%M:%S:%f", + "120": "%Y-%m-%d %H:%M:%S", + "121": "%Y-%m-%d %H:%M:%S.%f", + } + class Tokenizer(Tokenizer): IDENTIFIERS = ['"', ("[", "]")] @@ -29,19 +164,67 @@ class TSQL(Dialect): "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, "XML": TokenType.XML, "SQL_VARIANT": TokenType.VARIANT, + "NVARCHAR(MAX)": TokenType.TEXT, + "VARCHAR(MAX)": TokenType.TEXT, } class Parser(Parser): FUNCTIONS = { **Parser.FUNCTIONS, "CHARINDEX": exp.StrPosition.from_arg_list, + "ISNULL": exp.Coalesce.from_arg_list, + "DATEADD": parse_date_delta(exp.DateAdd), + "DATEDIFF": parse_date_delta(exp.DateDiff), + "DATENAME": tsql_format_time_lambda(exp.TimeToStr, full_format_mapping=True), + "DATEPART": tsql_format_time_lambda(exp.TimeToStr), + "GETDATE": exp.CurrentDate.from_arg_list, + "IIF": exp.If.from_arg_list, + "LEN": exp.Length.from_arg_list, + "REPLICATE": exp.Repeat.from_arg_list, + "JSON_VALUE": exp.JSONExtractScalar.from_arg_list, + } + + VAR_LENGTH_DATATYPES = { + DataType.Type.NVARCHAR, + DataType.Type.VARCHAR, + DataType.Type.CHAR, + DataType.Type.NCHAR, } - def _parse_convert(self): + def _parse_convert(self, strict): to = self._parse_types() self._match(TokenType.COMMA) this = self._parse_field() - return self.expression(exp.Cast, this=this, to=to) + + # Retrieve length of datatype and override to default if not specified + if list_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: + to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False) + + # Check whether a conversion with format is applicable + if self._match(TokenType.COMMA): + format_val = self._parse_number().name + if format_val not in TSQL.convert_format_mapping: + raise ValueError(f"CONVERT function at T-SQL does not support format style {format_val}") + format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val]) + + # Check whether the convert entails a string to date format + if to.this == DataType.Type.DATE: + return self.expression(exp.StrToDate, this=this, format=format_norm) + # Check whether the convert entails a string to datetime format + elif to.this == DataType.Type.DATETIME: + return self.expression(exp.StrToTime, this=this, format=format_norm) + # Check whether the convert entails a date to string format + elif to.this in self.VAR_LENGTH_DATATYPES: + return self.expression( + exp.Cast if strict else exp.TryCast, + to=to, + this=self.expression(exp.TimeToStr, this=this, format=format_norm), + ) + elif to.this == DataType.Type.TEXT: + return self.expression(exp.TimeToStr, this=this, format=format_norm) + + # Entails a simple cast without any format requirement + return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) class Generator(Generator): TYPE_MAPPING = { @@ -52,3 +235,11 @@ class TSQL(Dialect): exp.DataType.Type.DATETIME: "DATETIME2", exp.DataType.Type.VARIANT: "SQL_VARIANT", } + + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.DateAdd: lambda self, e: generate_date_delta(self, e), + exp.DateDiff: lambda self, e: generate_date_delta(self, e), + exp.CurrentDate: rename_func("GETDATE"), + exp.If: rename_func("IIF"), + } |