diff options
Diffstat (limited to 'sqlglot/dialects/drill.py')
-rw-r--r-- | sqlglot/dialects/drill.py | 174 |
1 files changed, 174 insertions, 0 deletions
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py new file mode 100644 index 0000000..eb420aa --- /dev/null +++ b/sqlglot/dialects/drill.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +import re + +from sqlglot import exp, generator, parser, tokens +from sqlglot.dialects.dialect import ( + Dialect, + create_with_partitions_sql, + format_time_lambda, + no_pivot_sql, + no_trycast_sql, + rename_func, + str_position_sql, +) +from sqlglot.dialects.postgres import _lateral_sql + + +def _to_timestamp(args): + # TO_TIMESTAMP accepts either a single double argument or (text, text) + if len(args) == 1 and args[0].is_number: + return exp.UnixToTime.from_arg_list(args) + return format_time_lambda(exp.StrToTime, "drill")(args) + + +def _str_to_time_sql(self, expression): + return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})" + + +def _ts_or_ds_to_date_sql(self, expression): + time_format = self.format_time(expression) + if time_format and time_format not in (Drill.time_format, Drill.date_format): + return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" + return f"CAST({self.sql(expression, 'this')} AS DATE)" + + +def _date_add_sql(kind): + def func(self, expression): + this = self.sql(expression, "this") + unit = expression.text("unit").upper() or "DAY" + expression = self.sql(expression, "expression") + return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})" + + return func + + +def if_sql(self, expression): + """ + Drill requires backticks around certain SQL reserved words, IF being one of them, This function + adds the backticks around the keyword IF. + Args: + self: The Drill dialect + expression: The input IF expression + + Returns: The expression with IF in backticks. + + """ + expressions = self.format_args( + expression.this, expression.args.get("true"), expression.args.get("false") + ) + return f"`IF`({expressions})" + + +def _str_to_date(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + if time_format == Drill.date_format: + return f"CAST({this} AS DATE)" + return f"TO_DATE({this}, {time_format})" + + +class Drill(Dialect): + normalize_functions = None + null_ordering = "nulls_are_last" + date_format = "'yyyy-MM-dd'" + dateint_format = "'yyyyMMdd'" + time_format = "'yyyy-MM-dd HH:mm:ss'" + + time_mapping = { + "y": "%Y", + "Y": "%Y", + "YYYY": "%Y", + "yyyy": "%Y", + "YY": "%y", + "yy": "%y", + "MMMM": "%B", + "MMM": "%b", + "MM": "%m", + "M": "%-m", + "dd": "%d", + "d": "%-d", + "HH": "%H", + "H": "%-H", + "hh": "%I", + "h": "%-I", + "mm": "%M", + "m": "%-M", + "ss": "%S", + "s": "%-S", + "SSSSSS": "%f", + "a": "%p", + "DD": "%j", + "D": "%-j", + "E": "%a", + "EE": "%a", + "EEE": "%a", + "EEEE": "%A", + "''T''": "T", + } + + class Tokenizer(tokens.Tokenizer): + QUOTES = ["'"] + IDENTIFIERS = ["`"] + ESCAPES = ["\\"] + ENCODE = "utf-8" + + class Parser(parser.Parser): + STRICT_CAST = False + + FUNCTIONS = { + **parser.Parser.FUNCTIONS, + "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, + "TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"), + } + + class Generator(generator.Generator): + TYPE_MAPPING = { + **generator.Generator.TYPE_MAPPING, + exp.DataType.Type.INT: "INTEGER", + exp.DataType.Type.SMALLINT: "INTEGER", + exp.DataType.Type.TINYINT: "INTEGER", + exp.DataType.Type.BINARY: "VARBINARY", + exp.DataType.Type.TEXT: "VARCHAR", + exp.DataType.Type.NCHAR: "VARCHAR", + exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", + exp.DataType.Type.DATETIME: "TIMESTAMP", + } + + ROOT_PROPERTIES = {exp.PartitionedByProperty} + + TRANSFORMS = { + **generator.Generator.TRANSFORMS, + exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", + exp.Lateral: _lateral_sql, + exp.ArrayContains: rename_func("REPEATED_CONTAINS"), + exp.ArraySize: rename_func("REPEATED_COUNT"), + exp.Create: create_with_partitions_sql, + exp.DateAdd: _date_add_sql("ADD"), + exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.DateSub: _date_add_sql("SUB"), + exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)", + exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})", + exp.If: if_sql, + exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}", + exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), + exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}", + exp.Pivot: no_pivot_sql, + exp.RegexpLike: rename_func("REGEXP_MATCHES"), + exp.StrPosition: str_position_sql, + exp.StrToDate: _str_to_date, + exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", + exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TryCast: no_trycast_sql, + exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)", + exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", + } + + def normalize_func(self, name): + return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`" |