diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/__init__.py | 15 | ||||
-rw-r--r-- | sqlglot/dialects/bigquery.py | 128 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 48 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 268 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 156 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 304 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 163 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 63 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 109 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 216 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 145 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 106 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 63 | ||||
-rw-r--r-- | sqlglot/dialects/starrocks.py | 12 | ||||
-rw-r--r-- | sqlglot/dialects/tableau.py | 37 | ||||
-rw-r--r-- | sqlglot/dialects/trino.py | 10 |
16 files changed, 1843 insertions, 0 deletions
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py new file mode 100644 index 0000000..5aa7d77 --- /dev/null +++ b/sqlglot/dialects/__init__.py @@ -0,0 +1,15 @@ +from sqlglot.dialects.bigquery import BigQuery +from sqlglot.dialects.clickhouse import ClickHouse +from sqlglot.dialects.dialect import Dialect, Dialects +from sqlglot.dialects.duckdb import DuckDB +from sqlglot.dialects.hive import Hive +from sqlglot.dialects.mysql import MySQL +from sqlglot.dialects.oracle import Oracle +from sqlglot.dialects.postgres import Postgres +from sqlglot.dialects.presto import Presto +from sqlglot.dialects.snowflake import Snowflake +from sqlglot.dialects.spark import Spark +from sqlglot.dialects.sqlite import SQLite +from sqlglot.dialects.starrocks import StarRocks +from sqlglot.dialects.tableau import Tableau +from sqlglot.dialects.trino import Trino diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py new file mode 100644 index 0000000..f4e87c3 --- /dev/null +++ b/sqlglot/dialects/bigquery.py @@ -0,0 +1,128 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import ( + Dialect, + inline_array_sql, + no_ilike_sql, + rename_func, +) +from sqlglot.generator import Generator +from sqlglot.helper import list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _date_add(expression_class): + def func(args): + interval = list_get(args, 1) + return expression_class( + this=list_get(args, 0), + expression=interval.this, + unit=interval.args.get("unit"), + ) + + return func + + +def _date_add_sql(data_type, kind): + def func(self, expression): + this = self.sql(expression, "this") + unit = self.sql(expression, "unit") or "'day'" + expression = self.sql(expression, "expression") + return f"{data_type}_{kind}({this}, INTERVAL {expression} {unit})" + + return func + + +class BigQuery(Dialect): + unnest_column_only = True + + class Tokenizer(Tokenizer): + QUOTES = [ + (prefix + quote, quote) if prefix else quote + for quote in ["'", '"', '"""', "'''"] + for prefix in ["", "r", "R"] + ] + IDENTIFIERS = ["`"] + ESCAPE = "\\" + + KEYWORDS = { + **Tokenizer.KEYWORDS, + "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, + "CURRENT_TIME": TokenType.CURRENT_TIME, + "GEOGRAPHY": TokenType.GEOGRAPHY, + "INT64": TokenType.BIGINT, + "FLOAT64": TokenType.DOUBLE, + "QUALIFY": TokenType.QUALIFY, + "UNKNOWN": TokenType.NULL, + "WINDOW": TokenType.WINDOW, + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "DATE_ADD": _date_add(exp.DateAdd), + "DATETIME_ADD": _date_add(exp.DatetimeAdd), + "TIME_ADD": _date_add(exp.TimeAdd), + "TIMESTAMP_ADD": _date_add(exp.TimestampAdd), + "DATE_SUB": _date_add(exp.DateSub), + "DATETIME_SUB": _date_add(exp.DatetimeSub), + "TIME_SUB": _date_add(exp.TimeSub), + "TIMESTAMP_SUB": _date_add(exp.TimestampSub), + } + + NO_PAREN_FUNCTIONS = { + **Parser.NO_PAREN_FUNCTIONS, + TokenType.CURRENT_DATETIME: exp.CurrentDatetime, + TokenType.CURRENT_TIME: exp.CurrentTime, + } + + class Generator(Generator): + TRANSFORMS = { + exp.Array: inline_array_sql, + exp.ArraySize: rename_func("ARRAY_LENGTH"), + exp.DateAdd: _date_add_sql("DATE", "ADD"), + exp.DateSub: _date_add_sql("DATE", "SUB"), + exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"), + exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), + exp.ILike: no_ilike_sql, + exp.TimeAdd: _date_add_sql("TIME", "ADD"), + exp.TimeSub: _date_add_sql("TIME", "SUB"), + exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), + exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), + exp.VariancePop: rename_func("VAR_POP"), + } + + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.TINYINT: "INT64", + exp.DataType.Type.SMALLINT: "INT64", + exp.DataType.Type.INT: "INT64", + exp.DataType.Type.BIGINT: "INT64", + exp.DataType.Type.DECIMAL: "NUMERIC", + exp.DataType.Type.FLOAT: "FLOAT64", + exp.DataType.Type.DOUBLE: "FLOAT64", + exp.DataType.Type.BOOLEAN: "BOOL", + exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.VARCHAR: "STRING", + exp.DataType.Type.NVARCHAR: "STRING", + } + + def in_unnest_op(self, unnest): + return self.sql(unnest) + + def union_op(self, expression): + return f"UNION{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" + + def except_op(self, expression): + if not expression.args.get("distinct", False): + self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery") + return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" + + def intersect_op(self, expression): + if not expression.args.get("distinct", False): + self.unsupported( + "INTERSECT without DISTINCT is not supported in BigQuery" + ) + return ( + f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" + ) diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py new file mode 100644 index 0000000..55dad7a --- /dev/null +++ b/sqlglot/dialects/clickhouse.py @@ -0,0 +1,48 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect, inline_array_sql +from sqlglot.generator import Generator +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +class ClickHouse(Dialect): + normalize_functions = None + null_ordering = "nulls_are_last" + + class Tokenizer(Tokenizer): + IDENTIFIERS = ['"', "`"] + + KEYWORDS = { + **Tokenizer.KEYWORDS, + "NULLABLE": TokenType.NULLABLE, + "FINAL": TokenType.FINAL, + "INT8": TokenType.TINYINT, + "INT16": TokenType.SMALLINT, + "INT32": TokenType.INT, + "INT64": TokenType.BIGINT, + "FLOAT32": TokenType.FLOAT, + "FLOAT64": TokenType.DOUBLE, + } + + class Parser(Parser): + def _parse_table(self, schema=False): + this = super()._parse_table(schema) + + if self._match(TokenType.FINAL): + this = self.expression(exp.Final, this=this) + + return this + + class Generator(Generator): + STRUCT_DELIMITER = ("(", ")") + + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.NULLABLE: "Nullable", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.Array: inline_array_sql, + exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", + } diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py new file mode 100644 index 0000000..8045f7a --- /dev/null +++ b/sqlglot/dialects/dialect.py @@ -0,0 +1,268 @@ +from enum import Enum + +from sqlglot import exp +from sqlglot.generator import Generator +from sqlglot.helper import csv, list_get +from sqlglot.parser import Parser +from sqlglot.time import format_time +from sqlglot.tokens import Tokenizer +from sqlglot.trie import new_trie + + +class Dialects(str, Enum): + DIALECT = "" + + BIGQUERY = "bigquery" + CLICKHOUSE = "clickhouse" + DUCKDB = "duckdb" + HIVE = "hive" + MYSQL = "mysql" + ORACLE = "oracle" + POSTGRES = "postgres" + PRESTO = "presto" + SNOWFLAKE = "snowflake" + SPARK = "spark" + SQLITE = "sqlite" + STARROCKS = "starrocks" + TABLEAU = "tableau" + TRINO = "trino" + + +class _Dialect(type): + classes = {} + + @classmethod + def __getitem__(cls, key): + return cls.classes[key] + + @classmethod + def get(cls, key, default=None): + return cls.classes.get(key, default) + + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + enum = Dialects.__members__.get(clsname.upper()) + cls.classes[enum.value if enum is not None else clsname.lower()] = klass + + klass.time_trie = new_trie(klass.time_mapping) + klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()} + klass.inverse_time_trie = new_trie(klass.inverse_time_mapping) + + klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) + klass.parser_class = getattr(klass, "Parser", Parser) + klass.generator_class = getattr(klass, "Generator", Generator) + + klass.tokenizer = klass.tokenizer_class() + klass.quote_start, klass.quote_end = list(klass.tokenizer_class.QUOTES.items())[ + 0 + ] + klass.identifier_start, klass.identifier_end = list( + klass.tokenizer_class.IDENTIFIERS.items() + )[0] + + return klass + + +class Dialect(metaclass=_Dialect): + index_offset = 0 + unnest_column_only = False + alias_post_tablesample = False + normalize_functions = "upper" + null_ordering = "nulls_are_small" + + date_format = "'%Y-%m-%d'" + dateint_format = "'%Y%m%d'" + time_format = "'%Y-%m-%d %H:%M:%S'" + time_mapping = {} + + # autofilled + quote_start = None + quote_end = None + identifier_start = None + identifier_end = None + + time_trie = None + inverse_time_mapping = None + inverse_time_trie = None + tokenizer_class = None + parser_class = None + generator_class = None + tokenizer = None + + @classmethod + def get_or_raise(cls, dialect): + if not dialect: + return cls + result = cls.get(dialect) + if not result: + raise ValueError(f"Unknown dialect '{dialect}'") + return result + + @classmethod + def format_time(cls, expression): + if isinstance(expression, str): + return exp.Literal.string( + format_time( + expression[1:-1], # the time formats are quoted + cls.time_mapping, + cls.time_trie, + ) + ) + if expression and expression.is_string: + return exp.Literal.string( + format_time( + expression.this, + cls.time_mapping, + cls.time_trie, + ) + ) + return expression + + def parse(self, sql, **opts): + return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql) + + def parse_into(self, expression_type, sql, **opts): + return self.parser(**opts).parse_into( + expression_type, self.tokenizer.tokenize(sql), sql + ) + + def generate(self, expression, **opts): + return self.generator(**opts).generate(expression) + + def transpile(self, code, **opts): + return self.generate(self.parse(code), **opts) + + def parser(self, **opts): + return self.parser_class( + **{ + "index_offset": self.index_offset, + "unnest_column_only": self.unnest_column_only, + "alias_post_tablesample": self.alias_post_tablesample, + "null_ordering": self.null_ordering, + **opts, + }, + ) + + def generator(self, **opts): + return self.generator_class( + **{ + "quote_start": self.quote_start, + "quote_end": self.quote_end, + "identifier_start": self.identifier_start, + "identifier_end": self.identifier_end, + "escape": self.tokenizer_class.ESCAPE, + "index_offset": self.index_offset, + "time_mapping": self.inverse_time_mapping, + "time_trie": self.inverse_time_trie, + "unnest_column_only": self.unnest_column_only, + "alias_post_tablesample": self.alias_post_tablesample, + "normalize_functions": self.normalize_functions, + "null_ordering": self.null_ordering, + **opts, + } + ) + + +def rename_func(name): + return ( + lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})" + ) + + +def approx_count_distinct_sql(self, expression): + if expression.args.get("accuracy"): + self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") + return f"APPROX_COUNT_DISTINCT({self.sql(expression, 'this')})" + + +def if_sql(self, expression): + expressions = csv( + self.sql(expression, "this"), + self.sql(expression, "true"), + self.sql(expression, "false"), + ) + return f"IF({expressions})" + + +def arrow_json_extract_sql(self, expression): + return f"{self.sql(expression, 'this')}->{self.sql(expression, 'path')}" + + +def arrow_json_extract_scalar_sql(self, expression): + return f"{self.sql(expression, 'this')}->>{self.sql(expression, 'path')}" + + +def inline_array_sql(self, expression): + return f"[{self.expressions(expression)}]" + + +def no_ilike_sql(self, expression): + return self.like_sql( + exp.Like( + this=exp.Lower(this=expression.this), + expression=expression.args["expression"], + ) + ) + + +def no_paren_current_date_sql(self, expression): + zone = self.sql(expression, "this") + return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" + + +def no_recursive_cte_sql(self, expression): + if expression.args.get("recursive"): + self.unsupported("Recursive CTEs are unsupported") + expression.args["recursive"] = False + return self.with_sql(expression) + + +def no_safe_divide_sql(self, expression): + n = self.sql(expression, "this") + d = self.sql(expression, "expression") + return f"IF({d} <> 0, {n} / {d}, NULL)" + + +def no_tablesample_sql(self, expression): + self.unsupported("TABLESAMPLE unsupported") + return self.sql(expression.this) + + +def no_trycast_sql(self, expression): + return self.cast_sql(expression) + + +def str_position_sql(self, expression): + this = self.sql(expression, "this") + substr = self.sql(expression, "substr") + position = self.sql(expression, "position") + if position: + return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" + return f"STRPOS({this}, {substr})" + + +def struct_extract_sql(self, expression): + this = self.sql(expression, "this") + struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) + return f"{this}.{struct_key}" + + +def format_time_lambda(exp_class, dialect, default=None): + """Helper used for time expressions. + + Args + exp_class (Class): the expression class to instantiate + dialect (string): sql dialect + default (Option[bool | str]): the default format, True being time + """ + + def _format_time(args): + return exp_class( + this=list_get(args, 0), + format=Dialect[dialect].format_time( + list_get(args, 1) + or (Dialect[dialect].time_format if default is True else default) + ), + ) + + return _format_time diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py new file mode 100644 index 0000000..d83a620 --- /dev/null +++ b/sqlglot/dialects/duckdb.py @@ -0,0 +1,156 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import ( + Dialect, + approx_count_distinct_sql, + arrow_json_extract_scalar_sql, + arrow_json_extract_sql, + format_time_lambda, + no_safe_divide_sql, + no_tablesample_sql, + rename_func, + str_position_sql, +) +from sqlglot.generator import Generator +from sqlglot.helper import list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _unix_to_time(self, expression): + return f"TO_TIMESTAMP(CAST({self.sql(expression, 'this')} AS BIGINT))" + + +def _str_to_time_sql(self, expression): + return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})" + + +def _ts_or_ds_add(self, expression): + this = self.sql(expression, "this") + e = self.sql(expression, "expression") + unit = self.sql(expression, "unit").strip("'") or "DAY" + return f"CAST({this} AS DATE) + INTERVAL {e} {unit}" + + +def _ts_or_ds_to_date_sql(self, expression): + time_format = self.format_time(expression) + if time_format and time_format not in (DuckDB.time_format, DuckDB.date_format): + return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" + return f"CAST({self.sql(expression, 'this')} AS DATE)" + + +def _date_add(self, expression): + this = self.sql(expression, "this") + e = self.sql(expression, "expression") + unit = self.sql(expression, "unit").strip("'") or "DAY" + return f"{this} + INTERVAL {e} {unit}" + + +def _array_sort_sql(self, expression): + if expression.expression: + self.unsupported("DUCKDB ARRAY_SORT does not support a comparator") + return f"ARRAY_SORT({self.sql(expression, 'this')})" + + +def _sort_array_sql(self, expression): + this = self.sql(expression, "this") + if expression.args.get("asc") == exp.FALSE: + return f"ARRAY_REVERSE_SORT({this})" + return f"ARRAY_SORT({this})" + + +def _sort_array_reverse(args): + return exp.SortArray(this=list_get(args, 0), asc=exp.FALSE) + + +def _struct_pack_sql(self, expression): + args = [ + self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) + for e in expression.expressions + ] + return f"STRUCT_PACK({', '.join(args)})" + + +class DuckDB(Dialect): + class Tokenizer(Tokenizer): + KEYWORDS = { + **Tokenizer.KEYWORDS, + ":=": TokenType.EQ, + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, + "ARRAY_LENGTH": exp.ArraySize.from_arg_list, + "ARRAY_SORT": exp.SortArray.from_arg_list, + "ARRAY_REVERSE_SORT": _sort_array_reverse, + "EPOCH": exp.TimeToUnix.from_arg_list, + "EPOCH_MS": lambda args: exp.UnixToTime( + this=exp.Div( + this=list_get(args, 0), + expression=exp.Literal.number(1000), + ) + ), + "LIST_SORT": exp.SortArray.from_arg_list, + "LIST_REVERSE_SORT": _sort_array_reverse, + "LIST_VALUE": exp.Array.from_arg_list, + "REGEXP_MATCHES": exp.RegexpLike.from_arg_list, + "STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"), + "STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"), + "STR_SPLIT": exp.Split.from_arg_list, + "STRING_SPLIT": exp.Split.from_arg_list, + "STRING_TO_ARRAY": exp.Split.from_arg_list, + "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, + "STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, + "STRUCT_PACK": exp.Struct.from_arg_list, + "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, + "UNNEST": exp.Explode.from_arg_list, + } + + class Generator(Generator): + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.ApproxDistinct: approx_count_distinct_sql, + exp.Array: lambda self, e: f"LIST_VALUE({self.expressions(e, flat=True)})", + exp.ArraySize: rename_func("ARRAY_LENGTH"), + exp.ArraySort: _array_sort_sql, + exp.ArraySum: rename_func("LIST_SUM"), + exp.DateAdd: _date_add, + exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", + exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)", + exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)", + exp.Explode: rename_func("UNNEST"), + exp.JSONExtract: arrow_json_extract_sql, + exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.JSONBExtract: arrow_json_extract_sql, + exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, + exp.RegexpLike: rename_func("REGEXP_MATCHES"), + exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"), + exp.SafeDivide: no_safe_divide_sql, + exp.Split: rename_func("STR_SPLIT"), + exp.SortArray: _sort_array_sql, + exp.StrPosition: str_position_sql, + exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", + exp.StrToTime: _str_to_time_sql, + exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))", + exp.Struct: _struct_pack_sql, + exp.TableSample: no_tablesample_sql, + exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", + exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))", + exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToUnix: rename_func("EPOCH"), + exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)", + exp.TsOrDsAdd: _ts_or_ds_add, + exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.UnixToStr: lambda self, e: f"STRFTIME({_unix_to_time(self, e)}, {self.format_time(e)})", + exp.UnixToTime: _unix_to_time, + exp.UnixToTimeStr: lambda self, e: f"CAST({_unix_to_time(self, e)} AS TEXT)", + } + + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.VARCHAR: "TEXT", + exp.DataType.Type.NVARCHAR: "TEXT", + } diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py new file mode 100644 index 0000000..e3f3f39 --- /dev/null +++ b/sqlglot/dialects/hive.py @@ -0,0 +1,304 @@ +from sqlglot import exp, transforms +from sqlglot.dialects.dialect import ( + Dialect, + approx_count_distinct_sql, + format_time_lambda, + if_sql, + no_ilike_sql, + no_recursive_cte_sql, + no_safe_divide_sql, + no_trycast_sql, + rename_func, + struct_extract_sql, +) +from sqlglot.generator import Generator +from sqlglot.helper import csv, list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer + + +def _parse_map(args): + keys = [] + values = [] + for i in range(0, len(args), 2): + keys.append(args[i]) + values.append(args[i + 1]) + return HiveMap( + keys=exp.Array(expressions=keys), + values=exp.Array(expressions=values), + ) + + +def _map_sql(self, expression): + keys = expression.args["keys"] + values = expression.args["values"] + + if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): + self.unsupported("Cannot convert array columns into map use SparkSQL instead.") + return f"MAP({self.sql(keys)}, {self.sql(values)})" + + args = [] + for key, value in zip(keys.expressions, values.expressions): + args.append(self.sql(key)) + args.append(self.sql(value)) + return f"MAP({csv(*args)})" + + +def _array_sort(self, expression): + if expression.expression: + self.unsupported("Hive SORT_ARRAY does not support a comparator") + return f"SORT_ARRAY({self.sql(expression, 'this')})" + + +def _property_sql(self, expression): + key = expression.name + value = self.sql(expression, "value") + return f"'{key}' = {value}" + + +def _str_to_unix(self, expression): + return f"UNIX_TIMESTAMP({csv(self.sql(expression, 'this'), _time_format(self, expression))})" + + +def _str_to_date(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + if time_format not in (Hive.time_format, Hive.date_format): + this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))" + return f"CAST({this} AS DATE)" + + +def _str_to_time(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + if time_format not in (Hive.time_format, Hive.date_format): + this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))" + return f"CAST({this} AS TIMESTAMP)" + + +def _time_format(self, expression): + time_format = self.format_time(expression) + if time_format == Hive.time_format: + return None + return time_format + + +def _time_to_str(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + return f"DATE_FORMAT({this}, {time_format})" + + +def _to_date_sql(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + if time_format and time_format not in (Hive.time_format, Hive.date_format): + return f"TO_DATE({this}, {time_format})" + return f"TO_DATE({this})" + + +def _unnest_to_explode_sql(self, expression): + unnest = expression.this + if isinstance(unnest, exp.Unnest): + alias = unnest.args.get("alias") + udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode + return "".join( + self.sql( + exp.Lateral( + this=udtf(this=expression), + alias=exp.TableAlias(this=alias.this, columns=[column]), + ) + ) + for expression, column in zip( + unnest.expressions, alias.columns if alias else [] + ) + ) + return self.join_sql(expression) + + +def _index_sql(self, expression): + this = self.sql(expression, "this") + table = self.sql(expression, "table") + columns = self.sql(expression, "columns") + return f"{this} ON TABLE {table} {columns}" + + +class HiveMap(exp.Map): + is_var_len_args = True + + +class Hive(Dialect): + alias_post_tablesample = True + + time_mapping = { + "y": "%Y", + "Y": "%Y", + "YYYY": "%Y", + "yyyy": "%Y", + "YY": "%y", + "yy": "%y", + "MMMM": "%B", + "MMM": "%b", + "MM": "%m", + "M": "%-m", + "dd": "%d", + "d": "%-d", + "HH": "%H", + "H": "%-H", + "hh": "%I", + "h": "%-I", + "mm": "%M", + "m": "%-M", + "ss": "%S", + "s": "%-S", + "S": "%f", + } + + date_format = "'yyyy-MM-dd'" + dateint_format = "'yyyyMMdd'" + time_format = "'yyyy-MM-dd HH:mm:ss'" + + class Tokenizer(Tokenizer): + QUOTES = ["'", '"'] + IDENTIFIERS = ["`"] + ESCAPE = "\\" + ENCODE = "utf-8" + + NUMERIC_LITERALS = { + "L": "BIGINT", + "S": "SMALLINT", + "Y": "TINYINT", + "D": "DOUBLE", + "F": "FLOAT", + "BD": "DECIMAL", + } + + class Parser(Parser): + STRICT_CAST = False + + FUNCTIONS = { + **Parser.FUNCTIONS, + "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, + "COLLECT_LIST": exp.ArrayAgg.from_arg_list, + "DATE_ADD": lambda args: exp.TsOrDsAdd( + this=list_get(args, 0), + expression=list_get(args, 1), + unit=exp.Literal.string("DAY"), + ), + "DATEDIFF": lambda args: exp.DateDiff( + this=exp.TsOrDsToDate(this=list_get(args, 0)), + expression=exp.TsOrDsToDate(this=list_get(args, 1)), + ), + "DATE_SUB": lambda args: exp.TsOrDsAdd( + this=list_get(args, 0), + expression=exp.Mul( + this=list_get(args, 1), + expression=exp.Literal.number(-1), + ), + unit=exp.Literal.string("DAY"), + ), + "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"), + "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=list_get(args, 0))), + "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), + "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, + "LOCATE": lambda args: exp.StrPosition( + this=list_get(args, 1), + substr=list_get(args, 0), + position=list_get(args, 2), + ), + "LOG": ( + lambda args: exp.Log.from_arg_list(args) + if len(args) > 1 + else exp.Ln.from_arg_list(args) + ), + "MAP": _parse_map, + "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), + "PERCENTILE": exp.Quantile.from_arg_list, + "COLLECT_SET": exp.SetAgg.from_arg_list, + "SIZE": exp.ArraySize.from_arg_list, + "SPLIT": exp.RegexpSplit.from_arg_list, + "TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"), + "UNIX_TIMESTAMP": format_time_lambda(exp.StrToUnix, "hive", True), + "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), + } + + class Generator(Generator): + ROOT_PROPERTIES = [ + exp.PartitionedByProperty, + exp.FileFormatProperty, + exp.SchemaCommentProperty, + exp.LocationProperty, + exp.TableFormatProperty, + ] + WITH_PROPERTIES = [exp.AnonymousProperty] + + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.TEXT: "STRING", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, + exp.AnonymousProperty: _property_sql, + exp.ApproxDistinct: approx_count_distinct_sql, + exp.ArrayAgg: rename_func("COLLECT_LIST"), + exp.ArraySize: rename_func("SIZE"), + exp.ArraySort: _array_sort, + exp.With: no_recursive_cte_sql, + exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.DateStrToDate: rename_func("TO_DATE"), + exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)", + exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})", + exp.FileFormatProperty: lambda self, e: f"STORED AS {e.text('value').upper()}", + exp.If: if_sql, + exp.Index: _index_sql, + exp.ILike: no_ilike_sql, + exp.Join: _unnest_to_explode_sql, + exp.JSONExtract: rename_func("GET_JSON_OBJECT"), + exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), + exp.Map: _map_sql, + HiveMap: _map_sql, + exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}", + exp.Quantile: rename_func("PERCENTILE"), + exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"), + exp.RegexpSplit: rename_func("SPLIT"), + exp.SafeDivide: no_safe_divide_sql, + exp.SchemaCommentProperty: lambda self, e: f"COMMENT {self.sql(e.args['value'])}", + exp.SetAgg: rename_func("COLLECT_SET"), + exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", + exp.StrPosition: lambda self, e: f"LOCATE({csv(self.sql(e, 'substr'), self.sql(e, 'this'), self.sql(e, 'position'))})", + exp.StrToDate: _str_to_date, + exp.StrToTime: _str_to_time, + exp.StrToUnix: _str_to_unix, + exp.StructExtract: struct_extract_sql, + exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'value')}", + exp.TimeStrToDate: rename_func("TO_DATE"), + exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", + exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TimeToStr: _time_to_str, + exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)", + exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.TsOrDsToDate: _to_date_sql, + exp.TryCast: no_trycast_sql, + exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({csv(self.sql(e, 'this'), _time_format(self, e))})", + exp.UnixToTime: rename_func("FROM_UNIXTIME"), + exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), + } + + def with_properties(self, properties): + return self.properties( + properties, + prefix="TBLPROPERTIES", + ) + + def datatype_sql(self, expression): + if ( + expression.this + in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR) + and not expression.expressions + ): + expression = exp.DataType.build("text") + return super().datatype_sql(expression) diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py new file mode 100644 index 0000000..93800a6 --- /dev/null +++ b/sqlglot/dialects/mysql.py @@ -0,0 +1,163 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import ( + Dialect, + no_ilike_sql, + no_paren_current_date_sql, + no_tablesample_sql, + no_trycast_sql, +) +from sqlglot.generator import Generator +from sqlglot.helper import list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _date_trunc_sql(self, expression): + unit = expression.text("unit").lower() + + this = self.sql(expression.this) + + if unit == "day": + return f"DATE({this})" + + if unit == "week": + concat = f"CONCAT(YEAR({this}), ' ', WEEK({this}, 1), ' 1')" + date_format = "%Y %u %w" + elif unit == "month": + concat = f"CONCAT(YEAR({this}), ' ', MONTH({this}), ' 1')" + date_format = "%Y %c %e" + elif unit == "quarter": + concat = f"CONCAT(YEAR({this}), ' ', QUARTER({this}) * 3 - 2, ' 1')" + date_format = "%Y %c %e" + elif unit == "year": + concat = f"CONCAT(YEAR({this}), ' 1 1')" + date_format = "%Y %c %e" + else: + self.unsupported("Unexpected interval unit: {unit}") + return f"DATE({this})" + + return f"STR_TO_DATE({concat}, '{date_format}')" + + +def _str_to_date(args): + date_format = MySQL.format_time(list_get(args, 1)) + return exp.StrToDate(this=list_get(args, 0), format=date_format) + + +def _str_to_date_sql(self, expression): + date_format = self.format_time(expression) + return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})" + + +def _date_add(expression_class): + def func(args): + interval = list_get(args, 1) + return expression_class( + this=list_get(args, 0), + expression=interval.this, + unit=exp.Literal.string(interval.text("unit").lower()), + ) + + return func + + +def _date_add_sql(kind): + def func(self, expression): + this = self.sql(expression, "this") + unit = expression.text("unit").upper() or "DAY" + expression = self.sql(expression, "expression") + return f"DATE_{kind}({this}, INTERVAL {expression} {unit})" + + return func + + +class MySQL(Dialect): + # https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions + time_mapping = { + "%M": "%B", + "%c": "%-m", + "%e": "%-d", + "%h": "%I", + "%i": "%M", + "%s": "%S", + "%S": "%S", + "%u": "%W", + } + + class Tokenizer(Tokenizer): + QUOTES = ["'", '"'] + COMMENTS = ["--", "#", ("/*", "*/")] + IDENTIFIERS = ["`"] + + KEYWORDS = { + **Tokenizer.KEYWORDS, + "_ARMSCII8": TokenType.INTRODUCER, + "_ASCII": TokenType.INTRODUCER, + "_BIG5": TokenType.INTRODUCER, + "_BINARY": TokenType.INTRODUCER, + "_CP1250": TokenType.INTRODUCER, + "_CP1251": TokenType.INTRODUCER, + "_CP1256": TokenType.INTRODUCER, + "_CP1257": TokenType.INTRODUCER, + "_CP850": TokenType.INTRODUCER, + "_CP852": TokenType.INTRODUCER, + "_CP866": TokenType.INTRODUCER, + "_CP932": TokenType.INTRODUCER, + "_DEC8": TokenType.INTRODUCER, + "_EUCJPMS": TokenType.INTRODUCER, + "_EUCKR": TokenType.INTRODUCER, + "_GB18030": TokenType.INTRODUCER, + "_GB2312": TokenType.INTRODUCER, + "_GBK": TokenType.INTRODUCER, + "_GEOSTD8": TokenType.INTRODUCER, + "_GREEK": TokenType.INTRODUCER, + "_HEBREW": TokenType.INTRODUCER, + "_HP8": TokenType.INTRODUCER, + "_KEYBCS2": TokenType.INTRODUCER, + "_KOI8R": TokenType.INTRODUCER, + "_KOI8U": TokenType.INTRODUCER, + "_LATIN1": TokenType.INTRODUCER, + "_LATIN2": TokenType.INTRODUCER, + "_LATIN5": TokenType.INTRODUCER, + "_LATIN7": TokenType.INTRODUCER, + "_MACCE": TokenType.INTRODUCER, + "_MACROMAN": TokenType.INTRODUCER, + "_SJIS": TokenType.INTRODUCER, + "_SWE7": TokenType.INTRODUCER, + "_TIS620": TokenType.INTRODUCER, + "_UCS2": TokenType.INTRODUCER, + "_UJIS": TokenType.INTRODUCER, + "_UTF8": TokenType.INTRODUCER, + "_UTF16": TokenType.INTRODUCER, + "_UTF16LE": TokenType.INTRODUCER, + "_UTF32": TokenType.INTRODUCER, + "_UTF8MB3": TokenType.INTRODUCER, + "_UTF8MB4": TokenType.INTRODUCER, + } + + class Parser(Parser): + STRICT_CAST = False + + FUNCTIONS = { + **Parser.FUNCTIONS, + "DATE_ADD": _date_add(exp.DateAdd), + "DATE_SUB": _date_add(exp.DateSub), + "STR_TO_DATE": _str_to_date, + } + + class Generator(Generator): + NULL_ORDERING_SUPPORTED = False + + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.CurrentDate: no_paren_current_date_sql, + exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", + exp.ILike: no_ilike_sql, + exp.TableSample: no_tablesample_sql, + exp.TryCast: no_trycast_sql, + exp.DateAdd: _date_add_sql("ADD"), + exp.DateSub: _date_add_sql("SUB"), + exp.DateTrunc: _date_trunc_sql, + exp.StrToDate: _str_to_date_sql, + exp.StrToTime: _str_to_date_sql, + } diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py new file mode 100644 index 0000000..9c8b6f2 --- /dev/null +++ b/sqlglot/dialects/oracle.py @@ -0,0 +1,63 @@ +from sqlglot import exp, transforms +from sqlglot.dialects.dialect import Dialect, no_ilike_sql +from sqlglot.generator import Generator +from sqlglot.helper import csv +from sqlglot.tokens import Tokenizer, TokenType + + +def _limit_sql(self, expression): + return self.fetch_sql(exp.Fetch(direction="FIRST", count=expression.expression)) + + +class Oracle(Dialect): + class Generator(Generator): + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.TINYINT: "NUMBER", + exp.DataType.Type.SMALLINT: "NUMBER", + exp.DataType.Type.INT: "NUMBER", + exp.DataType.Type.BIGINT: "NUMBER", + exp.DataType.Type.DECIMAL: "NUMBER", + exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", + exp.DataType.Type.VARCHAR: "VARCHAR2", + exp.DataType.Type.NVARCHAR: "NVARCHAR2", + exp.DataType.Type.TEXT: "CLOB", + exp.DataType.Type.BINARY: "BLOB", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, + exp.ILike: no_ilike_sql, + exp.Limit: _limit_sql, + } + + def query_modifiers(self, expression, *sqls): + return csv( + *sqls, + *[self.sql(sql) for sql in expression.args.get("laterals", [])], + *[self.sql(sql) for sql in expression.args.get("joins", [])], + self.sql(expression, "where"), + self.sql(expression, "group"), + self.sql(expression, "having"), + self.sql(expression, "qualify"), + self.sql(expression, "window"), + self.sql(expression, "distribute"), + self.sql(expression, "sort"), + self.sql(expression, "cluster"), + self.sql(expression, "order"), + self.sql(expression, "offset"), # offset before limit in oracle + self.sql(expression, "limit"), + sep="", + ) + + def offset_sql(self, expression): + return f"{super().offset_sql(expression)} ROWS" + + class Tokenizer(Tokenizer): + KEYWORDS = { + **Tokenizer.KEYWORDS, + "TOP": TokenType.TOP, + "VARCHAR2": TokenType.VARCHAR, + "NVARCHAR2": TokenType.NVARCHAR, + } diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py new file mode 100644 index 0000000..61dff86 --- /dev/null +++ b/sqlglot/dialects/postgres.py @@ -0,0 +1,109 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import ( + Dialect, + arrow_json_extract_scalar_sql, + arrow_json_extract_sql, + format_time_lambda, + no_paren_current_date_sql, + no_tablesample_sql, + no_trycast_sql, +) +from sqlglot.generator import Generator +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _date_add_sql(kind): + def func(self, expression): + from sqlglot.optimizer.simplify import simplify + + this = self.sql(expression, "this") + unit = self.sql(expression, "unit") + expression = simplify(expression.args["expression"]) + + if not isinstance(expression, exp.Literal): + self.unsupported("Cannot add non literal") + + expression = expression.copy() + expression.args["is_string"] = True + expression = self.sql(expression) + return f"{this} {kind} INTERVAL {expression} {unit}" + + return func + + +class Postgres(Dialect): + null_ordering = "nulls_are_large" + time_format = "'YYYY-MM-DD HH24:MI:SS'" + time_mapping = { + "AM": "%p", # AM or PM + "D": "%w", # 1-based day of week + "DD": "%d", # day of month + "DDD": "%j", # zero padded day of year + "FMDD": "%-d", # - is no leading zero for Python; same for FM in postgres + "FMDDD": "%-j", # day of year + "FMHH12": "%-I", # 9 + "FMHH24": "%-H", # 9 + "FMMI": "%-M", # Minute + "FMMM": "%-m", # 1 + "FMSS": "%-S", # Second + "HH12": "%I", # 09 + "HH24": "%H", # 09 + "MI": "%M", # zero padded minute + "MM": "%m", # 01 + "OF": "%z", # utc offset + "SS": "%S", # zero padded second + "TMDay": "%A", # TM is locale dependent + "TMDy": "%a", + "TMMon": "%b", # Sep + "TMMonth": "%B", # September + "TZ": "%Z", # uppercase timezone name + "US": "%f", # zero padded microsecond + "WW": "%U", # 1-based week of year + "YY": "%y", # 15 + "YYYY": "%Y", # 2015 + } + + class Tokenizer(Tokenizer): + KEYWORDS = { + **Tokenizer.KEYWORDS, + "SERIAL": TokenType.AUTO_INCREMENT, + "UUID": TokenType.UUID, + } + + class Parser(Parser): + STRICT_CAST = False + FUNCTIONS = { + **Parser.FUNCTIONS, + "TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"), + "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), + } + + class Generator(Generator): + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.TINYINT: "SMALLINT", + exp.DataType.Type.FLOAT: "REAL", + exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", + exp.DataType.Type.BINARY: "BYTEA", + } + + TOKEN_MAPPING = { + TokenType.AUTO_INCREMENT: "SERIAL", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.JSONExtract: arrow_json_extract_sql, + exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}", + exp.JSONBExtractScalar: lambda self, e: f"{self.sql(e, 'this')}#>>{self.sql(e, 'path')}", + exp.CurrentDate: no_paren_current_date_sql, + exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", + exp.DateAdd: _date_add_sql("+"), + exp.DateSub: _date_add_sql("-"), + exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TableSample: no_tablesample_sql, + exp.TryCast: no_trycast_sql, + } diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py new file mode 100644 index 0000000..ca913e4 --- /dev/null +++ b/sqlglot/dialects/presto.py @@ -0,0 +1,216 @@ +from sqlglot import exp, transforms +from sqlglot.dialects.dialect import ( + Dialect, + format_time_lambda, + if_sql, + no_ilike_sql, + no_safe_divide_sql, + rename_func, + str_position_sql, + struct_extract_sql, +) +from sqlglot.dialects.mysql import MySQL +from sqlglot.generator import Generator +from sqlglot.helper import csv, list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _approx_distinct_sql(self, expression): + accuracy = expression.args.get("accuracy") + accuracy = ", " + self.sql(accuracy) if accuracy else "" + return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})" + + +def _concat_ws_sql(self, expression): + sep, *args = expression.expressions + sep = self.sql(sep) + if len(args) > 1: + return f"ARRAY_JOIN(ARRAY[{csv(*(self.sql(e) for e in args))}], {sep})" + return f"ARRAY_JOIN({self.sql(args[0])}, {sep})" + + +def _datatype_sql(self, expression): + sql = self.datatype_sql(expression) + if expression.this == exp.DataType.Type.TIMESTAMPTZ: + sql = f"{sql} WITH TIME ZONE" + return sql + + +def _date_parse_sql(self, expression): + return f"DATE_PARSE({self.sql(expression, 'this')}, '%Y-%m-%d %H:%i:%s')" + + +def _explode_to_unnest_sql(self, expression): + if isinstance(expression.this, (exp.Explode, exp.Posexplode)): + return self.sql( + exp.Join( + this=exp.Unnest( + expressions=[expression.this.this], + alias=expression.args.get("alias"), + ordinality=isinstance(expression.this, exp.Posexplode), + ), + kind="cross", + ) + ) + return self.lateral_sql(expression) + + +def _initcap_sql(self, expression): + regex = r"(\w)(\w*)" + return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" + + +def _no_sort_array(self, expression): + if expression.args.get("asc") == exp.FALSE: + comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" + else: + comparator = None + args = csv(self.sql(expression, "this"), comparator) + return f"ARRAY_SORT({args})" + + +def _schema_sql(self, expression): + if isinstance(expression.parent, exp.Property): + columns = ", ".join(f"'{c.text('this')}'" for c in expression.expressions) + return f"ARRAY[{columns}]" + + for schema in expression.parent.find_all(exp.Schema): + if isinstance(schema.parent, exp.Property): + expression = expression.copy() + expression.expressions.extend(schema.expressions) + + return self.schema_sql(expression) + + +def _quantile_sql(self, expression): + self.unsupported("Presto does not support exact quantiles") + return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})" + + +def _str_to_time_sql(self, expression): + return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})" + + +def _ts_or_ds_to_date_sql(self, expression): + time_format = self.format_time(expression) + if time_format and time_format not in (Presto.time_format, Presto.date_format): + return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" + return ( + f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)" + ) + + +def _ts_or_ds_add_sql(self, expression): + this = self.sql(expression, "this") + e = self.sql(expression, "expression") + unit = self.sql(expression, "unit") or "'day'" + return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))" + + +class Presto(Dialect): + index_offset = 1 + null_ordering = "nulls_are_last" + time_format = "'%Y-%m-%d %H:%i:%S'" + time_mapping = MySQL.time_mapping + + class Tokenizer(Tokenizer): + KEYWORDS = { + **Tokenizer.KEYWORDS, + "ROW": TokenType.STRUCT, + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, + "CARDINALITY": exp.ArraySize.from_arg_list, + "CONTAINS": exp.ArrayContains.from_arg_list, + "DATE_ADD": lambda args: exp.DateAdd( + this=list_get(args, 2), + expression=list_get(args, 1), + unit=list_get(args, 0), + ), + "DATE_DIFF": lambda args: exp.DateDiff( + this=list_get(args, 2), + expression=list_get(args, 1), + unit=list_get(args, 0), + ), + "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), + "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), + "FROM_UNIXTIME": exp.UnixToTime.from_arg_list, + "STRPOS": exp.StrPosition.from_arg_list, + "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, + } + + class Generator(Generator): + + STRUCT_DELIMITER = ("(", ")") + + WITH_PROPERTIES = [ + exp.PartitionedByProperty, + exp.FileFormatProperty, + exp.SchemaCommentProperty, + exp.AnonymousProperty, + exp.TableFormatProperty, + ] + + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.INT: "INTEGER", + exp.DataType.Type.FLOAT: "REAL", + exp.DataType.Type.BINARY: "VARBINARY", + exp.DataType.Type.TEXT: "VARCHAR", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", + exp.DataType.Type.STRUCT: "ROW", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, + exp.ApproxDistinct: _approx_distinct_sql, + exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", + exp.ArrayContains: rename_func("CONTAINS"), + exp.ArraySize: rename_func("CARDINALITY"), + exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.BitwiseLeftShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.BitwiseNot: lambda self, e: f"BITWISE_NOT({self.sql(e, 'this')})", + exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.ConcatWs: _concat_ws_sql, + exp.DataType: _datatype_sql, + exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", + exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", + exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)", + exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)", + exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", + exp.FileFormatProperty: lambda self, e: self.property_sql(e), + exp.If: if_sql, + exp.ILike: no_ilike_sql, + exp.Initcap: _initcap_sql, + exp.Lateral: _explode_to_unnest_sql, + exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), + exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}", + exp.Quantile: _quantile_sql, + exp.SafeDivide: no_safe_divide_sql, + exp.Schema: _schema_sql, + exp.SortArray: _no_sort_array, + exp.StrPosition: str_position_sql, + exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", + exp.StrToTime: _str_to_time_sql, + exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", + exp.StructExtract: struct_extract_sql, + exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT = '{e.text('value').upper()}'", + exp.TimeStrToDate: _date_parse_sql, + exp.TimeStrToTime: _date_parse_sql, + exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))", + exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToUnix: rename_func("TO_UNIXTIME"), + exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", + exp.TsOrDsAdd: _ts_or_ds_add_sql, + exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})", + exp.UnixToTime: rename_func("FROM_UNIXTIME"), + exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", + } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py new file mode 100644 index 0000000..148dfb5 --- /dev/null +++ b/sqlglot/dialects/snowflake.py @@ -0,0 +1,145 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect, format_time_lambda, rename_func +from sqlglot.expressions import Literal +from sqlglot.generator import Generator +from sqlglot.helper import list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _check_int(s): + if s[0] in ("-", "+"): + return s[1:].isdigit() + return s.isdigit() + + +# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html +def _snowflake_to_timestamp(args): + if len(args) == 2: + first_arg, second_arg = args + if second_arg.is_string: + # case: <string_expr> [ , <format> ] + return format_time_lambda(exp.StrToTime, "snowflake")(args) + + # case: <numeric_expr> [ , <scale> ] + if second_arg.name not in ["0", "3", "9"]: + raise ValueError( + f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9" + ) + + if second_arg.name == "0": + timescale = exp.UnixToTime.SECONDS + elif second_arg.name == "3": + timescale = exp.UnixToTime.MILLIS + elif second_arg.name == "9": + timescale = exp.UnixToTime.MICROS + + return exp.UnixToTime(this=first_arg, scale=timescale) + + first_arg = list_get(args, 0) + if not isinstance(first_arg, Literal): + # case: <variant_expr> + return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args) + + if first_arg.is_string: + if _check_int(first_arg.this): + # case: <integer> + return exp.UnixToTime.from_arg_list(args) + + # case: <date_expr> + return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args) + + # case: <numeric_expr> + return exp.UnixToTime.from_arg_list(args) + + +def _unix_to_time(self, expression): + scale = expression.args.get("scale") + timestamp = self.sql(expression, "this") + if scale in [None, exp.UnixToTime.SECONDS]: + return f"TO_TIMESTAMP({timestamp})" + if scale == exp.UnixToTime.MILLIS: + return f"TO_TIMESTAMP({timestamp}, 3)" + if scale == exp.UnixToTime.MICROS: + return f"TO_TIMESTAMP({timestamp}, 9)" + + raise ValueError("Improper scale for timestamp") + + +class Snowflake(Dialect): + null_ordering = "nulls_are_large" + time_format = "'yyyy-mm-dd hh24:mi:ss'" + + time_mapping = { + "YYYY": "%Y", + "yyyy": "%Y", + "YY": "%y", + "yy": "%y", + "MMMM": "%B", + "mmmm": "%B", + "MON": "%b", + "mon": "%b", + "MM": "%m", + "mm": "%m", + "DD": "%d", + "dd": "%d", + "d": "%-d", + "DY": "%w", + "dy": "%w", + "HH24": "%H", + "hh24": "%H", + "HH12": "%I", + "hh12": "%I", + "MI": "%M", + "mi": "%M", + "SS": "%S", + "ss": "%S", + "FF": "%f", + "ff": "%f", + "FF6": "%f", + "ff6": "%f", + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "ARRAYAGG": exp.ArrayAgg.from_arg_list, + "IFF": exp.If.from_arg_list, + "TO_TIMESTAMP": _snowflake_to_timestamp, + } + + COLUMN_OPERATORS = { + **Parser.COLUMN_OPERATORS, + TokenType.COLON: lambda self, this, path: self.expression( + exp.Bracket, + this=this, + expressions=[path], + ), + } + + class Tokenizer(Tokenizer): + QUOTES = ["'", "$$"] + ESCAPE = "\\" + KEYWORDS = { + **Tokenizer.KEYWORDS, + "QUALIFY": TokenType.QUALIFY, + "DOUBLE PRECISION": TokenType.DOUBLE, + } + + class Generator(Generator): + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.If: rename_func("IFF"), + exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.UnixToTime: _unix_to_time, + } + + def except_op(self, expression): + if not expression.args.get("distinct", False): + self.unsupported("EXCEPT with All is not supported in Snowflake") + return super().except_op(expression) + + def intersect_op(self, expression): + if not expression.args.get("distinct", False): + self.unsupported("INTERSECT with All is not supported in Snowflake") + return super().intersect_op(expression) diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py new file mode 100644 index 0000000..89c7ed5 --- /dev/null +++ b/sqlglot/dialects/spark.py @@ -0,0 +1,106 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import no_ilike_sql, rename_func +from sqlglot.dialects.hive import Hive, HiveMap +from sqlglot.helper import list_get + + +def _create_sql(self, e): + kind = e.args.get("kind") + temporary = e.args.get("temporary") + + if kind.upper() == "TABLE" and temporary is True: + return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}" + return self.create_sql(e) + + +def _map_sql(self, expression): + keys = self.sql(expression.args["keys"]) + values = self.sql(expression.args["values"]) + return f"MAP_FROM_ARRAYS({keys}, {values})" + + +def _str_to_date(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + if time_format == Hive.date_format: + return f"TO_DATE({this})" + return f"TO_DATE({this}, {time_format})" + + +def _unix_to_time(self, expression): + scale = expression.args.get("scale") + timestamp = self.sql(expression, "this") + if scale is None: + return f"FROM_UNIXTIME({timestamp})" + if scale == exp.UnixToTime.SECONDS: + return f"TIMESTAMP_SECONDS({timestamp})" + if scale == exp.UnixToTime.MILLIS: + return f"TIMESTAMP_MILLIS({timestamp})" + if scale == exp.UnixToTime.MICROS: + return f"TIMESTAMP_MICROS({timestamp})" + + raise ValueError("Improper scale for timestamp") + + +class Spark(Hive): + class Parser(Hive.Parser): + FUNCTIONS = { + **Hive.Parser.FUNCTIONS, + "MAP_FROM_ARRAYS": exp.Map.from_arg_list, + "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, + "LEFT": lambda args: exp.Substring( + this=list_get(args, 0), + start=exp.Literal.number(1), + length=list_get(args, 1), + ), + "SHIFTLEFT": lambda args: exp.BitwiseLeftShift( + this=list_get(args, 0), + expression=list_get(args, 1), + ), + "SHIFTRIGHT": lambda args: exp.BitwiseRightShift( + this=list_get(args, 0), + expression=list_get(args, 1), + ), + "RIGHT": lambda args: exp.Substring( + this=list_get(args, 0), + start=exp.Sub( + this=exp.Length(this=list_get(args, 0)), + expression=exp.Add( + this=list_get(args, 1), expression=exp.Literal.number(1) + ), + ), + length=list_get(args, 1), + ), + } + + class Generator(Hive.Generator): + TYPE_MAPPING = { + **Hive.Generator.TYPE_MAPPING, + exp.DataType.Type.TINYINT: "BYTE", + exp.DataType.Type.SMALLINT: "SHORT", + exp.DataType.Type.BIGINT: "LONG", + } + + TRANSFORMS = { + **{ + k: v + for k, v in Hive.Generator.TRANSFORMS.items() + if k not in {exp.ArraySort} + }, + exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", + exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), + exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), + exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", + exp.ILike: no_ilike_sql, + exp.StrToDate: _str_to_date, + exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.UnixToTime: _unix_to_time, + exp.Create: _create_sql, + exp.Map: _map_sql, + exp.Reduce: rename_func("AGGREGATE"), + exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}", + HiveMap: _map_sql, + } + + def bitstring_sql(self, expression): + return f"X'{self.sql(expression, 'this')}'" diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py new file mode 100644 index 0000000..6cf5022 --- /dev/null +++ b/sqlglot/dialects/sqlite.py @@ -0,0 +1,63 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import ( + Dialect, + arrow_json_extract_scalar_sql, + arrow_json_extract_sql, + no_ilike_sql, + no_tablesample_sql, + no_trycast_sql, + rename_func, +) +from sqlglot.generator import Generator +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +class SQLite(Dialect): + class Tokenizer(Tokenizer): + IDENTIFIERS = ['"', ("[", "]"), "`"] + + KEYWORDS = { + **Tokenizer.KEYWORDS, + "AUTOINCREMENT": TokenType.AUTO_INCREMENT, + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "EDITDIST3": exp.Levenshtein.from_arg_list, + } + + class Generator(Generator): + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.BOOLEAN: "INTEGER", + exp.DataType.Type.TINYINT: "INTEGER", + exp.DataType.Type.SMALLINT: "INTEGER", + exp.DataType.Type.INT: "INTEGER", + exp.DataType.Type.BIGINT: "INTEGER", + exp.DataType.Type.FLOAT: "REAL", + exp.DataType.Type.DOUBLE: "REAL", + exp.DataType.Type.DECIMAL: "REAL", + exp.DataType.Type.CHAR: "TEXT", + exp.DataType.Type.NCHAR: "TEXT", + exp.DataType.Type.VARCHAR: "TEXT", + exp.DataType.Type.NVARCHAR: "TEXT", + exp.DataType.Type.BINARY: "BLOB", + } + + TOKEN_MAPPING = { + TokenType.AUTO_INCREMENT: "AUTOINCREMENT", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.ILike: no_ilike_sql, + exp.JSONExtract: arrow_json_extract_sql, + exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.JSONBExtract: arrow_json_extract_sql, + exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, + exp.Levenshtein: rename_func("EDITDIST3"), + exp.TableSample: no_tablesample_sql, + exp.TryCast: no_trycast_sql, + } diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py new file mode 100644 index 0000000..b9cd584 --- /dev/null +++ b/sqlglot/dialects/starrocks.py @@ -0,0 +1,12 @@ +from sqlglot import exp +from sqlglot.dialects.mysql import MySQL + + +class StarRocks(MySQL): + class Generator(MySQL.Generator): + TYPE_MAPPING = { + **MySQL.Generator.TYPE_MAPPING, + exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.TIMESTAMP: "DATETIME", + exp.DataType.Type.TIMESTAMPTZ: "DATETIME", + } diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py new file mode 100644 index 0000000..e571749 --- /dev/null +++ b/sqlglot/dialects/tableau.py @@ -0,0 +1,37 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect +from sqlglot.generator import Generator +from sqlglot.helper import list_get +from sqlglot.parser import Parser + + +def _if_sql(self, expression): + return f"IF {self.sql(expression, 'this')} THEN {self.sql(expression, 'true')} ELSE {self.sql(expression, 'false')} END" + + +def _coalesce_sql(self, expression): + return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})" + + +def _count_sql(self, expression): + this = expression.this + if isinstance(this, exp.Distinct): + return f"COUNTD({self.sql(this, 'this')})" + return f"COUNT({self.sql(expression, 'this')})" + + +class Tableau(Dialect): + class Generator(Generator): + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.If: _if_sql, + exp.Coalesce: _coalesce_sql, + exp.Count: _count_sql, + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "IFNULL": exp.Coalesce.from_arg_list, + "COUNTD": lambda args: exp.Count(this=exp.Distinct(this=list_get(args, 0))), + } diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py new file mode 100644 index 0000000..805106c --- /dev/null +++ b/sqlglot/dialects/trino.py @@ -0,0 +1,10 @@ +from sqlglot import exp +from sqlglot.dialects.presto import Presto + + +class Trino(Presto): + class Generator(Presto.Generator): + TRANSFORMS = { + **Presto.Generator.TRANSFORMS, + exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", + } |