diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 57 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 24 | ||||
-rw-r--r-- | sqlglot/dialects/databricks.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 52 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 33 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 57 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 329 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 20 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 25 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 41 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 13 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 46 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 37 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 24 | ||||
-rw-r--r-- | sqlglot/dialects/starrocks.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/tableau.py | 14 | ||||
-rw-r--r-- | sqlglot/dialects/trino.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 54 |
18 files changed, 596 insertions, 245 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 62d042e..5bbff9d 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -1,21 +1,21 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, inline_array_sql, no_ilike_sql, rename_func, ) -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.helper import seq_get +from sqlglot.tokens import TokenType def _date_add(expression_class): def func(args): - interval = list_get(args, 1) + interval = seq_get(args, 1) return expression_class( - this=list_get(args, 0), + this=seq_get(args, 0), expression=interval.this, unit=interval.args.get("unit"), ) @@ -23,6 +23,13 @@ def _date_add(expression_class): return func +def _date_trunc(args): + unit = seq_get(args, 1) + if isinstance(unit, exp.Column): + unit = exp.Var(this=unit.name) + return exp.DateTrunc(this=seq_get(args, 0), expression=unit) + + def _date_add_sql(data_type, kind): def func(self, expression): this = self.sql(expression, "this") @@ -40,7 +47,8 @@ def _derived_table_values_to_unnest(self, expression): structs = [] for row in rows: aliases = [ - exp.alias_(value, column_name) for value, column_name in zip(row, expression.args["alias"].args["columns"]) + exp.alias_(value, column_name) + for value, column_name in zip(row, expression.args["alias"].args["columns"]) ] structs.append(exp.Struct(expressions=aliases)) unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)]) @@ -89,18 +97,19 @@ class BigQuery(Dialect): "%j": "%-j", } - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): QUOTES = [ (prefix + quote, quote) if prefix else quote for quote in ["'", '"', '"""', "'''"] for prefix in ["", "r", "R"] ] + COMMENTS = ["--", "#", ("/*", "*/")] IDENTIFIERS = ["`"] - ESCAPE = "\\" + ESCAPES = ["\\"] HEX_STRINGS = [("0x", ""), ("0X", "")] KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, "CURRENT_TIME": TokenType.CURRENT_TIME, "GEOGRAPHY": TokenType.GEOGRAPHY, @@ -111,35 +120,40 @@ class BigQuery(Dialect): "WINDOW": TokenType.WINDOW, "NOT DETERMINISTIC": TokenType.VOLATILE, } + KEYWORDS.pop("DIV") - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, + "DATE_TRUNC": _date_trunc, "DATE_ADD": _date_add(exp.DateAdd), "DATETIME_ADD": _date_add(exp.DatetimeAdd), + "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)), "TIME_ADD": _date_add(exp.TimeAdd), "TIMESTAMP_ADD": _date_add(exp.TimestampAdd), "DATE_SUB": _date_add(exp.DateSub), "DATETIME_SUB": _date_add(exp.DatetimeSub), "TIME_SUB": _date_add(exp.TimeSub), "TIMESTAMP_SUB": _date_add(exp.TimestampSub), - "PARSE_TIMESTAMP": lambda args: exp.StrToTime(this=list_get(args, 1), format=list_get(args, 0)), + "PARSE_TIMESTAMP": lambda args: exp.StrToTime( + this=seq_get(args, 1), format=seq_get(args, 0) + ), } NO_PAREN_FUNCTIONS = { - **Parser.NO_PAREN_FUNCTIONS, + **parser.Parser.NO_PAREN_FUNCTIONS, TokenType.CURRENT_DATETIME: exp.CurrentDatetime, TokenType.CURRENT_TIME: exp.CurrentTime, } NESTED_TYPE_TOKENS = { - *Parser.NESTED_TYPE_TOKENS, + *parser.Parser.NESTED_TYPE_TOKENS, TokenType.TABLE, } - class Generator(Generator): + class Generator(generator.Generator): TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.Array: inline_array_sql, exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.DateAdd: _date_add_sql("DATE", "ADD"), @@ -148,6 +162,7 @@ class BigQuery(Dialect): exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", exp.ILike: no_ilike_sql, + exp.IntDiv: rename_func("DIV"), exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeSub: _date_add_sql("TIME", "SUB"), @@ -157,11 +172,13 @@ class BigQuery(Dialect): exp.Values: _derived_table_values_to_unnest, exp.ReturnsProperty: _returnsproperty_sql, exp.Create: _create_sql, - exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", + exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" + if e.name == "IMMUTABLE" + else "NOT DETERMINISTIC", } TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "INT64", exp.DataType.Type.SMALLINT: "INT64", exp.DataType.Type.INT: "INT64", diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index f446e6d..332b4c1 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -1,8 +1,9 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql -from sqlglot.generator import Generator -from sqlglot.parser import Parser, parse_var_map -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.parser import parse_var_map +from sqlglot.tokens import TokenType def _lower_func(sql): @@ -14,11 +15,12 @@ class ClickHouse(Dialect): normalize_functions = None null_ordering = "nulls_are_last" - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): + COMMENTS = ["--", "#", "#!", ("/*", "*/")] IDENTIFIERS = ['"', "`"] KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "FINAL": TokenType.FINAL, "DATETIME64": TokenType.DATETIME, "INT8": TokenType.TINYINT, @@ -30,9 +32,9 @@ class ClickHouse(Dialect): "TUPLE": TokenType.STRUCT, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "MAP": parse_var_map, } @@ -44,11 +46,11 @@ class ClickHouse(Dialect): return this - class Generator(Generator): + class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.NULLABLE: "Nullable", exp.DataType.Type.DATETIME: "DateTime64", exp.DataType.Type.MAP: "Map", @@ -63,7 +65,7 @@ class ClickHouse(Dialect): } TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.Array: inline_array_sql, exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 9dc3c38..2498c62 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from sqlglot import exp from sqlglot.dialects.dialect import parse_date_delta from sqlglot.dialects.spark import Spark @@ -15,7 +17,7 @@ class Databricks(Spark): class Generator(Spark.Generator): TRANSFORMS = { - **Spark.Generator.TRANSFORMS, + **Spark.Generator.TRANSFORMS, # type: ignore exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, } diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 33985a7..3af08bb 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -1,8 +1,11 @@ +from __future__ import annotations + +import typing as t from enum import Enum from sqlglot import exp from sqlglot.generator import Generator -from sqlglot.helper import flatten, list_get +from sqlglot.helper import flatten, seq_get from sqlglot.parser import Parser from sqlglot.time import format_time from sqlglot.tokens import Tokenizer @@ -32,7 +35,7 @@ class Dialects(str, Enum): class _Dialect(type): - classes = {} + classes: t.Dict[str, Dialect] = {} @classmethod def __getitem__(cls, key): @@ -56,19 +59,30 @@ class _Dialect(type): klass.generator_class = getattr(klass, "Generator", Generator) klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0] - klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0] - - if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS: + klass.identifier_start, klass.identifier_end = list( + klass.tokenizer_class._IDENTIFIERS.items() + )[0] + + if ( + klass.tokenizer_class._BIT_STRINGS + and exp.BitString not in klass.generator_class.TRANSFORMS + ): bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0] klass.generator_class.TRANSFORMS[ exp.BitString ] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}" - if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS: + if ( + klass.tokenizer_class._HEX_STRINGS + and exp.HexString not in klass.generator_class.TRANSFORMS + ): hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0] klass.generator_class.TRANSFORMS[ exp.HexString ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}" - if klass.tokenizer_class._BYTE_STRINGS and exp.ByteString not in klass.generator_class.TRANSFORMS: + if ( + klass.tokenizer_class._BYTE_STRINGS + and exp.ByteString not in klass.generator_class.TRANSFORMS + ): be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0] klass.generator_class.TRANSFORMS[ exp.ByteString @@ -81,13 +95,13 @@ class Dialect(metaclass=_Dialect): index_offset = 0 unnest_column_only = False alias_post_tablesample = False - normalize_functions = "upper" + normalize_functions: t.Optional[str] = "upper" null_ordering = "nulls_are_small" date_format = "'%Y-%m-%d'" dateint_format = "'%Y%m%d'" time_format = "'%Y-%m-%d %H:%M:%S'" - time_mapping = {} + time_mapping: t.Dict[str, str] = {} # autofilled quote_start = None @@ -167,7 +181,7 @@ class Dialect(metaclass=_Dialect): "quote_end": self.quote_end, "identifier_start": self.identifier_start, "identifier_end": self.identifier_end, - "escape": self.tokenizer_class.ESCAPE, + "escape": self.tokenizer_class.ESCAPES[0], "index_offset": self.index_offset, "time_mapping": self.inverse_time_mapping, "time_trie": self.inverse_time_trie, @@ -195,7 +209,9 @@ def approx_count_distinct_sql(self, expression): def if_sql(self, expression): - expressions = self.format_args(expression.this, expression.args.get("true"), expression.args.get("false")) + expressions = self.format_args( + expression.this, expression.args.get("true"), expression.args.get("false") + ) return f"IF({expressions})" @@ -298,9 +314,9 @@ def format_time_lambda(exp_class, dialect, default=None): def _format_time(args): return exp_class( - this=list_get(args, 0), + this=seq_get(args, 0), format=Dialect[dialect].format_time( - list_get(args, 1) or (Dialect[dialect].time_format if default is True else default) + seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default) ), ) @@ -328,7 +344,9 @@ def create_with_partitions_sql(self, expression): "expressions", [e for e in schema.expressions if e not in partitions], ) - prop.replace(exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))) + prop.replace( + exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions)) + ) expression.set("this", schema) return self.create_sql(expression) @@ -337,9 +355,9 @@ def create_with_partitions_sql(self, expression): def parse_date_delta(exp_class, unit_mapping=None): def inner_func(args): unit_based = len(args) == 3 - this = list_get(args, 2) if unit_based else list_get(args, 0) - expression = list_get(args, 1) if unit_based else list_get(args, 1) - unit = list_get(args, 0) if unit_based else exp.Literal.string("DAY") + this = seq_get(args, 2) if unit_based else seq_get(args, 0) + expression = seq_get(args, 1) if unit_based else seq_get(args, 1) + unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY") unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit return exp_class(this=this, expression=expression, unit=unit) diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index f3ff6d3..781edff 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -1,4 +1,6 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, approx_count_distinct_sql, @@ -12,10 +14,8 @@ from sqlglot.dialects.dialect import ( rename_func, str_position_sql, ) -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.helper import seq_get +from sqlglot.tokens import TokenType def _unix_to_time(self, expression): @@ -61,11 +61,14 @@ def _sort_array_sql(self, expression): def _sort_array_reverse(args): - return exp.SortArray(this=list_get(args, 0), asc=exp.FALSE) + return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE) def _struct_pack_sql(self, expression): - args = [self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) for e in expression.expressions] + args = [ + self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) + for e in expression.expressions + ] return f"STRUCT_PACK({', '.join(args)})" @@ -76,15 +79,15 @@ def _datatype_sql(self, expression): class DuckDB(Dialect): - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, ":=": TokenType.EQ, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "ARRAY_LENGTH": exp.ArraySize.from_arg_list, "ARRAY_SORT": exp.SortArray.from_arg_list, @@ -92,7 +95,7 @@ class DuckDB(Dialect): "EPOCH": exp.TimeToUnix.from_arg_list, "EPOCH_MS": lambda args: exp.UnixToTime( this=exp.Div( - this=list_get(args, 0), + this=seq_get(args, 0), expression=exp.Literal.number(1000), ) ), @@ -112,11 +115,11 @@ class DuckDB(Dialect): "UNNEST": exp.Explode.from_arg_list, } - class Generator(Generator): + class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, exp.Array: rename_func("LIST_VALUE"), exp.ArraySize: rename_func("ARRAY_LENGTH"), @@ -160,7 +163,7 @@ class DuckDB(Dialect): } TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.VARCHAR: "TEXT", exp.DataType.Type.NVARCHAR: "TEXT", } diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 03049ff..ed7357c 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -1,4 +1,6 @@ -from sqlglot import exp, transforms +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, approx_count_distinct_sql, @@ -13,10 +15,8 @@ from sqlglot.dialects.dialect import ( struct_extract_sql, var_map_sql, ) -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser, parse_var_map -from sqlglot.tokens import Tokenizer +from sqlglot.helper import seq_get +from sqlglot.parser import parse_var_map # (FuncType, Multiplier) DATE_DELTA_INTERVAL = { @@ -34,7 +34,9 @@ def _add_date_sql(self, expression): unit = expression.text("unit").upper() func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) modified_increment = ( - int(expression.text("expression")) * multiplier if expression.expression.is_number else expression.expression + int(expression.text("expression")) * multiplier + if expression.expression.is_number + else expression.expression ) modified_increment = exp.Literal.number(modified_increment) return f"{func}({self.format_args(expression.this, modified_increment.this)})" @@ -165,10 +167,10 @@ class Hive(Dialect): dateint_format = "'yyyyMMdd'" time_format = "'yyyy-MM-dd HH:mm:ss'" - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): QUOTES = ["'", '"'] IDENTIFIERS = ["`"] - ESCAPE = "\\" + ESCAPES = ["\\"] ENCODE = "utf-8" NUMERIC_LITERALS = { @@ -180,40 +182,44 @@ class Hive(Dialect): "BD": "DECIMAL", } - class Parser(Parser): + class Parser(parser.Parser): STRICT_CAST = False FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list, "DATE_ADD": lambda args: exp.TsOrDsAdd( - this=list_get(args, 0), - expression=list_get(args, 1), + this=seq_get(args, 0), + expression=seq_get(args, 1), unit=exp.Literal.string("DAY"), ), "DATEDIFF": lambda args: exp.DateDiff( - this=exp.TsOrDsToDate(this=list_get(args, 0)), - expression=exp.TsOrDsToDate(this=list_get(args, 1)), + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + expression=exp.TsOrDsToDate(this=seq_get(args, 1)), ), "DATE_SUB": lambda args: exp.TsOrDsAdd( - this=list_get(args, 0), + this=seq_get(args, 0), expression=exp.Mul( - this=list_get(args, 1), + this=seq_get(args, 1), expression=exp.Literal.number(-1), ), unit=exp.Literal.string("DAY"), ), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"), - "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=list_get(args, 0))), + "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, "LOCATE": lambda args: exp.StrPosition( - this=list_get(args, 1), - substr=list_get(args, 0), - position=list_get(args, 2), + this=seq_get(args, 1), + substr=seq_get(args, 0), + position=seq_get(args, 2), + ), + "LOG": ( + lambda args: exp.Log.from_arg_list(args) + if len(args) > 1 + else exp.Ln.from_arg_list(args) ), - "LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)), "MAP": parse_var_map, "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "PERCENTILE": exp.Quantile.from_arg_list, @@ -226,15 +232,16 @@ class Hive(Dialect): "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), } - class Generator(Generator): + class Generator(generator.Generator): TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.VARBINARY: "BINARY", } TRANSFORMS = { - **Generator.TRANSFORMS, - **transforms.UNALIAS_GROUP, + **generator.Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, # type: ignore exp.AnonymousProperty: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, exp.ArrayAgg: rename_func("COLLECT_LIST"), diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 524390f..e742640 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -1,4 +1,8 @@ -from sqlglot import exp +from __future__ import annotations + +import typing as t + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, no_ilike_sql, @@ -6,42 +10,47 @@ from sqlglot.dialects.dialect import ( no_tablesample_sql, no_trycast_sql, ) -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.helper import seq_get +from sqlglot.tokens import TokenType + + +def _show_parser(*args, **kwargs): + def _parse(self): + return self._parse_show_mysql(*args, **kwargs) + + return _parse def _date_trunc_sql(self, expression): - unit = expression.text("unit").lower() + unit = expression.name.lower() - this = self.sql(expression.this) + expr = self.sql(expression.expression) if unit == "day": - return f"DATE({this})" + return f"DATE({expr})" if unit == "week": - concat = f"CONCAT(YEAR({this}), ' ', WEEK({this}, 1), ' 1')" + concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')" date_format = "%Y %u %w" elif unit == "month": - concat = f"CONCAT(YEAR({this}), ' ', MONTH({this}), ' 1')" + concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')" date_format = "%Y %c %e" elif unit == "quarter": - concat = f"CONCAT(YEAR({this}), ' ', QUARTER({this}) * 3 - 2, ' 1')" + concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')" date_format = "%Y %c %e" elif unit == "year": - concat = f"CONCAT(YEAR({this}), ' 1 1')" + concat = f"CONCAT(YEAR({expr}), ' 1 1')" date_format = "%Y %c %e" else: self.unsupported("Unexpected interval unit: {unit}") - return f"DATE({this})" + return f"DATE({expr})" return f"STR_TO_DATE({concat}, '{date_format}')" def _str_to_date(args): - date_format = MySQL.format_time(list_get(args, 1)) - return exp.StrToDate(this=list_get(args, 0), format=date_format) + date_format = MySQL.format_time(seq_get(args, 1)) + return exp.StrToDate(this=seq_get(args, 0), format=date_format) def _str_to_date_sql(self, expression): @@ -66,9 +75,9 @@ def _trim_sql(self, expression): def _date_add(expression_class): def func(args): - interval = list_get(args, 1) + interval = seq_get(args, 1) return expression_class( - this=list_get(args, 0), + this=seq_get(args, 0), expression=interval.this, unit=exp.Literal.string(interval.text("unit").lower()), ) @@ -101,15 +110,16 @@ class MySQL(Dialect): "%l": "%-I", } - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): QUOTES = ["'", '"'] COMMENTS = ["--", "#", ("/*", "*/")] IDENTIFIERS = ["`"] + ESCAPES = ["'", "\\"] BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")] HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")] KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "SEPARATOR": TokenType.SEPARATOR, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, @@ -156,20 +166,23 @@ class MySQL(Dialect): "_UTF32": TokenType.INTRODUCER, "_UTF8MB3": TokenType.INTRODUCER, "_UTF8MB4": TokenType.INTRODUCER, + "@@": TokenType.SESSION_PARAMETER, } - class Parser(Parser): + COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW} + + class Parser(parser.Parser): STRICT_CAST = False FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "DATE_ADD": _date_add(exp.DateAdd), "DATE_SUB": _date_add(exp.DateSub), "STR_TO_DATE": _str_to_date, } FUNCTION_PARSERS = { - **Parser.FUNCTION_PARSERS, + **parser.Parser.FUNCTION_PARSERS, "GROUP_CONCAT": lambda self: self.expression( exp.GroupConcat, this=self._parse_lambda(), @@ -178,15 +191,212 @@ class MySQL(Dialect): } PROPERTY_PARSERS = { - **Parser.PROPERTY_PARSERS, + **parser.Parser.PROPERTY_PARSERS, TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty), } - class Generator(Generator): + STATEMENT_PARSERS = { + **parser.Parser.STATEMENT_PARSERS, + TokenType.SHOW: lambda self: self._parse_show(), + TokenType.SET: lambda self: self._parse_set(), + } + + SHOW_PARSERS = { + "BINARY LOGS": _show_parser("BINARY LOGS"), + "MASTER LOGS": _show_parser("BINARY LOGS"), + "BINLOG EVENTS": _show_parser("BINLOG EVENTS"), + "CHARACTER SET": _show_parser("CHARACTER SET"), + "CHARSET": _show_parser("CHARACTER SET"), + "COLLATION": _show_parser("COLLATION"), + "FULL COLUMNS": _show_parser("COLUMNS", target="FROM", full=True), + "COLUMNS": _show_parser("COLUMNS", target="FROM"), + "CREATE DATABASE": _show_parser("CREATE DATABASE", target=True), + "CREATE EVENT": _show_parser("CREATE EVENT", target=True), + "CREATE FUNCTION": _show_parser("CREATE FUNCTION", target=True), + "CREATE PROCEDURE": _show_parser("CREATE PROCEDURE", target=True), + "CREATE TABLE": _show_parser("CREATE TABLE", target=True), + "CREATE TRIGGER": _show_parser("CREATE TRIGGER", target=True), + "CREATE VIEW": _show_parser("CREATE VIEW", target=True), + "DATABASES": _show_parser("DATABASES"), + "ENGINE": _show_parser("ENGINE", target=True), + "STORAGE ENGINES": _show_parser("ENGINES"), + "ENGINES": _show_parser("ENGINES"), + "ERRORS": _show_parser("ERRORS"), + "EVENTS": _show_parser("EVENTS"), + "FUNCTION CODE": _show_parser("FUNCTION CODE", target=True), + "FUNCTION STATUS": _show_parser("FUNCTION STATUS"), + "GRANTS": _show_parser("GRANTS", target="FOR"), + "INDEX": _show_parser("INDEX", target="FROM"), + "MASTER STATUS": _show_parser("MASTER STATUS"), + "OPEN TABLES": _show_parser("OPEN TABLES"), + "PLUGINS": _show_parser("PLUGINS"), + "PROCEDURE CODE": _show_parser("PROCEDURE CODE", target=True), + "PROCEDURE STATUS": _show_parser("PROCEDURE STATUS"), + "PRIVILEGES": _show_parser("PRIVILEGES"), + "FULL PROCESSLIST": _show_parser("PROCESSLIST", full=True), + "PROCESSLIST": _show_parser("PROCESSLIST"), + "PROFILE": _show_parser("PROFILE"), + "PROFILES": _show_parser("PROFILES"), + "RELAYLOG EVENTS": _show_parser("RELAYLOG EVENTS"), + "REPLICAS": _show_parser("REPLICAS"), + "SLAVE HOSTS": _show_parser("REPLICAS"), + "REPLICA STATUS": _show_parser("REPLICA STATUS"), + "SLAVE STATUS": _show_parser("REPLICA STATUS"), + "GLOBAL STATUS": _show_parser("STATUS", global_=True), + "SESSION STATUS": _show_parser("STATUS"), + "STATUS": _show_parser("STATUS"), + "TABLE STATUS": _show_parser("TABLE STATUS"), + "FULL TABLES": _show_parser("TABLES", full=True), + "TABLES": _show_parser("TABLES"), + "TRIGGERS": _show_parser("TRIGGERS"), + "GLOBAL VARIABLES": _show_parser("VARIABLES", global_=True), + "SESSION VARIABLES": _show_parser("VARIABLES"), + "VARIABLES": _show_parser("VARIABLES"), + "WARNINGS": _show_parser("WARNINGS"), + } + + SET_PARSERS = { + "GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"), + "PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"), + "PERSIST_ONLY": lambda self: self._parse_set_item_assignment("PERSIST_ONLY"), + "SESSION": lambda self: self._parse_set_item_assignment("SESSION"), + "LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"), + "CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"), + "CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"), + "NAMES": lambda self: self._parse_set_item_names(), + } + + PROFILE_TYPES = { + "ALL", + "BLOCK IO", + "CONTEXT SWITCHES", + "CPU", + "IPC", + "MEMORY", + "PAGE FAULTS", + "SOURCE", + "SWAPS", + } + + def _parse_show_mysql(self, this, target=False, full=None, global_=None): + if target: + if isinstance(target, str): + self._match_text(target) + target_id = self._parse_id_var() + else: + target_id = None + + log = self._parse_string() if self._match_text("IN") else None + + if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}: + position = self._parse_number() if self._match_text("FROM") else None + db = None + else: + position = None + db = self._parse_id_var() if self._match_text("FROM") else None + + channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None + + like = self._parse_string() if self._match_text("LIKE") else None + where = self._parse_where() + + if this == "PROFILE": + types = self._parse_csv(self._parse_show_profile_type) + query = self._parse_number() if self._match_text("FOR", "QUERY") else None + offset = self._parse_number() if self._match_text("OFFSET") else None + limit = self._parse_number() if self._match_text("LIMIT") else None + else: + types, query = None, None + offset, limit = self._parse_oldstyle_limit() + + mutex = True if self._match_text("MUTEX") else None + mutex = False if self._match_text("STATUS") else mutex + + return self.expression( + exp.Show, + this=this, + target=target_id, + full=full, + log=log, + position=position, + db=db, + channel=channel, + like=like, + where=where, + types=types, + query=query, + offset=offset, + limit=limit, + mutex=mutex, + **{"global": global_}, + ) + + def _parse_show_profile_type(self): + for type_ in self.PROFILE_TYPES: + if self._match_text(*type_.split(" ")): + return exp.Var(this=type_) + return None + + def _parse_oldstyle_limit(self): + limit = None + offset = None + if self._match_text("LIMIT"): + parts = self._parse_csv(self._parse_number) + if len(parts) == 1: + limit = parts[0] + elif len(parts) == 2: + limit = parts[1] + offset = parts[0] + return offset, limit + + def _default_parse_set_item(self): + return self._parse_set_item_assignment(kind=None) + + def _parse_set_item_assignment(self, kind): + left = self._parse_primary() or self._parse_id_var() + if not self._match(TokenType.EQ): + self.raise_error("Expected =") + right = self._parse_statement() or self._parse_id_var() + + this = self.expression( + exp.EQ, + this=left, + expression=right, + ) + + return self.expression( + exp.SetItem, + this=this, + kind=kind, + ) + + def _parse_set_item_charset(self, kind): + this = self._parse_string() or self._parse_id_var() + + return self.expression( + exp.SetItem, + this=this, + kind=kind, + ) + + def _parse_set_item_names(self): + charset = self._parse_string() or self._parse_id_var() + if self._match_text("COLLATE"): + collate = self._parse_string() or self._parse_id_var() + else: + collate = None + return self.expression( + exp.SetItem, + this=charset, + collate=collate, + kind="NAMES", + ) + + class Generator(generator.Generator): NULL_ORDERING_SUPPORTED = False TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.ILike: no_ilike_sql, @@ -199,6 +409,8 @@ class MySQL(Dialect): exp.StrToDate: _str_to_date_sql, exp.StrToTime: _str_to_date_sql, exp.Trim: _trim_sql, + exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), + exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), } ROOT_PROPERTIES = { @@ -209,4 +421,69 @@ class MySQL(Dialect): exp.SchemaCommentProperty, } - WITH_PROPERTIES = {} + WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set() + + def show_sql(self, expression): + this = f" {expression.name}" + full = " FULL" if expression.args.get("full") else "" + global_ = " GLOBAL" if expression.args.get("global") else "" + + target = self.sql(expression, "target") + target = f" {target}" if target else "" + if expression.name in {"COLUMNS", "INDEX"}: + target = f" FROM{target}" + elif expression.name == "GRANTS": + target = f" FOR{target}" + + db = self._prefixed_sql("FROM", expression, "db") + + like = self._prefixed_sql("LIKE", expression, "like") + where = self.sql(expression, "where") + + types = self.expressions(expression, key="types") + types = f" {types}" if types else types + query = self._prefixed_sql("FOR QUERY", expression, "query") + + if expression.name == "PROFILE": + offset = self._prefixed_sql("OFFSET", expression, "offset") + limit = self._prefixed_sql("LIMIT", expression, "limit") + else: + offset = "" + limit = self._oldstyle_limit_sql(expression) + + log = self._prefixed_sql("IN", expression, "log") + position = self._prefixed_sql("FROM", expression, "position") + + channel = self._prefixed_sql("FOR CHANNEL", expression, "channel") + + if expression.name == "ENGINE": + mutex_or_status = " MUTEX" if expression.args.get("mutex") else " STATUS" + else: + mutex_or_status = "" + + return f"SHOW{full}{global_}{this}{target}{types}{db}{query}{log}{position}{channel}{mutex_or_status}{like}{where}{offset}{limit}" + + def _prefixed_sql(self, prefix, expression, arg): + sql = self.sql(expression, arg) + if not sql: + return "" + return f" {prefix} {sql}" + + def _oldstyle_limit_sql(self, expression): + limit = self.sql(expression, "limit") + offset = self.sql(expression, "offset") + if limit: + limit_offset = f"{offset}, {limit}" if offset else limit + return f" LIMIT {limit_offset}" + return "" + + def setitem_sql(self, expression): + kind = self.sql(expression, "kind") + kind = f"{kind} " if kind else "" + this = self.sql(expression, "this") + collate = self.sql(expression, "collate") + collate = f" COLLATE {collate}" if collate else "" + return f"{kind}{this}{collate}" + + def set_sql(self, expression): + return f"SET {self.expressions(expression)}" diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 144dba5..3bc1109 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -1,8 +1,9 @@ -from sqlglot import exp, transforms +from __future__ import annotations + +from sqlglot import exp, generator, tokens, transforms from sqlglot.dialects.dialect import Dialect, no_ilike_sql -from sqlglot.generator import Generator from sqlglot.helper import csv -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.tokens import TokenType def _limit_sql(self, expression): @@ -36,9 +37,9 @@ class Oracle(Dialect): "YYYY": "%Y", # 2015 } - class Generator(Generator): + class Generator(generator.Generator): TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "NUMBER", exp.DataType.Type.SMALLINT: "NUMBER", exp.DataType.Type.INT: "NUMBER", @@ -49,11 +50,12 @@ class Oracle(Dialect): exp.DataType.Type.NVARCHAR: "NVARCHAR2", exp.DataType.Type.TEXT: "CLOB", exp.DataType.Type.BINARY: "BLOB", + exp.DataType.Type.VARBINARY: "BLOB", } TRANSFORMS = { - **Generator.TRANSFORMS, - **transforms.UNALIAS_GROUP, + **generator.Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, # type: ignore exp.ILike: no_ilike_sql, exp.Limit: _limit_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", @@ -86,9 +88,9 @@ class Oracle(Dialect): def table_sql(self, expression): return super().table_sql(expression, sep=" ") - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "TOP": TokenType.TOP, "VARCHAR2": TokenType.VARCHAR, "NVARCHAR2": TokenType.NVARCHAR, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 459e926..553a73b 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -1,4 +1,6 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, arrow_json_extract_scalar_sql, @@ -9,9 +11,7 @@ from sqlglot.dialects.dialect import ( no_trycast_sql, str_position_sql, ) -from sqlglot.generator import Generator -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.tokens import TokenType from sqlglot.transforms import delegate, preprocess @@ -160,12 +160,12 @@ class Postgres(Dialect): "YYYY": "%Y", # 2015 } - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): BIT_STRINGS = [("b'", "'"), ("B'", "'")] HEX_STRINGS = [("x'", "'"), ("X'", "'")] BYTE_STRINGS = [("e'", "'"), ("E'", "'")] KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "ALWAYS": TokenType.ALWAYS, "BY DEFAULT": TokenType.BY_DEFAULT, "COMMENT ON": TokenType.COMMENT_ON, @@ -179,31 +179,32 @@ class Postgres(Dialect): } QUOTES = ["'", "$$"] SINGLE_TOKENS = { - **Tokenizer.SINGLE_TOKENS, + **tokens.Tokenizer.SINGLE_TOKENS, "$": TokenType.PARAMETER, } - class Parser(Parser): + class Parser(parser.Parser): STRICT_CAST = False FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "TO_TIMESTAMP": _to_timestamp, "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), } - class Generator(Generator): + class Generator(generator.Generator): TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "SMALLINT", exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", exp.DataType.Type.BINARY: "BYTEA", + exp.DataType.Type.VARBINARY: "BYTEA", exp.DataType.Type.DATETIME: "TIMESTAMP", } TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.ColumnDef: preprocess( [ _auto_increment_to_serial, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index a2d392c..11ea778 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -1,4 +1,6 @@ -from sqlglot import exp, transforms +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, format_time_lambda, @@ -10,10 +12,8 @@ from sqlglot.dialects.dialect import ( struct_extract_sql, ) from sqlglot.dialects.mysql import MySQL -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.helper import seq_get +from sqlglot.tokens import TokenType def _approx_distinct_sql(self, expression): @@ -110,30 +110,29 @@ class Presto(Dialect): index_offset = 1 null_ordering = "nulls_are_last" time_format = "'%Y-%m-%d %H:%i:%S'" - time_mapping = MySQL.time_mapping + time_mapping = MySQL.time_mapping # type: ignore - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): KEYWORDS = { - **Tokenizer.KEYWORDS, - "VARBINARY": TokenType.BINARY, + **tokens.Tokenizer.KEYWORDS, "ROW": TokenType.STRUCT, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, "CARDINALITY": exp.ArraySize.from_arg_list, "CONTAINS": exp.ArrayContains.from_arg_list, "DATE_ADD": lambda args: exp.DateAdd( - this=list_get(args, 2), - expression=list_get(args, 1), - unit=list_get(args, 0), + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=seq_get(args, 0), ), "DATE_DIFF": lambda args: exp.DateDiff( - this=list_get(args, 2), - expression=list_get(args, 1), - unit=list_get(args, 0), + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=seq_get(args, 0), ), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), @@ -143,7 +142,7 @@ class Presto(Dialect): "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, } - class Generator(Generator): + class Generator(generator.Generator): STRUCT_DELIMITER = ("(", ")") @@ -159,7 +158,7 @@ class Presto(Dialect): } TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.BINARY: "VARBINARY", @@ -169,8 +168,8 @@ class Presto(Dialect): } TRANSFORMS = { - **Generator.TRANSFORMS, - **transforms.UNALIAS_GROUP, + **generator.Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, # type: ignore exp.ApproxDistinct: _approx_distinct_sql, exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", exp.ArrayConcat: rename_func("CONCAT"), diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index e1f7b78..a9b12fb 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from sqlglot import exp from sqlglot.dialects.postgres import Postgres from sqlglot.tokens import TokenType @@ -6,29 +8,30 @@ from sqlglot.tokens import TokenType class Redshift(Postgres): time_format = "'YYYY-MM-DD HH:MI:SS'" time_mapping = { - **Postgres.time_mapping, + **Postgres.time_mapping, # type: ignore "MON": "%b", "HH": "%H", } class Tokenizer(Postgres.Tokenizer): - ESCAPE = "\\" + ESCAPES = ["\\"] KEYWORDS = { - **Postgres.Tokenizer.KEYWORDS, + **Postgres.Tokenizer.KEYWORDS, # type: ignore "GEOMETRY": TokenType.GEOMETRY, "GEOGRAPHY": TokenType.GEOGRAPHY, "HLLSKETCH": TokenType.HLLSKETCH, "SUPER": TokenType.SUPER, "TIME": TokenType.TIMESTAMP, "TIMETZ": TokenType.TIMESTAMPTZ, - "VARBYTE": TokenType.BINARY, + "VARBYTE": TokenType.VARBINARY, "SIMILAR TO": TokenType.SIMILAR_TO, } class Generator(Postgres.Generator): TYPE_MAPPING = { - **Postgres.Generator.TYPE_MAPPING, + **Postgres.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.BINARY: "VARBYTE", + exp.DataType.Type.VARBINARY: "VARBYTE", exp.DataType.Type.INT: "INTEGER", } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 3b97e6d..d1aaded 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -1,4 +1,6 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, format_time_lambda, @@ -6,10 +8,8 @@ from sqlglot.dialects.dialect import ( rename_func, ) from sqlglot.expressions import Literal -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.helper import seq_get +from sqlglot.tokens import TokenType def _check_int(s): @@ -28,7 +28,9 @@ def _snowflake_to_timestamp(args): # case: <numeric_expr> [ , <scale> ] if second_arg.name not in ["0", "3", "9"]: - raise ValueError(f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9") + raise ValueError( + f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9" + ) if second_arg.name == "0": timescale = exp.UnixToTime.SECONDS @@ -39,7 +41,7 @@ def _snowflake_to_timestamp(args): return exp.UnixToTime(this=first_arg, scale=timescale) - first_arg = list_get(args, 0) + first_arg = seq_get(args, 0) if not isinstance(first_arg, Literal): # case: <variant_expr> return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args) @@ -56,7 +58,7 @@ def _snowflake_to_timestamp(args): return exp.UnixToTime.from_arg_list(args) -def _unix_to_time(self, expression): +def _unix_to_time_sql(self, expression): scale = expression.args.get("scale") timestamp = self.sql(expression, "this") if scale in [None, exp.UnixToTime.SECONDS]: @@ -132,9 +134,9 @@ class Snowflake(Dialect): "ff6": "%f", } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, "IFF": exp.If.from_arg_list, "TO_TIMESTAMP": _snowflake_to_timestamp, @@ -143,18 +145,18 @@ class Snowflake(Dialect): } FUNCTION_PARSERS = { - **Parser.FUNCTION_PARSERS, + **parser.Parser.FUNCTION_PARSERS, "DATE_PART": _parse_date_part, } FUNC_TOKENS = { - *Parser.FUNC_TOKENS, + *parser.Parser.FUNC_TOKENS, TokenType.RLIKE, TokenType.TABLE, } COLUMN_OPERATORS = { - **Parser.COLUMN_OPERATORS, + **parser.Parser.COLUMN_OPERATORS, # type: ignore TokenType.COLON: lambda self, this, path: self.expression( exp.Bracket, this=this, @@ -163,21 +165,21 @@ class Snowflake(Dialect): } PROPERTY_PARSERS = { - **Parser.PROPERTY_PARSERS, + **parser.Parser.PROPERTY_PARSERS, TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(), } - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] - ESCAPE = "\\" + ESCAPES = ["\\"] SINGLE_TOKENS = { - **Tokenizer.SINGLE_TOKENS, + **tokens.Tokenizer.SINGLE_TOKENS, "$": TokenType.PARAMETER, } KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "QUALIFY": TokenType.QUALIFY, "DOUBLE PRECISION": TokenType.DOUBLE, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, @@ -187,15 +189,15 @@ class Snowflake(Dialect): "SAMPLE": TokenType.TABLE_SAMPLE, } - class Generator(Generator): + class Generator(generator.Generator): CREATE_TRANSIENT = True TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.If: rename_func("IFF"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", - exp.UnixToTime: _unix_to_time, + exp.UnixToTime: _unix_to_time_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.Array: inline_array_sql, exp.StrPosition: rename_func("POSITION"), @@ -204,7 +206,7 @@ class Snowflake(Dialect): } TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", } diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 572f411..4e404b8 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -1,8 +1,9 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, parser from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func from sqlglot.dialects.hive import Hive -from sqlglot.helper import list_get -from sqlglot.parser import Parser +from sqlglot.helper import seq_get def _create_sql(self, e): @@ -46,36 +47,36 @@ def _unix_to_time(self, expression): class Spark(Hive): class Parser(Hive.Parser): FUNCTIONS = { - **Hive.Parser.FUNCTIONS, + **Hive.Parser.FUNCTIONS, # type: ignore "MAP_FROM_ARRAYS": exp.Map.from_arg_list, "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, "LEFT": lambda args: exp.Substring( - this=list_get(args, 0), + this=seq_get(args, 0), start=exp.Literal.number(1), - length=list_get(args, 1), + length=seq_get(args, 1), ), "SHIFTLEFT": lambda args: exp.BitwiseLeftShift( - this=list_get(args, 0), - expression=list_get(args, 1), + this=seq_get(args, 0), + expression=seq_get(args, 1), ), "SHIFTRIGHT": lambda args: exp.BitwiseRightShift( - this=list_get(args, 0), - expression=list_get(args, 1), + this=seq_get(args, 0), + expression=seq_get(args, 1), ), "RIGHT": lambda args: exp.Substring( - this=list_get(args, 0), + this=seq_get(args, 0), start=exp.Sub( - this=exp.Length(this=list_get(args, 0)), - expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)), + this=exp.Length(this=seq_get(args, 0)), + expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)), ), - length=list_get(args, 1), + length=seq_get(args, 1), ), "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, "IIF": exp.If.from_arg_list, } FUNCTION_PARSERS = { - **Parser.FUNCTION_PARSERS, + **parser.Parser.FUNCTION_PARSERS, "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), @@ -88,14 +89,14 @@ class Spark(Hive): class Generator(Hive.Generator): TYPE_MAPPING = { - **Hive.Generator.TYPE_MAPPING, + **Hive.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TINYINT: "BYTE", exp.DataType.Type.SMALLINT: "SHORT", exp.DataType.Type.BIGINT: "LONG", } TRANSFORMS = { - **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort, exp.ILike}}, + **Hive.Generator.TRANSFORMS, # type: ignore exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}", exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", @@ -114,6 +115,8 @@ class Spark(Hive): exp.VariancePop: rename_func("VAR_POP"), exp.DateFromParts: rename_func("MAKE_DATE"), } + TRANSFORMS.pop(exp.ArraySort) + TRANSFORMS.pop(exp.ILike) WRAP_DERIVED_VALUES = False diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 62b7617..8c9fb76 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -1,4 +1,6 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, arrow_json_extract_scalar_sql, @@ -8,31 +10,28 @@ from sqlglot.dialects.dialect import ( no_trycast_sql, rename_func, ) -from sqlglot.generator import Generator -from sqlglot.parser import Parser -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.tokens import TokenType class SQLite(Dialect): - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]"), "`"] HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")] KEYWORDS = { - **Tokenizer.KEYWORDS, - "VARBINARY": TokenType.BINARY, + **tokens.Tokenizer.KEYWORDS, "AUTOINCREMENT": TokenType.AUTO_INCREMENT, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "EDITDIST3": exp.Levenshtein.from_arg_list, } - class Generator(Generator): + class Generator(generator.Generator): TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.BOOLEAN: "INTEGER", exp.DataType.Type.TINYINT: "INTEGER", exp.DataType.Type.SMALLINT: "INTEGER", @@ -46,6 +45,7 @@ class SQLite(Dialect): exp.DataType.Type.VARCHAR: "TEXT", exp.DataType.Type.NVARCHAR: "TEXT", exp.DataType.Type.BINARY: "BLOB", + exp.DataType.Type.VARBINARY: "BLOB", } TOKEN_MAPPING = { @@ -53,7 +53,7 @@ class SQLite(Dialect): } TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, exp.ILike: no_ilike_sql, exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 0cba6fe..3519c09 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -1,10 +1,12 @@ +from __future__ import annotations + from sqlglot import exp from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func from sqlglot.dialects.mysql import MySQL class StarRocks(MySQL): - class Generator(MySQL.Generator): + class Generator(MySQL.Generator): # type: ignore TYPE_MAPPING = { **MySQL.Generator.TYPE_MAPPING, exp.DataType.Type.TEXT: "STRING", @@ -13,7 +15,7 @@ class StarRocks(MySQL): } TRANSFORMS = { - **MySQL.Generator.TRANSFORMS, + **MySQL.Generator.TRANSFORMS, # type: ignore exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, exp.DateDiff: rename_func("DATEDIFF"), @@ -22,3 +24,4 @@ class StarRocks(MySQL): exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})", exp.UnixToTime: rename_func("FROM_UNIXTIME"), } + TRANSFORMS.pop(exp.DateTrunc) diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index 45aa041..63e7275 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -1,7 +1,7 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, generator, parser from sqlglot.dialects.dialect import Dialect -from sqlglot.generator import Generator -from sqlglot.parser import Parser def _if_sql(self, expression): @@ -20,17 +20,17 @@ def _count_sql(self, expression): class Tableau(Dialect): - class Generator(Generator): + class Generator(generator.Generator): TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.If: _if_sql, exp.Coalesce: _coalesce_sql, exp.Count: _count_sql, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "IFNULL": exp.Coalesce.from_arg_list, "COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)), } diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py index 9a6f7fe..c7b34fe 100644 --- a/sqlglot/dialects/trino.py +++ b/sqlglot/dialects/trino.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from sqlglot import exp from sqlglot.dialects.presto import Presto @@ -5,7 +7,7 @@ from sqlglot.dialects.presto import Presto class Trino(Presto): class Generator(Presto.Generator): TRANSFORMS = { - **Presto.Generator.TRANSFORMS, + **Presto.Generator.TRANSFORMS, # type: ignore exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 0f93c75..a233d4b 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -1,15 +1,22 @@ +from __future__ import annotations + import re -from sqlglot import exp +from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func from sqlglot.expressions import DataType -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser +from sqlglot.helper import seq_get from sqlglot.time import format_time -from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.tokens import TokenType -FULL_FORMAT_TIME_MAPPING = {"weekday": "%A", "dw": "%A", "w": "%A", "month": "%B", "mm": "%B", "m": "%B"} +FULL_FORMAT_TIME_MAPPING = { + "weekday": "%A", + "dw": "%A", + "w": "%A", + "month": "%B", + "mm": "%B", + "m": "%B", +} DATE_DELTA_INTERVAL = { "year": "year", "yyyy": "year", @@ -37,11 +44,13 @@ TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"} def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): def _format_time(args): return exp_class( - this=list_get(args, 1), + this=seq_get(args, 1), format=exp.Literal.string( format_time( - list_get(args, 0).name or (TSQL.time_format if default is True else default), - {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.time_mapping, + seq_get(args, 0).name or (TSQL.time_format if default is True else default), + {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} + if full_format_mapping + else TSQL.time_mapping, ) ), ) @@ -50,12 +59,12 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): def parse_format(args): - fmt = list_get(args, 1) + fmt = seq_get(args, 1) number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this) if number_fmt: - return exp.NumberToStr(this=list_get(args, 0), format=fmt) + return exp.NumberToStr(this=seq_get(args, 0), format=fmt) return exp.TimeToStr( - this=list_get(args, 0), + this=seq_get(args, 0), format=exp.Literal.string( format_time(fmt.name, TSQL.format_time_mapping) if len(fmt.name) == 1 @@ -188,11 +197,11 @@ class TSQL(Dialect): "Y": "%a %Y", } - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]")] KEYWORDS = { - **Tokenizer.KEYWORDS, + **tokens.Tokenizer.KEYWORDS, "BIT": TokenType.BOOLEAN, "REAL": TokenType.FLOAT, "NTEXT": TokenType.TEXT, @@ -200,7 +209,6 @@ class TSQL(Dialect): "DATETIME2": TokenType.DATETIME, "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, "TIME": TokenType.TIMESTAMP, - "VARBINARY": TokenType.BINARY, "IMAGE": TokenType.IMAGE, "MONEY": TokenType.MONEY, "SMALLMONEY": TokenType.SMALLMONEY, @@ -213,9 +221,9 @@ class TSQL(Dialect): "TOP": TokenType.TOP, } - class Parser(Parser): + class Parser(parser.Parser): FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "CHARINDEX": exp.StrPosition.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list, "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), @@ -243,14 +251,16 @@ class TSQL(Dialect): this = self._parse_column() # Retrieve length of datatype and override to default if not specified - if list_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: + if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False) # Check whether a conversion with format is applicable if self._match(TokenType.COMMA): format_val = self._parse_number().name if format_val not in TSQL.convert_format_mapping: - raise ValueError(f"CONVERT function at T-SQL does not support format style {format_val}") + raise ValueError( + f"CONVERT function at T-SQL does not support format style {format_val}" + ) format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val]) # Check whether the convert entails a string to date format @@ -272,9 +282,9 @@ class TSQL(Dialect): # Entails a simple cast without any format requirement return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) - class Generator(Generator): + class Generator(generator.Generator): TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.BOOLEAN: "BIT", exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.DECIMAL: "NUMERIC", @@ -283,7 +293,7 @@ class TSQL(Dialect): } TRANSFORMS = { - **Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), |