diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/dialects/__init__.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/bigquery.py | 51 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 20 | ||||
-rw-r--r-- | sqlglot/dialects/databricks.py | 11 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 42 | ||||
-rw-r--r-- | sqlglot/dialects/drill.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 66 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 45 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 13 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 18 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 30 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 240 | ||||
-rw-r--r-- | sqlglot/dialects/spark2.py | 238 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 43 | ||||
-rw-r--r-- | sqlglot/dialects/starrocks.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/tableau.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/teradata.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 11 |
21 files changed, 536 insertions, 314 deletions
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index 191e703..fc34262 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -70,6 +70,7 @@ from sqlglot.dialects.presto import Presto from sqlglot.dialects.redshift import Redshift from sqlglot.dialects.snowflake import Snowflake from sqlglot.dialects.spark import Spark +from sqlglot.dialects.spark2 import Spark2 from sqlglot.dialects.sqlite import SQLite from sqlglot.dialects.starrocks import StarRocks from sqlglot.dialects.tableau import Tableau diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 1a88654..9705b35 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -39,18 +39,26 @@ def _date_add_sql( def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str: if not isinstance(expression.unnest().parent, exp.From): - expression = t.cast(exp.Values, transforms.remove_precision_parameterized_types(expression)) return self.values_sql(expression) - rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)] - structs = [] - for row in rows: - aliases = [ - exp.alias_(value, column_name) - for value, column_name in zip(row, expression.args["alias"].args["columns"]) - ] - structs.append(exp.Struct(expressions=aliases)) - unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)]) - return self.unnest_sql(unnest_exp) + + alias = expression.args.get("alias") + + structs = [ + exp.Struct( + expressions=[ + exp.alias_(value, column_name) + for value, column_name in zip( + t.expressions, + alias.columns + if alias and alias.columns + else (f"_c{i}" for i in range(len(t.expressions))), + ) + ] + ) + for t in expression.find_all(exp.Tuple) + ] + + return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)])) def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str: @@ -128,6 +136,7 @@ class BigQuery(Dialect): IDENTIFIERS = ["`"] STRING_ESCAPES = ["\\"] HEX_STRINGS = [("0x", ""), ("0X", "")] + BYTE_STRINGS = [("b'", "'"), ("B'", "'")] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -139,6 +148,7 @@ class BigQuery(Dialect): "GEOGRAPHY": TokenType.GEOGRAPHY, "FLOAT64": TokenType.DOUBLE, "INT64": TokenType.BIGINT, + "BYTES": TokenType.BINARY, "NOT DETERMINISTIC": TokenType.VOLATILE, "UNKNOWN": TokenType.NULL, } @@ -153,7 +163,7 @@ class BigQuery(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore "DATE_TRUNC": lambda args: exp.DateTrunc( - unit=exp.Literal.string(seq_get(args, 1).name), # type: ignore + unit=exp.Literal.string(str(seq_get(args, 1))), this=seq_get(args, 0), ), "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), @@ -206,6 +216,12 @@ class BigQuery(Dialect): "NOT DETERMINISTIC": lambda self: self.expression( exp.StabilityProperty, this=exp.Literal.string("VOLATILE") ), + "OPTIONS": lambda self: self._parse_with_property(), + } + + CONSTRAINT_PARSERS = { + **parser.Parser.CONSTRAINT_PARSERS, # type: ignore + "OPTIONS": lambda self: exp.Properties(expressions=self._parse_with_property()), } class Generator(generator.Generator): @@ -217,11 +233,11 @@ class BigQuery(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore - **transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.AtTimeZone: lambda self, e: self.func( "TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone")) ), + exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]), exp.DateAdd: _date_add_sql("DATE", "ADD"), exp.DateSub: _date_add_sql("DATE", "SUB"), exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"), @@ -234,7 +250,9 @@ class BigQuery(Dialect): exp.IntDiv: rename_func("DIV"), exp.Max: max_or_greatest, exp.Min: min_or_least, - exp.Select: transforms.preprocess([_unqualify_unnest]), + exp.Select: transforms.preprocess( + [_unqualify_unnest, transforms.eliminate_distinct_on] + ), exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", exp.TimeAdd: _date_add_sql("TIME", "ADD"), exp.TimeSub: _date_add_sql("TIME", "SUB"), @@ -259,6 +277,7 @@ class BigQuery(Dialect): **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC", exp.DataType.Type.BIGINT: "INT64", + exp.DataType.Type.BINARY: "BYTES", exp.DataType.Type.BOOLEAN: "BOOL", exp.DataType.Type.CHAR: "STRING", exp.DataType.Type.DECIMAL: "NUMERIC", @@ -272,6 +291,7 @@ class BigQuery(Dialect): exp.DataType.Type.TIMESTAMP: "DATETIME", exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", exp.DataType.Type.TINYINT: "INT64", + exp.DataType.Type.VARBINARY: "BYTES", exp.DataType.Type.VARCHAR: "STRING", exp.DataType.Type.VARIANT: "ANY TYPE", } @@ -310,3 +330,6 @@ class BigQuery(Dialect): if not expression.args.get("distinct", False): self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery") return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" + + def with_properties(self, properties: exp.Properties) -> str: + return self.properties(properties, prefix=self.seg("OPTIONS")) diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index e91b0bf..2a49066 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -22,6 +22,8 @@ class ClickHouse(Dialect): class Tokenizer(tokens.Tokenizer): COMMENTS = ["--", "#", "#!", ("/*", "*/")] IDENTIFIERS = ['"', "`"] + BIT_STRINGS = [("0b", "")] + HEX_STRINGS = [("0x", ""), ("0X", "")] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -31,10 +33,18 @@ class ClickHouse(Dialect): "FINAL": TokenType.FINAL, "FLOAT32": TokenType.FLOAT, "FLOAT64": TokenType.DOUBLE, + "INT8": TokenType.TINYINT, + "UINT8": TokenType.UTINYINT, "INT16": TokenType.SMALLINT, + "UINT16": TokenType.USMALLINT, "INT32": TokenType.INT, + "UINT32": TokenType.UINT, "INT64": TokenType.BIGINT, - "INT8": TokenType.TINYINT, + "UINT64": TokenType.UBIGINT, + "INT128": TokenType.INT128, + "UINT128": TokenType.UINT128, + "INT256": TokenType.INT256, + "UINT256": TokenType.UINT256, "TUPLE": TokenType.STRUCT, } @@ -121,9 +131,17 @@ class ClickHouse(Dialect): exp.DataType.Type.ARRAY: "Array", exp.DataType.Type.STRUCT: "Tuple", exp.DataType.Type.TINYINT: "Int8", + exp.DataType.Type.UTINYINT: "UInt8", exp.DataType.Type.SMALLINT: "Int16", + exp.DataType.Type.USMALLINT: "UInt16", exp.DataType.Type.INT: "Int32", + exp.DataType.Type.UINT: "UInt32", exp.DataType.Type.BIGINT: "Int64", + exp.DataType.Type.UBIGINT: "UInt64", + exp.DataType.Type.INT128: "Int128", + exp.DataType.Type.UINT128: "UInt128", + exp.DataType.Type.INT256: "Int256", + exp.DataType.Type.UINT256: "UInt256", exp.DataType.Type.FLOAT: "Float32", exp.DataType.Type.DOUBLE: "Float64", } diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 138f26c..51112a0 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -1,6 +1,6 @@ from __future__ import annotations -from sqlglot import exp +from sqlglot import exp, transforms from sqlglot.dialects.dialect import parse_date_delta from sqlglot.dialects.spark import Spark from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql @@ -29,13 +29,20 @@ class Databricks(Spark): exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, exp.JSONExtract: lambda self, e: self.binary(e, ":"), + exp.Select: transforms.preprocess( + [ + transforms.eliminate_distinct_on, + transforms.unnest_to_explode, + ] + ), exp.ToChar: lambda self, e: self.function_fallback_sql(e), } - TRANSFORMS.pop(exp.Select) # Remove the ELIMINATE_QUALIFY transformation PARAMETER_TOKEN = "$" class Tokenizer(Spark.Tokenizer): + HEX_STRINGS = [] + SINGLE_TOKENS = { **Spark.Tokenizer.SINGLE_TOKENS, "$": TokenType.PARAMETER, diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 19c6f73..71269f2 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -28,6 +28,7 @@ class Dialects(str, Enum): REDSHIFT = "redshift" SNOWFLAKE = "snowflake" SPARK = "spark" + SPARK2 = "spark2" SQLITE = "sqlite" STARROCKS = "starrocks" TABLEAU = "tableau" @@ -69,30 +70,17 @@ class _Dialect(type): klass.tokenizer_class._IDENTIFIERS.items() )[0] - if ( - klass.tokenizer_class._BIT_STRINGS - and exp.BitString not in klass.generator_class.TRANSFORMS - ): - bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0] - klass.generator_class.TRANSFORMS[ - exp.BitString - ] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}" - if ( - klass.tokenizer_class._HEX_STRINGS - and exp.HexString not in klass.generator_class.TRANSFORMS - ): - hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0] - klass.generator_class.TRANSFORMS[ - exp.HexString - ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}" - if ( - klass.tokenizer_class._BYTE_STRINGS - and exp.ByteString not in klass.generator_class.TRANSFORMS - ): - be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0] - klass.generator_class.TRANSFORMS[ - exp.ByteString - ] = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}" + klass.bit_start, klass.bit_end = seq_get( + list(klass.tokenizer_class._BIT_STRINGS.items()), 0 + ) or (None, None) + + klass.hex_start, klass.hex_end = seq_get( + list(klass.tokenizer_class._HEX_STRINGS.items()), 0 + ) or (None, None) + + klass.byte_start, klass.byte_end = seq_get( + list(klass.tokenizer_class._BYTE_STRINGS.items()), 0 + ) or (None, None) return klass @@ -198,6 +186,12 @@ class Dialect(metaclass=_Dialect): **{ "quote_start": self.quote_start, "quote_end": self.quote_end, + "bit_start": self.bit_start, + "bit_end": self.bit_end, + "hex_start": self.hex_start, + "hex_end": self.hex_end, + "byte_start": self.byte_start, + "byte_end": self.byte_end, "identifier_start": self.identifier_start, "identifier_end": self.identifier_end, "string_escape": self.tokenizer_class.STRING_ESCAPES[0], diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index d7e2d88..7ad555e 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -2,7 +2,7 @@ from __future__ import annotations import typing as t -from sqlglot import exp, generator, parser, tokens +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, create_with_partitions_sql, @@ -145,6 +145,7 @@ class Drill(Dialect): exp.StrPosition: str_position_sql, exp.StrToDate: _str_to_date, exp.Pow: rename_func("POW"), + exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", exp.TimeStrToTime: timestrtotime_sql, diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 9454db6..bce956e 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, @@ -23,52 +25,61 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _ts_or_ds_add(self, expression): +def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str: this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" -def _date_add(self, expression): +def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str: this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" return f"{this} + {self.sql(exp.Interval(this=expression.expression, unit=unit))}" -def _array_sort_sql(self, expression): +def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str: if expression.expression: self.unsupported("DUCKDB ARRAY_SORT does not support a comparator") return f"ARRAY_SORT({self.sql(expression, 'this')})" -def _sort_array_sql(self, expression): +def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str: this = self.sql(expression, "this") if expression.args.get("asc") == exp.false(): return f"ARRAY_REVERSE_SORT({this})" return f"ARRAY_SORT({this})" -def _sort_array_reverse(args): +def _sort_array_reverse(args: t.Sequence) -> exp.Expression: return exp.SortArray(this=seq_get(args, 0), asc=exp.false()) -def _struct_sql(self, expression): +def _parse_date_diff(args: t.Sequence) -> exp.Expression: + return exp.DateDiff( + this=seq_get(args, 2), + expression=seq_get(args, 1), + unit=seq_get(args, 0), + ) + + +def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str: args = [ f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions ] return f"{{{', '.join(args)}}}" -def _datatype_sql(self, expression): +def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: if expression.this == exp.DataType.Type.ARRAY: return f"{self.expressions(expression, flat=True)}[]" return self.datatype_sql(expression) -def _regexp_extract_sql(self, expression): +def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract) -> str: bad_args = list(filter(expression.args.get, ("position", "occurrence"))) if bad_args: self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}") + return self.func( "REGEXP_EXTRACT", expression.args.get("this"), @@ -108,6 +119,8 @@ class DuckDB(Dialect): "ARRAY_LENGTH": exp.ArraySize.from_arg_list, "ARRAY_SORT": exp.SortArray.from_arg_list, "ARRAY_REVERSE_SORT": _sort_array_reverse, + "DATEDIFF": _parse_date_diff, + "DATE_DIFF": _parse_date_diff, "EPOCH": exp.TimeToUnix.from_arg_list, "EPOCH_MS": lambda args: exp.UnixToTime( this=exp.Div( @@ -115,18 +128,18 @@ class DuckDB(Dialect): expression=exp.Literal.number(1000), ) ), - "LIST_SORT": exp.SortArray.from_arg_list, "LIST_REVERSE_SORT": _sort_array_reverse, + "LIST_SORT": exp.SortArray.from_arg_list, "LIST_VALUE": exp.Array.from_arg_list, "REGEXP_MATCHES": exp.RegexpLike.from_arg_list, "STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"), - "STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"), - "STR_SPLIT": exp.Split.from_arg_list, "STRING_SPLIT": exp.Split.from_arg_list, - "STRING_TO_ARRAY": exp.Split.from_arg_list, - "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, + "STRING_TO_ARRAY": exp.Split.from_arg_list, + "STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"), "STRUCT_PACK": exp.Struct.from_arg_list, + "STR_SPLIT": exp.Split.from_arg_list, + "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, "TO_TIMESTAMP": exp.UnixToTime.from_arg_list, "UNNEST": exp.Explode.from_arg_list, } @@ -142,10 +155,11 @@ class DuckDB(Dialect): class Generator(generator.Generator): JOIN_HINTS = False TABLE_HINTS = False + LIMIT_FETCH = "LIMIT" STRUCT_DELIMITER = ("(", ")") TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0]) if isinstance(seq_get(e.expressions, 0), exp.Select) @@ -154,13 +168,16 @@ class DuckDB(Dialect): exp.ArraySort: _array_sort_sql, exp.ArraySum: rename_func("LIST_SUM"), exp.CommentColumnConstraint: no_comment_column_constraint_sql, + exp.CurrentDate: lambda self, e: "CURRENT_DATE", + exp.CurrentTime: lambda self, e: "CURRENT_TIME", + exp.CurrentTimestamp: lambda self, e: "CURRENT_TIMESTAMP", exp.DayOfMonth: rename_func("DAYOFMONTH"), exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.DayOfYear: rename_func("DAYOFYEAR"), exp.DataType: _datatype_sql, - exp.DateAdd: _date_add, + exp.DateAdd: _date_add_sql, exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", e.args.get("unit") or exp.Literal.string("day"), e.expression, e.this + "DATE_DIFF", f"'{e.args.get('unit', 'day')}'", e.expression, e.this ), exp.DateStrToDate: datestrtodate_sql, exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)", @@ -192,7 +209,7 @@ class DuckDB(Dialect): exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToUnix: rename_func("EPOCH"), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)", - exp.TsOrDsAdd: _ts_or_ds_add, + exp.TsOrDsAdd: _ts_or_ds_add_sql, exp.TsOrDsToDate: ts_or_ds_to_date_sql("duckdb"), exp.UnixToStr: lambda self, e: f"STRFTIME(TO_TIMESTAMP({self.sql(e, 'this')}), {self.format_time(e)})", exp.UnixToTime: rename_func("TO_TIMESTAMP"), @@ -201,7 +218,7 @@ class DuckDB(Dialect): } TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.BINARY: "BLOB", exp.DataType.Type.CHAR: "TEXT", exp.DataType.Type.FLOAT: "REAL", @@ -212,17 +229,14 @@ class DuckDB(Dialect): exp.DataType.Type.VARCHAR: "TEXT", } - STAR_MAPPING = { - **generator.Generator.STAR_MAPPING, - "except": "EXCLUDE", - } + STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"} PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, # type: ignore exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } - LIMIT_FETCH = "LIMIT" - - def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str: - return super().tablesample_sql(expression, seed_prefix="REPEATABLE") + def tablesample_sql( + self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " + ) -> str: + return super().tablesample_sql(expression, seed_prefix="REPEATABLE", sep=sep) diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 6746fcf..871a180 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -81,7 +81,20 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: return f"{diff_sql}{multiplier_sql}" -def _array_sort(self: generator.Generator, expression: exp.ArraySort) -> str: +def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str: + this = expression.this + + if not this.type: + from sqlglot.optimizer.annotate_types import annotate_types + + annotate_types(this) + + if this.type.is_type(exp.DataType.Type.JSON): + return self.sql(this) + return self.func("TO_JSON", this, expression.args.get("options")) + + +def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str: if expression.expression: self.unsupported("Hive SORT_ARRAY does not support a comparator") return f"SORT_ARRAY({self.sql(expression, 'this')})" @@ -91,11 +104,11 @@ def _property_sql(self: generator.Generator, expression: exp.Property) -> str: return f"'{expression.name}'={self.sql(expression, 'value')}" -def _str_to_unix(self: generator.Generator, expression: exp.StrToUnix) -> str: +def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str: return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression)) -def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str: +def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format not in (Hive.time_format, Hive.date_format): @@ -103,7 +116,7 @@ def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str: return f"CAST({this} AS DATE)" -def _str_to_time(self: generator.Generator, expression: exp.StrToTime) -> str: +def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format not in (Hive.time_format, Hive.date_format): @@ -214,6 +227,7 @@ class Hive(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, + "BASE64": exp.ToBase64.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list, "DATE_ADD": lambda args: exp.TsOrDsAdd( this=seq_get(args, 0), @@ -251,6 +265,7 @@ class Hive(Dialect): "SPLIT": exp.RegexpSplit.from_arg_list, "TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"), "TO_JSON": exp.JSONFormat.from_arg_list, + "UNBASE64": exp.FromBase64.from_arg_list, "UNIX_TIMESTAMP": format_time_lambda(exp.StrToUnix, "hive", True), "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), } @@ -280,16 +295,20 @@ class Hive(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore - **transforms.UNALIAS_GROUP, # type: ignore - **transforms.ELIMINATE_QUALIFY, # type: ignore + exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Select: transforms.preprocess( - [transforms.eliminate_qualify, transforms.unnest_to_explode] + [ + transforms.eliminate_qualify, + transforms.eliminate_distinct_on, + transforms.unnest_to_explode, + ] ), exp.Property: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, exp.ArrayConcat: rename_func("CONCAT"), + exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this), exp.ArraySize: rename_func("SIZE"), - exp.ArraySort: _array_sort, + exp.ArraySort: _array_sort_sql, exp.With: no_recursive_cte_sql, exp.DateAdd: _add_date_sql, exp.DateDiff: _date_diff_sql, @@ -298,12 +317,13 @@ class Hive(Dialect): exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})", exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}", + exp.FromBase64: rename_func("UNBASE64"), exp.If: if_sql, exp.Index: _index_sql, exp.ILike: no_ilike_sql, exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), - exp.JSONFormat: rename_func("TO_JSON"), + exp.JSONFormat: _json_format_sql, exp.Map: var_map_sql, exp.Max: max_or_greatest, exp.Min: min_or_least, @@ -318,9 +338,9 @@ class Hive(Dialect): exp.SetAgg: rename_func("COLLECT_SET"), exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", exp.StrPosition: strposition_to_locate_sql, - exp.StrToDate: _str_to_date, - exp.StrToTime: _str_to_time, - exp.StrToUnix: _str_to_unix, + exp.StrToDate: _str_to_date_sql, + exp.StrToTime: _str_to_time_sql, + exp.StrToUnix: _str_to_unix_sql, exp.StructExtract: struct_extract_sql, exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}", exp.TimeStrToDate: rename_func("TO_DATE"), @@ -328,6 +348,7 @@ class Hive(Dialect): exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeToStr: _time_to_str, exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), + exp.ToBase64: rename_func("BASE64"), exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)", exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.TsOrDsToDate: _to_date_sql, diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 666e740..5342624 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -1,6 +1,6 @@ from __future__ import annotations -from sqlglot import exp, generator, parser, tokens +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, arrow_json_extract_scalar_sql, @@ -403,6 +403,7 @@ class MySQL(Dialect): exp.Min: min_or_least, exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), + exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date_sql, exp.StrToTime: _str_to_date_sql, diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 9ccd02e..c8af1c6 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -34,6 +34,8 @@ def _parse_xml_table(self) -> exp.XMLTable: class Oracle(Dialect): + alias_post_tablesample = True + # https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212 # https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes time_mapping = { @@ -121,21 +123,23 @@ class Oracle(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore - **transforms.UNALIAS_GROUP, # type: ignore exp.DateStrToDate: lambda self, e: self.func( "TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD") ), + exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.ILike: no_ilike_sql, + exp.IfNull: rename_func("NVL"), + exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "), exp.Substring: rename_func("SUBSTR"), exp.Table: lambda self, e: self.table_sql(e, sep=" "), + exp.TableSample: lambda self, e: self.tablesample_sql(e, sep=" "), exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: trim_sql, exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", - exp.IfNull: rename_func("NVL"), } PROPERTIES_LOCATION = { @@ -164,14 +168,19 @@ class Oracle(Dialect): return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}" class Tokenizer(tokens.Tokenizer): + VAR_SINGLE_TOKENS = {"@"} + KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "(+)": TokenType.JOIN_MARKER, + "BINARY_DOUBLE": TokenType.DOUBLE, + "BINARY_FLOAT": TokenType.FLOAT, "COLUMNS": TokenType.COLUMN, "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, "MINUS": TokenType.EXCEPT, "NVARCHAR2": TokenType.NVARCHAR, "RETURNING": TokenType.RETURNING, + "SAMPLE": TokenType.TABLE_SAMPLE, "START": TokenType.BEGIN, "TOP": TokenType.TOP, "VARCHAR2": TokenType.VARCHAR, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index c47ff51..2132778 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -1,6 +1,8 @@ from __future__ import annotations -from sqlglot import exp, generator, parser, tokens +import typing as t + +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, arrow_json_extract_scalar_sql, @@ -20,7 +22,6 @@ from sqlglot.dialects.dialect import ( from sqlglot.helper import seq_get from sqlglot.parser import binary_range_parser from sqlglot.tokens import TokenType -from sqlglot.transforms import preprocess, remove_target_from_merge DATE_DIFF_FACTOR = { "MICROSECOND": " * 1000000", @@ -274,8 +275,7 @@ class Postgres(Dialect): TokenType.HASH: exp.BitwiseXor, } - FACTOR = { - **parser.Parser.FACTOR, + EXPONENT = { TokenType.CARET: exp.Pow, } @@ -286,6 +286,12 @@ class Postgres(Dialect): TokenType.LT_AT: binary_range_parser(exp.ArrayContained), } + def _parse_factor(self) -> t.Optional[exp.Expression]: + return self._parse_tokens(self._parse_exponent, self.FACTOR) + + def _parse_exponent(self) -> t.Optional[exp.Expression]: + return self._parse_tokens(self._parse_unary, self.EXPONENT) + def _parse_date_part(self) -> exp.Expression: part = self._parse_type() self._match(TokenType.COMMA) @@ -316,7 +322,7 @@ class Postgres(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore exp.BitwiseXor: lambda self, e: self.binary(e, "#"), - exp.ColumnDef: preprocess( + exp.ColumnDef: transforms.preprocess( [ _auto_increment_to_serial, _serial_to_generated, @@ -341,7 +347,7 @@ class Postgres(Dialect): exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), exp.ArrayContains: lambda self, e: self.binary(e, "@>"), exp.ArrayContained: lambda self, e: self.binary(e, "<@"), - exp.Merge: preprocess([remove_target_from_merge]), + exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]), exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), exp.StrPosition: str_position_sql, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 489d439..6133a27 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -130,7 +130,7 @@ def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> s def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> str: start = expression.args["start"] end = expression.args["end"] - step = expression.args.get("step", 1) # Postgres defaults to 1 for generate_series + step = expression.args.get("step") target_type = None @@ -147,7 +147,11 @@ def _sequence_sql(self: generator.Generator, expression: exp.GenerateSeries) -> else: start = exp.Cast(this=start, to=to) - return self.func("SEQUENCE", start, end, step) + sql = self.func("SEQUENCE", start, end, step) + if isinstance(expression.parent, exp.Table): + sql = f"UNNEST({sql})" + + return sql def _ensure_utf8(charset: exp.Literal) -> None: @@ -204,6 +208,7 @@ class Presto(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, + "APPROX_PERCENTILE": _approx_percentile, "CARDINALITY": exp.ArraySize.from_arg_list, "CONTAINS": exp.ArrayContains.from_arg_list, "DATE_ADD": lambda args: exp.DateAdd( @@ -219,23 +224,23 @@ class Presto(Dialect): "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), "DATE_TRUNC": date_trunc_to_time, + "FROM_HEX": exp.Unhex.from_arg_list, "FROM_UNIXTIME": _from_unixtime, + "FROM_UTF8": lambda args: exp.Decode( + this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8") + ), "NOW": exp.CurrentTimestamp.from_arg_list, + "SEQUENCE": exp.GenerateSeries.from_arg_list, "STRPOS": lambda args: exp.StrPosition( this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2), ), "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, - "APPROX_PERCENTILE": _approx_percentile, - "FROM_HEX": exp.Unhex.from_arg_list, "TO_HEX": exp.Hex.from_arg_list, "TO_UTF8": lambda args: exp.Encode( this=seq_get(args, 0), charset=exp.Literal.string("utf-8") ), - "FROM_UTF8": lambda args: exp.Decode( - this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8") - ), } FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy() FUNCTION_PARSERS.pop("TRIM") @@ -264,7 +269,6 @@ class Presto(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore - **transforms.UNALIAS_GROUP, # type: ignore exp.ApproxDistinct: _approx_distinct_sql, exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", exp.ArrayConcat: rename_func("CONCAT"), @@ -290,6 +294,7 @@ class Presto(Dialect): exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", exp.Encode: _encode_sql, exp.GenerateSeries: _sequence_sql, + exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Hex: rename_func("TO_HEX"), exp.If: if_sql, exp.ILike: no_ilike_sql, @@ -303,7 +308,11 @@ class Presto(Dialect): exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, exp.Select: transforms.preprocess( - [transforms.eliminate_qualify, transforms.explode_to_unnest] + [ + transforms.eliminate_qualify, + transforms.eliminate_distinct_on, + transforms.explode_to_unnest, + ] ), exp.SortArray: _no_sort_array, exp.StrPosition: rename_func("STRPOS"), @@ -327,6 +336,9 @@ class Presto(Dialect): exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", exp.VariancePop: rename_func("VAR_POP"), + exp.WithinGroup: transforms.preprocess( + [transforms.remove_within_group_for_percentiles] + ), } def interval_sql(self, expression: exp.Interval) -> str: diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index a9c4f62..1b7cf31 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -52,6 +52,8 @@ class Redshift(Postgres): return this class Tokenizer(Postgres.Tokenizer): + BIT_STRINGS = [] + HEX_STRINGS = [] STRING_ESCAPES = ["\\"] KEYWORDS = { @@ -90,7 +92,6 @@ class Redshift(Postgres): TRANSFORMS = { **Postgres.Generator.TRANSFORMS, # type: ignore - **transforms.ELIMINATE_DISTINCT_ON, # type: ignore exp.CurrentTimestamp: lambda self, e: "SYSDATE", exp.DateAdd: lambda self, e: self.func( "DATEADD", exp.var(e.text("unit") or "day"), e.expression, e.this @@ -102,6 +103,7 @@ class Redshift(Postgres): exp.DistStyleProperty: lambda self, e: self.naked_property(e), exp.JSONExtract: _json_sql, exp.JSONExtractScalar: _json_sql, + exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 0829669..70dcaa9 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -2,7 +2,7 @@ from __future__ import annotations import typing as t -from sqlglot import exp, generator, parser, tokens +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, date_trunc_to_time, @@ -252,6 +252,7 @@ class Snowflake(Dialect): class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] STRING_ESCAPES = ["\\", "'"] + HEX_STRINGS = [("x'", "'"), ("X'", "'")] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -305,6 +306,7 @@ class Snowflake(Dialect): exp.Max: max_or_greatest, exp.Min: min_or_least, exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", + exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StarMap: rename_func("OBJECT_CONSTRUCT"), exp.StrPosition: lambda self, e: self.func( "POSITION", e.args.get("substr"), e.this, e.args.get("position") diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index a3e4cce..939f2fd 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -2,222 +2,54 @@ from __future__ import annotations import typing as t -from sqlglot import exp, parser -from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql -from sqlglot.dialects.hive import Hive +from sqlglot import exp +from sqlglot.dialects.spark2 import Spark2 from sqlglot.helper import seq_get -def _create_sql(self: Hive.Generator, e: exp.Create) -> str: - kind = e.args["kind"] - properties = e.args.get("properties") +def _parse_datediff(args: t.Sequence) -> exp.Expression: + """ + Although Spark docs don't mention the "unit" argument, Spark3 added support for + it at some point. Databricks also supports this variation (see below). - if kind.upper() == "TABLE" and any( - isinstance(prop, exp.TemporaryProperty) - for prop in (properties.expressions if properties else []) - ): - return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}" - return create_with_partitions_sql(self, e) + For example, in spark-sql (v3.3.1): + - SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4 + - SELECT DATEDIFF(day, '2020-01-01', '2020-01-05') results in 4 + See also: + - https://docs.databricks.com/sql/language-manual/functions/datediff3.html + - https://docs.databricks.com/sql/language-manual/functions/datediff.html + """ + unit = None + this = seq_get(args, 0) + expression = seq_get(args, 1) -def _map_sql(self: Hive.Generator, expression: exp.Map) -> str: - keys = self.sql(expression.args["keys"]) - values = self.sql(expression.args["values"]) - return f"MAP_FROM_ARRAYS({keys}, {values})" + if len(args) == 3: + unit = this + this = args[2] + return exp.DateDiff( + this=exp.TsOrDsToDate(this=this), expression=exp.TsOrDsToDate(this=expression), unit=unit + ) -def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str: - this = self.sql(expression, "this") - time_format = self.format_time(expression) - if time_format == Hive.date_format: - return f"TO_DATE({this})" - return f"TO_DATE({this}, {time_format})" - -def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str: - scale = expression.args.get("scale") - timestamp = self.sql(expression, "this") - if scale is None: - return f"FROM_UNIXTIME({timestamp})" - if scale == exp.UnixToTime.SECONDS: - return f"TIMESTAMP_SECONDS({timestamp})" - if scale == exp.UnixToTime.MILLIS: - return f"TIMESTAMP_MILLIS({timestamp})" - if scale == exp.UnixToTime.MICROS: - return f"TIMESTAMP_MICROS({timestamp})" - - raise ValueError("Improper scale for timestamp") - - -class Spark(Hive): - class Parser(Hive.Parser): +class Spark(Spark2): + class Parser(Spark2.Parser): FUNCTIONS = { - **Hive.Parser.FUNCTIONS, # type: ignore - "MAP_FROM_ARRAYS": exp.Map.from_arg_list, - "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, - "LEFT": lambda args: exp.Substring( - this=seq_get(args, 0), - start=exp.Literal.number(1), - length=seq_get(args, 1), - ), - "SHIFTLEFT": lambda args: exp.BitwiseLeftShift( - this=seq_get(args, 0), - expression=seq_get(args, 1), - ), - "SHIFTRIGHT": lambda args: exp.BitwiseRightShift( - this=seq_get(args, 0), - expression=seq_get(args, 1), - ), - "RIGHT": lambda args: exp.Substring( - this=seq_get(args, 0), - start=exp.Sub( - this=exp.Length(this=seq_get(args, 0)), - expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)), - ), - length=seq_get(args, 1), - ), - "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, - "BOOLEAN": lambda args: exp.Cast( - this=seq_get(args, 0), to=exp.DataType.build("boolean") - ), - "IIF": exp.If.from_arg_list, - "INT": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("int")), - "AGGREGATE": exp.Reduce.from_arg_list, - "DAYOFWEEK": lambda args: exp.DayOfWeek( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - ), - "DAYOFMONTH": lambda args: exp.DayOfMonth( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - ), - "DAYOFYEAR": lambda args: exp.DayOfYear( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - ), - "WEEKOFYEAR": lambda args: exp.WeekOfYear( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - ), - "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")), - "DATE_TRUNC": lambda args: exp.TimestampTrunc( - this=seq_get(args, 1), - unit=exp.var(seq_get(args, 0)), - ), - "STRING": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("string")), - "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), - "TIMESTAMP": lambda args: exp.Cast( - this=seq_get(args, 0), to=exp.DataType.build("timestamp") - ), - } - - FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, # type: ignore - "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), - "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), - "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), - "MERGE": lambda self: self._parse_join_hint("MERGE"), - "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"), - "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"), - "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"), - "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"), - } - - def _parse_add_column(self) -> t.Optional[exp.Expression]: - return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema() - - def _parse_drop_column(self) -> t.Optional[exp.Expression]: - return self._match_text_seq("DROP", "COLUMNS") and self.expression( - exp.Drop, - this=self._parse_schema(), - kind="COLUMNS", - ) - - def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]: - # Spark doesn't add a suffix to the pivot columns when there's a single aggregation - if len(pivot_columns) == 1: - return [""] - - names = [] - for agg in pivot_columns: - if isinstance(agg, exp.Alias): - names.append(agg.alias) - else: - """ - This case corresponds to aggregations without aliases being used as suffixes - (e.g. col_avg(foo)). We need to unquote identifiers because they're going to - be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. - Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). - - Moreover, function names are lowercased in order to mimic Spark's naming scheme. - """ - agg_all_unquoted = agg.transform( - lambda node: exp.Identifier(this=node.name, quoted=False) - if isinstance(node, exp.Identifier) - else node - ) - names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower")) - - return names - - class Generator(Hive.Generator): - TYPE_MAPPING = { - **Hive.Generator.TYPE_MAPPING, # type: ignore - exp.DataType.Type.TINYINT: "BYTE", - exp.DataType.Type.SMALLINT: "SHORT", - exp.DataType.Type.BIGINT: "LONG", - } - - PROPERTIES_LOCATION = { - **Hive.Generator.PROPERTIES_LOCATION, # type: ignore - exp.EngineProperty: exp.Properties.Location.UNSUPPORTED, - exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED, - exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED, - exp.CollateProperty: exp.Properties.Location.UNSUPPORTED, - } - - TRANSFORMS = { - **Hive.Generator.TRANSFORMS, # type: ignore - exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), - exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", - exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", - exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), - exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), - exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")), - exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", - exp.StrToDate: _str_to_date, - exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", - exp.UnixToTime: _unix_to_time_sql, - exp.Create: _create_sql, - exp.Map: _map_sql, - exp.Reduce: rename_func("AGGREGATE"), - exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}", - exp.TimestampTrunc: lambda self, e: self.func( - "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this - ), - exp.Trim: trim_sql, - exp.VariancePop: rename_func("VAR_POP"), - exp.DateFromParts: rename_func("MAKE_DATE"), - exp.LogicalOr: rename_func("BOOL_OR"), - exp.LogicalAnd: rename_func("BOOL_AND"), - exp.DayOfWeek: rename_func("DAYOFWEEK"), - exp.DayOfMonth: rename_func("DAYOFMONTH"), - exp.DayOfYear: rename_func("DAYOFYEAR"), - exp.WeekOfYear: rename_func("WEEKOFYEAR"), - exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", + **Spark2.Parser.FUNCTIONS, # type: ignore + "DATEDIFF": _parse_datediff, } - TRANSFORMS.pop(exp.ArraySort) - TRANSFORMS.pop(exp.ILike) - WRAP_DERIVED_VALUES = False - CREATE_FUNCTION_RETURN_AS = False + class Generator(Spark2.Generator): + TRANSFORMS = Spark2.Generator.TRANSFORMS.copy() + TRANSFORMS.pop(exp.DateDiff) - def cast_sql(self, expression: exp.Cast) -> str: - if isinstance(expression.this, exp.Cast) and expression.this.is_type( - exp.DataType.Type.JSON - ): - schema = f"'{self.sql(expression, 'to')}'" - return self.func("FROM_JSON", expression.this.this, schema) - if expression.to.is_type(exp.DataType.Type.JSON): - return self.func("TO_JSON", expression.this) + def datediff_sql(self, expression: exp.DateDiff) -> str: + unit = self.sql(expression, "unit") + end = self.sql(expression, "this") + start = self.sql(expression, "expression") - return super(Spark.Generator, self).cast_sql(expression) + if unit: + return self.func("DATEDIFF", unit, start, end) - class Tokenizer(Hive.Tokenizer): - HEX_STRINGS = [("X'", "'")] + return self.func("DATEDIFF", end, start) diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py new file mode 100644 index 0000000..584671f --- /dev/null +++ b/sqlglot/dialects/spark2.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp, parser, transforms +from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func, trim_sql +from sqlglot.dialects.hive import Hive +from sqlglot.helper import seq_get + + +def _create_sql(self: Hive.Generator, e: exp.Create) -> str: + kind = e.args["kind"] + properties = e.args.get("properties") + + if kind.upper() == "TABLE" and any( + isinstance(prop, exp.TemporaryProperty) + for prop in (properties.expressions if properties else []) + ): + return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}" + return create_with_partitions_sql(self, e) + + +def _map_sql(self: Hive.Generator, expression: exp.Map) -> str: + keys = self.sql(expression.args["keys"]) + values = self.sql(expression.args["values"]) + return f"MAP_FROM_ARRAYS({keys}, {values})" + + +def _parse_as_cast(to_type: str) -> t.Callable[[t.Sequence], exp.Expression]: + return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type)) + + +def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str: + this = self.sql(expression, "this") + time_format = self.format_time(expression) + if time_format == Hive.date_format: + return f"TO_DATE({this})" + return f"TO_DATE({this}, {time_format})" + + +def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str: + scale = expression.args.get("scale") + timestamp = self.sql(expression, "this") + if scale is None: + return f"CAST(FROM_UNIXTIME({timestamp}) AS TIMESTAMP)" + if scale == exp.UnixToTime.SECONDS: + return f"TIMESTAMP_SECONDS({timestamp})" + if scale == exp.UnixToTime.MILLIS: + return f"TIMESTAMP_MILLIS({timestamp})" + if scale == exp.UnixToTime.MICROS: + return f"TIMESTAMP_MICROS({timestamp})" + + raise ValueError("Improper scale for timestamp") + + +class Spark2(Hive): + class Parser(Hive.Parser): + FUNCTIONS = { + **Hive.Parser.FUNCTIONS, # type: ignore + "MAP_FROM_ARRAYS": exp.Map.from_arg_list, + "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, + "LEFT": lambda args: exp.Substring( + this=seq_get(args, 0), + start=exp.Literal.number(1), + length=seq_get(args, 1), + ), + "SHIFTLEFT": lambda args: exp.BitwiseLeftShift( + this=seq_get(args, 0), + expression=seq_get(args, 1), + ), + "SHIFTRIGHT": lambda args: exp.BitwiseRightShift( + this=seq_get(args, 0), + expression=seq_get(args, 1), + ), + "RIGHT": lambda args: exp.Substring( + this=seq_get(args, 0), + start=exp.Sub( + this=exp.Length(this=seq_get(args, 0)), + expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)), + ), + length=seq_get(args, 1), + ), + "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, + "IIF": exp.If.from_arg_list, + "AGGREGATE": exp.Reduce.from_arg_list, + "DAYOFWEEK": lambda args: exp.DayOfWeek( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + ), + "DAYOFMONTH": lambda args: exp.DayOfMonth( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + ), + "DAYOFYEAR": lambda args: exp.DayOfYear( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + ), + "WEEKOFYEAR": lambda args: exp.WeekOfYear( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + ), + "DATE": lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build("date")), + "DATE_TRUNC": lambda args: exp.TimestampTrunc( + this=seq_get(args, 1), + unit=exp.var(seq_get(args, 0)), + ), + "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), + "BOOLEAN": _parse_as_cast("boolean"), + "DOUBLE": _parse_as_cast("double"), + "FLOAT": _parse_as_cast("float"), + "INT": _parse_as_cast("int"), + "STRING": _parse_as_cast("string"), + "TIMESTAMP": _parse_as_cast("timestamp"), + } + + FUNCTION_PARSERS = { + **parser.Parser.FUNCTION_PARSERS, # type: ignore + "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), + "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), + "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), + "MERGE": lambda self: self._parse_join_hint("MERGE"), + "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"), + "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"), + "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"), + "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"), + } + + def _parse_add_column(self) -> t.Optional[exp.Expression]: + return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema() + + def _parse_drop_column(self) -> t.Optional[exp.Expression]: + return self._match_text_seq("DROP", "COLUMNS") and self.expression( + exp.Drop, + this=self._parse_schema(), + kind="COLUMNS", + ) + + def _pivot_column_names(self, pivot_columns: t.List[exp.Expression]) -> t.List[str]: + # Spark doesn't add a suffix to the pivot columns when there's a single aggregation + if len(pivot_columns) == 1: + return [""] + + names = [] + for agg in pivot_columns: + if isinstance(agg, exp.Alias): + names.append(agg.alias) + else: + """ + This case corresponds to aggregations without aliases being used as suffixes + (e.g. col_avg(foo)). We need to unquote identifiers because they're going to + be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. + Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). + + Moreover, function names are lowercased in order to mimic Spark's naming scheme. + """ + agg_all_unquoted = agg.transform( + lambda node: exp.Identifier(this=node.name, quoted=False) + if isinstance(node, exp.Identifier) + else node + ) + names.append(agg_all_unquoted.sql(dialect="spark", normalize_functions="lower")) + + return names + + class Generator(Hive.Generator): + TYPE_MAPPING = { + **Hive.Generator.TYPE_MAPPING, # type: ignore + exp.DataType.Type.TINYINT: "BYTE", + exp.DataType.Type.SMALLINT: "SHORT", + exp.DataType.Type.BIGINT: "LONG", + } + + PROPERTIES_LOCATION = { + **Hive.Generator.PROPERTIES_LOCATION, # type: ignore + exp.EngineProperty: exp.Properties.Location.UNSUPPORTED, + exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED, + exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED, + exp.CollateProperty: exp.Properties.Location.UNSUPPORTED, + } + + TRANSFORMS = { + **Hive.Generator.TRANSFORMS, # type: ignore + exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), + exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", + exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", + exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), + exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), + exp.Create: _create_sql, + exp.DateFromParts: rename_func("MAKE_DATE"), + exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")), + exp.DayOfMonth: rename_func("DAYOFMONTH"), + exp.DayOfWeek: rename_func("DAYOFWEEK"), + exp.DayOfYear: rename_func("DAYOFYEAR"), + exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", + exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", + exp.LogicalAnd: rename_func("BOOL_AND"), + exp.LogicalOr: rename_func("BOOL_OR"), + exp.Map: _map_sql, + exp.Pivot: transforms.preprocess([transforms.unqualify_pivot_columns]), + exp.Reduce: rename_func("AGGREGATE"), + exp.StrToDate: _str_to_date, + exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimestampTrunc: lambda self, e: self.func( + "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this + ), + exp.Trim: trim_sql, + exp.UnixToTime: _unix_to_time_sql, + exp.VariancePop: rename_func("VAR_POP"), + exp.WeekOfYear: rename_func("WEEKOFYEAR"), + exp.WithinGroup: transforms.preprocess( + [transforms.remove_within_group_for_percentiles] + ), + } + TRANSFORMS.pop(exp.ArrayJoin) + TRANSFORMS.pop(exp.ArraySort) + TRANSFORMS.pop(exp.ILike) + + WRAP_DERIVED_VALUES = False + CREATE_FUNCTION_RETURN_AS = False + + def cast_sql(self, expression: exp.Cast) -> str: + if isinstance(expression.this, exp.Cast) and expression.this.is_type( + exp.DataType.Type.JSON + ): + schema = f"'{self.sql(expression, 'to')}'" + return self.func("FROM_JSON", expression.this.this, schema) + if expression.to.is_type(exp.DataType.Type.JSON): + return self.func("TO_JSON", expression.this) + + return super(Hive.Generator, self).cast_sql(expression) + + def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: + return super().columndef_sql( + expression, + sep=": " + if isinstance(expression.parent, exp.DataType) + and expression.parent.is_type(exp.DataType.Type.STRUCT) + else sep, + ) + + class Tokenizer(Hive.Tokenizer): + HEX_STRINGS = [("X'", "'")] diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 4437f82..f2efe32 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -22,6 +22,40 @@ def _date_add_sql(self, expression): return self.func("DATE", expression.this, modifier) +def _transform_create(expression: exp.Expression) -> exp.Expression: + """Move primary key to a column and enforce auto_increment on primary keys.""" + schema = expression.this + + if isinstance(expression, exp.Create) and isinstance(schema, exp.Schema): + defs = {} + primary_key = None + + for e in schema.expressions: + if isinstance(e, exp.ColumnDef): + defs[e.name] = e + elif isinstance(e, exp.PrimaryKey): + primary_key = e + + if primary_key and len(primary_key.expressions) == 1: + column = defs[primary_key.expressions[0].name] + column.append( + "constraints", exp.ColumnConstraint(kind=exp.PrimaryKeyColumnConstraint()) + ) + schema.expressions.remove(primary_key) + else: + for column in defs.values(): + auto_increment = None + for constraint in column.constraints.copy(): + if isinstance(constraint.kind, exp.PrimaryKeyColumnConstraint): + break + if isinstance(constraint.kind, exp.AutoIncrementColumnConstraint): + auto_increment = constraint + if auto_increment: + column.constraints.remove(auto_increment) + + return expression + + class SQLite(Dialect): class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]"), "`"] @@ -65,8 +99,8 @@ class SQLite(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore - **transforms.ELIMINATE_QUALIFY, # type: ignore exp.CountIf: count_if_to_sum, + exp.Create: transforms.preprocess([_transform_create]), exp.CurrentDate: lambda *_: "CURRENT_DATE", exp.CurrentTime: lambda *_: "CURRENT_TIME", exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", @@ -80,14 +114,17 @@ class SQLite(Dialect): exp.Levenshtein: rename_func("EDITDIST3"), exp.LogicalOr: rename_func("MAX"), exp.LogicalAnd: rename_func("MIN"), + exp.Select: transforms.preprocess( + [transforms.eliminate_distinct_on, transforms.eliminate_qualify] + ), exp.TableSample: no_tablesample_sql, exp.TimeStrToTime: lambda self, e: self.sql(e, "this"), exp.TryCast: no_trycast_sql, } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore - exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + k: exp.Properties.Location.UNSUPPORTED + for k, v in generator.Generator.PROPERTIES_LOCATION.items() } LIMIT_FETCH = "LIMIT" diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index ff19dab..895588a 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -34,6 +34,7 @@ class StarRocks(MySQL): exp.JSONExtractScalar: arrow_json_extract_sql, exp.JSONExtract: arrow_json_extract_sql, exp.DateDiff: rename_func("DATEDIFF"), + exp.RegexpLike: rename_func("REGEXP"), exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimestampTrunc: lambda self, e: self.func( "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index 792c2b4..51e685b 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -1,6 +1,6 @@ from __future__ import annotations -from sqlglot import exp, generator, parser +from sqlglot import exp, generator, parser, transforms from sqlglot.dialects.dialect import Dialect @@ -29,6 +29,7 @@ class Tableau(Dialect): exp.If: _if_sql, exp.Coalesce: _coalesce_sql, exp.Count: _count_sql, + exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), } PROPERTIES_LOCATION = { diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 331e105..a79eaeb 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -2,7 +2,7 @@ from __future__ import annotations import typing as t -from sqlglot import exp, generator, parser, tokens +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, format_time_lambda, @@ -148,6 +148,7 @@ class Teradata(Dialect): **generator.Generator.TRANSFORMS, exp.Max: max_or_greatest, exp.Min: min_or_least, + exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.TimeToStr: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 9cf56e1..03de99c 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -3,7 +3,7 @@ from __future__ import annotations import re import typing as t -from sqlglot import exp, generator, parser, tokens +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, max_or_greatest, @@ -259,8 +259,8 @@ class TSQL(Dialect): class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]")] - QUOTES = ["'", '"'] + HEX_STRINGS = [("0x", ""), ("0X", "")] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -463,17 +463,18 @@ class TSQL(Dialect): exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), exp.CurrentTimestamp: rename_func("GETDATE"), - exp.If: rename_func("IIF"), - exp.NumberToStr: _format_sql, - exp.TimeToStr: _format_sql, exp.GroupConcat: _string_agg_sql, + exp.If: rename_func("IIF"), exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this), exp.Min: min_or_least, + exp.NumberToStr: _format_sql, + exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this), exp.SHA2: lambda self, e: self.func( "HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this ), + exp.TimeToStr: _format_sql, } TRANSFORMS.pop(exp.ReturnsProperty) |