diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 86 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 52 | ||||
-rw-r--r-- | sqlglot/dialects/databricks.py | 15 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 20 | ||||
-rw-r--r-- | sqlglot/dialects/doris.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/drill.py | 9 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 38 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 55 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 32 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 11 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 38 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 54 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 14 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 78 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/spark2.py | 31 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/teradata.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/trino.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 157 |
20 files changed, 537 insertions, 176 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 71977dd..d763ed0 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, format_time_lambda, inline_array_sql, + json_keyvalue_comma_sql, max_or_greatest, min_or_least, no_ilike_sql, @@ -29,8 +30,8 @@ logger = logging.getLogger("sqlglot") def _date_add_sql( data_type: str, kind: str -) -> t.Callable[[generator.Generator, exp.Expression], str]: - def func(self, expression): +) -> t.Callable[[BigQuery.Generator, exp.Expression], str]: + def func(self: BigQuery.Generator, expression: exp.Expression) -> str: this = self.sql(expression, "this") unit = expression.args.get("unit") unit = exp.var(unit.name.upper() if unit else "DAY") @@ -40,7 +41,7 @@ def _date_add_sql( return func -def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.Values) -> str: +def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Values) -> str: if not expression.find_ancestor(exp.From, exp.Join): return self.values_sql(expression) @@ -64,7 +65,7 @@ def _derived_table_values_to_unnest(self: generator.Generator, expression: exp.V return self.unnest_sql(exp.Unnest(expressions=[exp.Array(expressions=structs)])) -def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsProperty) -> str: +def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str: this = expression.this if isinstance(this, exp.Schema): this = f"{this.this} <{self.expressions(this)}>" @@ -73,7 +74,7 @@ def _returnsproperty_sql(self: generator.Generator, expression: exp.ReturnsPrope return f"RETURNS {this}" -def _create_sql(self: generator.Generator, expression: exp.Create) -> str: +def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str: kind = expression.args["kind"] returns = expression.find(exp.ReturnsProperty) @@ -94,14 +95,20 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression: These are added by the optimizer's qualify_column step. """ - from sqlglot.optimizer.scope import Scope + from sqlglot.optimizer.scope import find_all_in_scope if isinstance(expression, exp.Select): - for unnest in expression.find_all(exp.Unnest): - if isinstance(unnest.parent, (exp.From, exp.Join)) and unnest.alias: - for column in Scope(expression).find_all(exp.Column): - if column.table == unnest.alias: - column.set("table", None) + unnest_aliases = { + unnest.alias + for unnest in find_all_in_scope(expression, exp.Unnest) + if isinstance(unnest.parent, (exp.From, exp.Join)) + } + if unnest_aliases: + for column in expression.find_all(exp.Column): + if column.table in unnest_aliases: + column.set("table", None) + elif column.db in unnest_aliases: + column.set("db", None) return expression @@ -261,6 +268,7 @@ class BigQuery(Dialect): "TIMESTAMP": TokenType.TIMESTAMPTZ, "NOT DETERMINISTIC": TokenType.VOLATILE, "UNKNOWN": TokenType.NULL, + "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, } KEYWORDS.pop("DIV") @@ -270,6 +278,8 @@ class BigQuery(Dialect): LOG_BASE_FIRST = False LOG_DEFAULTS_TO_LN = True + SUPPORTS_USER_DEFINED_TYPES = False + FUNCTIONS = { **parser.Parser.FUNCTIONS, "DATE": _parse_date, @@ -299,6 +309,8 @@ class BigQuery(Dialect): if re.compile(str(seq_get(args, 1))).groups == 1 else None, ), + "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)), + "SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)), "SPLIT": lambda args: exp.Split( # https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#split this=seq_get(args, 0), @@ -346,7 +358,7 @@ class BigQuery(Dialect): } def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: - this = super()._parse_table_part(schema=schema) + this = super()._parse_table_part(schema=schema) or self._parse_number() # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#table_names if isinstance(this, exp.Identifier): @@ -356,6 +368,17 @@ class BigQuery(Dialect): table_name += f"-{self._prev.text}" this = exp.Identifier(this=table_name, quoted=this.args.get("quoted")) + elif isinstance(this, exp.Literal): + table_name = this.name + + if ( + self._curr + and self._prev.end == self._curr.start - 1 + and self._parse_var(any_token=True) + ): + table_name += self._prev.text + + this = exp.Identifier(this=table_name, quoted=True) return this @@ -374,6 +397,27 @@ class BigQuery(Dialect): return table + def _parse_json_object(self) -> exp.JSONObject: + json_object = super()._parse_json_object() + array_kv_pair = seq_get(json_object.expressions, 0) + + # Converts BQ's "signature 2" of JSON_OBJECT into SQLGlot's canonical representation + # https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#json_object_signature2 + if ( + array_kv_pair + and isinstance(array_kv_pair.this, exp.Array) + and isinstance(array_kv_pair.expression, exp.Array) + ): + keys = array_kv_pair.this.expressions + values = array_kv_pair.expression.expressions + + json_object.set( + "expressions", + [exp.JSONKeyValue(this=k, expression=v) for k, v in zip(keys, values)], + ) + + return json_object + class Generator(generator.Generator): EXPLICIT_UNION = True INTERVAL_ALLOWS_PLURAL_FORM = False @@ -383,6 +427,7 @@ class BigQuery(Dialect): LIMIT_FETCH = "LIMIT" RENAME_TABLE_WITH_DB = False ESCAPE_LINE_BREAK = True + NVL2_SUPPORTED = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -405,6 +450,7 @@ class BigQuery(Dialect): exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), exp.JSONFormat: rename_func("TO_JSON_STRING"), + exp.JSONKeyValue: json_keyvalue_comma_sql, exp.Max: max_or_greatest, exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)), exp.MD5Digest: rename_func("MD5"), @@ -428,6 +474,9 @@ class BigQuery(Dialect): _alias_ordered_group, ] ), + exp.SHA2: lambda self, e: self.func( + f"SHA256" if e.text("length") == "256" else "SHA512", e.this + ), exp.StabilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC", @@ -591,6 +640,13 @@ class BigQuery(Dialect): return super().attimezone_sql(expression) + def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: + # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#json_literals + if expression.is_type("json"): + return f"JSON {self.sql(expression, 'this')}" + + return super().cast_sql(expression, safe_prefix=safe_prefix) + def trycast_sql(self, expression: exp.TryCast) -> str: return self.cast_sql(expression, safe_prefix="SAFE_") @@ -630,3 +686,9 @@ class BigQuery(Dialect): def with_properties(self, properties: exp.Properties) -> str: return self.properties(properties, prefix=self.seg("OPTIONS")) + + def version_sql(self, expression: exp.Version) -> str: + if expression.name == "TIMESTAMP": + expression = expression.copy() + expression.set("this", "SYSTEM_TIME") + return super().version_sql(expression) diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index cfde5fd..a38a239 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import ( var_map_sql, ) from sqlglot.errors import ParseError +from sqlglot.helper import seq_get from sqlglot.parser import parse_var_map from sqlglot.tokens import Token, TokenType @@ -63,9 +64,23 @@ class ClickHouse(Dialect): } class Parser(parser.Parser): + SUPPORTS_USER_DEFINED_TYPES = False + FUNCTIONS = { **parser.Parser.FUNCTIONS, "ANY": exp.AnyValue.from_arg_list, + "DATE_ADD": lambda args: exp.DateAdd( + this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) + ), + "DATEADD": lambda args: exp.DateAdd( + this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) + ), + "DATE_DIFF": lambda args: exp.DateDiff( + this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) + ), + "DATEDIFF": lambda args: exp.DateDiff( + this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) + ), "MAP": parse_var_map, "MATCH": exp.RegexpLike.from_arg_list, "UNIQ": exp.ApproxDistinct.from_arg_list, @@ -147,7 +162,7 @@ class ClickHouse(Dialect): this = self._parse_id_var() self._match(TokenType.COLON) - kind = self._parse_types(check_func=False) or ( + kind = self._parse_types(check_func=False, allow_identifiers=False) or ( self._match_text_seq("IDENTIFIER") and "Identifier" ) @@ -249,7 +264,7 @@ class ClickHouse(Dialect): def _parse_func_params( self, this: t.Optional[exp.Func] = None - ) -> t.Optional[t.List[t.Optional[exp.Expression]]]: + ) -> t.Optional[t.List[exp.Expression]]: if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN): return self._parse_csv(self._parse_lambda) @@ -267,9 +282,7 @@ class ClickHouse(Dialect): return self.expression(exp.Quantile, this=params[0], quantile=this) return self.expression(exp.Quantile, this=this, quantile=exp.Literal.number(0.5)) - def _parse_wrapped_id_vars( - self, optional: bool = False - ) -> t.List[t.Optional[exp.Expression]]: + def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[exp.Expression]: return super()._parse_wrapped_id_vars(optional=True) def _parse_primary_key( @@ -292,9 +305,22 @@ class ClickHouse(Dialect): class Generator(generator.Generator): QUERY_HINTS = False STRUCT_DELIMITER = ("(", ")") + NVL2_SUPPORTED = False + + STRING_TYPE_MAPPING = { + exp.DataType.Type.CHAR: "String", + exp.DataType.Type.LONGBLOB: "String", + exp.DataType.Type.LONGTEXT: "String", + exp.DataType.Type.MEDIUMBLOB: "String", + exp.DataType.Type.MEDIUMTEXT: "String", + exp.DataType.Type.TEXT: "String", + exp.DataType.Type.VARBINARY: "String", + exp.DataType.Type.VARCHAR: "String", + } TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, + **STRING_TYPE_MAPPING, exp.DataType.Type.ARRAY: "Array", exp.DataType.Type.BIGINT: "Int64", exp.DataType.Type.DATETIME64: "DateTime64", @@ -328,6 +354,12 @@ class ClickHouse(Dialect): exp.ApproxDistinct: rename_func("uniq"), exp.Array: inline_array_sql, exp.CastToStrType: rename_func("CAST"), + exp.DateAdd: lambda self, e: self.func( + "DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this + ), + exp.DateDiff: lambda self, e: self.func( + "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this + ), exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", @@ -364,6 +396,16 @@ class ClickHouse(Dialect): "NAMED COLLECTION", } + def datatype_sql(self, expression: exp.DataType) -> str: + # String is the standard ClickHouse type, every other variant is just an alias. + # Additionally, any supplied length parameter will be ignored. + # + # https://clickhouse.com/docs/en/sql-reference/data-types/string + if expression.this in self.STRING_TYPE_MAPPING: + return "String" + + return super().datatype_sql(expression) + def safeconcat_sql(self, expression: exp.SafeConcat) -> str: # Clickhouse errors out if we try to cast a NULL value to TEXT expression = expression.copy() diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 2149aca..6ec0487 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -1,7 +1,7 @@ from __future__ import annotations from sqlglot import exp, transforms -from sqlglot.dialects.dialect import parse_date_delta +from sqlglot.dialects.dialect import parse_date_delta, timestamptrunc_sql from sqlglot.dialects.spark import Spark from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql from sqlglot.tokens import TokenType @@ -28,6 +28,19 @@ class Databricks(Spark): **Spark.Generator.TRANSFORMS, exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, + exp.DatetimeAdd: lambda self, e: self.func( + "TIMESTAMPADD", e.text("unit"), e.expression, e.this + ), + exp.DatetimeSub: lambda self, e: self.func( + "TIMESTAMPADD", + e.text("unit"), + exp.Mul(this=e.expression.copy(), expression=exp.Literal.number(-1)), + e.this, + ), + exp.DatetimeDiff: lambda self, e: self.func( + "TIMESTAMPDIFF", e.text("unit"), e.expression, e.this + ), + exp.DatetimeTrunc: timestamptrunc_sql, exp.JSONExtract: lambda self, e: self.binary(e, ":"), exp.Select: transforms.preprocess( [ diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 132496f..1bfbfef 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -109,8 +109,7 @@ class _Dialect(type): for k, v in vars(klass).items() if not callable(v) and not isinstance(v, classmethod) and not k.startswith("__") }, - "STRING_ESCAPE": klass.tokenizer_class.STRING_ESCAPES[0], - "IDENTIFIER_ESCAPE": klass.tokenizer_class.IDENTIFIER_ESCAPES[0], + "TOKENIZER_CLASS": klass.tokenizer_class, } if enum not in ("", "bigquery"): @@ -345,7 +344,7 @@ def arrow_json_extract_scalar_sql( def inline_array_sql(self: Generator, expression: exp.Array) -> str: - return f"[{self.expressions(expression)}]" + return f"[{self.expressions(expression, flat=True)}]" def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: @@ -415,9 +414,9 @@ def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: - this = self.sql(expression, "this") - struct_key = self.sql(exp.Identifier(this=expression.expression.copy(), quoted=True)) - return f"{this}.{struct_key}" + return ( + f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" + ) def var_map_sql( @@ -722,3 +721,12 @@ def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: # Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) + + +def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: + return self.func("MAX", expression.this) + + +# Used to generate JSON_OBJECT with a comma in BigQuery and MySQL instead of colon +def json_keyvalue_comma_sql(self, expression: exp.JSONKeyValue) -> str: + return f"{self.sql(expression, 'this')}, {self.sql(expression, 'expression')}" diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py index 160c23c..4b8919c 100644 --- a/sqlglot/dialects/doris.py +++ b/sqlglot/dialects/doris.py @@ -37,7 +37,6 @@ class Doris(MySQL): **MySQL.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, exp.ArrayAgg: rename_func("COLLECT_LIST"), - exp.Coalesce: rename_func("NVL"), exp.CurrentTimestamp: lambda *_: "NOW()", exp.DateTrunc: lambda self, e: self.func( "DATE_TRUNC", e.this, "'" + e.text("unit") + "'" diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 1b2681d..c811c86 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -16,8 +16,8 @@ from sqlglot.dialects.dialect import ( ) -def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: - def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _date_add_sql(kind: str) -> t.Callable[[Drill.Generator, exp.DateAdd | exp.DateSub], str]: + def func(self: Drill.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = exp.var(expression.text("unit").upper() or "DAY") return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})" @@ -25,7 +25,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e return func -def _str_to_date(self: generator.Generator, expression: exp.StrToDate) -> str: +def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format == Drill.DATE_FORMAT: @@ -73,7 +73,6 @@ class Drill(Dialect): } class Tokenizer(tokens.Tokenizer): - QUOTES = ["'"] IDENTIFIERS = ["`"] STRING_ESCAPES = ["\\"] ENCODE = "utf-8" @@ -81,6 +80,7 @@ class Drill(Dialect): class Parser(parser.Parser): STRICT_CAST = False CONCAT_NULL_OUTPUTS_STRING = True + SUPPORTS_USER_DEFINED_TYPES = False FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -95,6 +95,7 @@ class Drill(Dialect): JOIN_HINTS = False TABLE_HINTS = False QUERY_HINTS = False + NVL2_SUPPORTED = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 8253b52..684e35e 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -13,6 +13,7 @@ from sqlglot.dialects.dialect import ( datestrtodate_sql, encode_decode_sql, format_time_lambda, + inline_array_sql, no_comment_column_constraint_sql, no_properties_sql, no_safe_divide_sql, @@ -30,13 +31,13 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str: +def _ts_or_ds_add_sql(self: DuckDB.Generator, expression: exp.TsOrDsAdd) -> str: this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" return f"CAST({this} AS DATE) + {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))}" -def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _date_delta_sql(self: DuckDB.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = self.sql(expression, "unit").strip("'") or "DAY" op = "+" if isinstance(expression, exp.DateAdd) else "-" @@ -44,7 +45,7 @@ def _date_delta_sql(self: generator.Generator, expression: exp.DateAdd | exp.Dat # BigQuery -> DuckDB conversion for the DATE function -def _date_sql(self: generator.Generator, expression: exp.Date) -> str: +def _date_sql(self: DuckDB.Generator, expression: exp.Date) -> str: result = f"CAST({self.sql(expression, 'this')} AS DATE)" zone = self.sql(expression, "zone") @@ -58,13 +59,13 @@ def _date_sql(self: generator.Generator, expression: exp.Date) -> str: return result -def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str: +def _array_sort_sql(self: DuckDB.Generator, expression: exp.ArraySort) -> str: if expression.expression: self.unsupported("DUCKDB ARRAY_SORT does not support a comparator") return f"ARRAY_SORT({self.sql(expression, 'this')})" -def _sort_array_sql(self: generator.Generator, expression: exp.SortArray) -> str: +def _sort_array_sql(self: DuckDB.Generator, expression: exp.SortArray) -> str: this = self.sql(expression, "this") if expression.args.get("asc") == exp.false(): return f"ARRAY_REVERSE_SORT({this})" @@ -79,14 +80,14 @@ def _parse_date_diff(args: t.List) -> exp.Expression: return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) -def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str: +def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str: args = [ f"'{e.name or e.this.name}': {self.sql(e, 'expression')}" for e in expression.expressions ] return f"{{{', '.join(args)}}}" -def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: +def _datatype_sql(self: DuckDB.Generator, expression: exp.DataType) -> str: if expression.is_type("array"): return f"{self.expressions(expression, flat=True)}[]" @@ -97,7 +98,7 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: return self.datatype_sql(expression) -def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str: +def _json_format_sql(self: DuckDB.Generator, expression: exp.JSONFormat) -> str: sql = self.func("TO_JSON", expression.this, expression.args.get("options")) return f"CAST({sql} AS TEXT)" @@ -134,6 +135,7 @@ class DuckDB(Dialect): class Parser(parser.Parser): CONCAT_NULL_OUTPUTS_STRING = True + SUPPORTS_USER_DEFINED_TYPES = False BITWISE = { **parser.Parser.BITWISE, @@ -183,18 +185,12 @@ class DuckDB(Dialect): ), } - TYPE_TOKENS = { - *parser.Parser.TYPE_TOKENS, - TokenType.UBIGINT, - TokenType.UINT, - TokenType.USMALLINT, - TokenType.UTINYINT, - } - def _parse_types( - self, check_func: bool = False, schema: bool = False + self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True ) -> t.Optional[exp.Expression]: - this = super()._parse_types(check_func=check_func, schema=schema) + this = super()._parse_types( + check_func=check_func, schema=schema, allow_identifiers=allow_identifiers + ) # DuckDB treats NUMERIC and DECIMAL without precision as DECIMAL(18, 3) # See: https://duckdb.org/docs/sql/data_types/numeric @@ -207,6 +203,9 @@ class DuckDB(Dialect): return this + def _parse_struct_types(self) -> t.Optional[exp.Expression]: + return self._parse_field_def() + def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: if len(aggregations) == 1: return super()._pivot_column_names(aggregations) @@ -219,13 +218,14 @@ class DuckDB(Dialect): LIMIT_FETCH = "LIMIT" STRUCT_DELIMITER = ("(", ")") RENAME_TABLE_WITH_DB = False + NVL2_SUPPORTED = False TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, exp.Array: lambda self, e: self.func("ARRAY", e.expressions[0]) if e.expressions and e.expressions[0].find(exp.Select) - else rename_func("LIST_VALUE")(self, e), + else inline_array_sql(self, e), exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySort: _array_sort_sql, exp.ArraySum: rename_func("LIST_SUM"), diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 584acc6..8b17c06 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -50,7 +50,7 @@ TIME_DIFF_FACTOR = { DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") -def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -> str: unit = expression.text("unit").upper() func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) @@ -69,7 +69,7 @@ def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateS return self.func(func, expression.this, modified_increment) -def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: +def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff) -> str: unit = expression.text("unit").upper() factor = TIME_DIFF_FACTOR.get(unit) @@ -87,7 +87,7 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: return f"{diff_sql}{multiplier_sql}" -def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str: +def _json_format_sql(self: Hive.Generator, expression: exp.JSONFormat) -> str: this = expression.this if isinstance(this, exp.Cast) and this.is_type("json") and this.this.is_string: # Since FROM_JSON requires a nested type, we always wrap the json string with @@ -103,21 +103,21 @@ def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> s return self.func("TO_JSON", this, expression.args.get("options")) -def _array_sort_sql(self: generator.Generator, expression: exp.ArraySort) -> str: +def _array_sort_sql(self: Hive.Generator, expression: exp.ArraySort) -> str: if expression.expression: self.unsupported("Hive SORT_ARRAY does not support a comparator") return f"SORT_ARRAY({self.sql(expression, 'this')})" -def _property_sql(self: generator.Generator, expression: exp.Property) -> str: +def _property_sql(self: Hive.Generator, expression: exp.Property) -> str: return f"'{expression.name}'={self.sql(expression, 'value')}" -def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str: +def _str_to_unix_sql(self: Hive.Generator, expression: exp.StrToUnix) -> str: return self.func("UNIX_TIMESTAMP", expression.this, time_format("hive")(self, expression)) -def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str: +def _str_to_date_sql(self: Hive.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): @@ -125,7 +125,7 @@ def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> st return f"CAST({this} AS DATE)" -def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> str: +def _str_to_time_sql(self: Hive.Generator, expression: exp.StrToTime) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): @@ -133,13 +133,13 @@ def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> st return f"CAST({this} AS TIMESTAMP)" -def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str: +def _time_to_str(self: Hive.Generator, expression: exp.TimeToStr) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) return f"DATE_FORMAT({this}, {time_format})" -def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: +def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): @@ -206,6 +206,8 @@ class Hive(Dialect): "MSCK REPAIR": TokenType.COMMAND, "REFRESH": TokenType.COMMAND, "WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES, + "TIMESTAMP AS OF": TokenType.TIMESTAMP_SNAPSHOT, + "VERSION AS OF": TokenType.VERSION_SNAPSHOT, } NUMERIC_LITERALS = { @@ -220,6 +222,7 @@ class Hive(Dialect): class Parser(parser.Parser): LOG_DEFAULTS_TO_LN = True STRICT_CAST = False + SUPPORTS_USER_DEFINED_TYPES = False FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -257,6 +260,11 @@ class Hive(Dialect): ), "SIZE": exp.ArraySize.from_arg_list, "SPLIT": exp.RegexpSplit.from_arg_list, + "STR_TO_MAP": lambda args: exp.StrToMap( + this=seq_get(args, 0), + pair_delim=seq_get(args, 1) or exp.Literal.string(","), + key_value_delim=seq_get(args, 2) or exp.Literal.string(":"), + ), "TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"), "TO_JSON": exp.JSONFormat.from_arg_list, "UNBASE64": exp.FromBase64.from_arg_list, @@ -313,7 +321,7 @@ class Hive(Dialect): ) def _parse_types( - self, check_func: bool = False, schema: bool = False + self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True ) -> t.Optional[exp.Expression]: """ Spark (and most likely Hive) treats casts to CHAR(length) and VARCHAR(length) as casts to @@ -333,7 +341,9 @@ class Hive(Dialect): Reference: https://spark.apache.org/docs/latest/sql-ref-datatypes.html """ - this = super()._parse_types(check_func=check_func, schema=schema) + this = super()._parse_types( + check_func=check_func, schema=schema, allow_identifiers=allow_identifiers + ) if this and not schema: return this.transform( @@ -345,6 +355,16 @@ class Hive(Dialect): return this + def _parse_partition_and_order( + self, + ) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]: + return ( + self._parse_csv(self._parse_conjunction) + if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY}) + else [], + super()._parse_order(skip_order_token=self._match(TokenType.SORT_BY)), + ) + class Generator(generator.Generator): LIMIT_FETCH = "LIMIT" TABLESAMPLE_WITH_METHOD = False @@ -354,6 +374,7 @@ class Hive(Dialect): QUERY_HINTS = False INDEX_ON = "ON TABLE" EXTRACT_ALLOWS_QUOTES = False + NVL2_SUPPORTED = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -376,6 +397,7 @@ class Hive(Dialect): ] ), exp.Property: _property_sql, + exp.AnyValue: rename_func("FIRST"), exp.ApproxDistinct: approx_count_distinct_sql, exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this), @@ -402,6 +424,9 @@ class Hive(Dialect): exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)), exp.Min: min_or_least, exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression), + exp.NotNullColumnConstraint: lambda self, e: "" + if e.args.get("allow_null") + else "NOT NULL", exp.VarMap: var_map_sql, exp.Create: create_with_partitions_sql, exp.Quantile: rename_func("PERCENTILE"), @@ -472,7 +497,7 @@ class Hive(Dialect): elif expression.this in exp.DataType.TEMPORAL_TYPES: expression = exp.DataType.build(expression.this) elif expression.is_type("float"): - size_expression = expression.find(exp.DataTypeSize) + size_expression = expression.find(exp.DataTypeParam) if size_expression: size = int(size_expression.name) expression = ( @@ -480,3 +505,7 @@ class Hive(Dialect): ) return super().datatype_sql(expression) + + def version_sql(self, expression: exp.Version) -> str: + sql = super().version_sql(expression) + return sql.replace("FOR ", "", 1) diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 9ab4ce8..f9249eb 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -8,6 +8,7 @@ from sqlglot.dialects.dialect import ( arrow_json_extract_scalar_sql, datestrtodate_sql, format_time_lambda, + json_keyvalue_comma_sql, locate_to_strposition, max_or_greatest, min_or_least, @@ -32,7 +33,7 @@ def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], ex return _parse -def _date_trunc_sql(self: generator.Generator, expression: exp.DateTrunc) -> str: +def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str: expr = self.sql(expression, "this") unit = expression.text("unit") @@ -63,12 +64,12 @@ def _str_to_date(args: t.List) -> exp.StrToDate: return exp.StrToDate(this=seq_get(args, 0), format=date_format) -def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate | exp.StrToTime) -> str: +def _str_to_date_sql(self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime) -> str: date_format = self.format_time(expression) return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})" -def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str: +def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str: target = self.sql(expression, "this") trim_type = self.sql(expression, "position") remove_chars = self.sql(expression, "expression") @@ -83,8 +84,8 @@ def _trim_sql(self: generator.Generator, expression: exp.Trim) -> str: return f"TRIM({trim_type}{remove_chars}{from_part}{target})" -def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: - def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.DateSub], str]: + def func(self: MySQL.Generator, expression: exp.DateAdd | exp.DateSub) -> str: this = self.sql(expression, "this") unit = expression.text("unit").upper() or "DAY" return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})" @@ -93,6 +94,9 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e class MySQL(Dialect): + # https://dev.mysql.com/doc/refman/8.0/en/identifiers.html + IDENTIFIERS_CAN_START_WITH_DIGIT = True + TIME_FORMAT = "'%Y-%m-%d %T'" DPIPE_IS_STRING_CONCAT = False @@ -129,6 +133,7 @@ class MySQL(Dialect): "LONGTEXT": TokenType.LONGTEXT, "MEDIUMBLOB": TokenType.MEDIUMBLOB, "MEDIUMTEXT": TokenType.MEDIUMTEXT, + "MEDIUMINT": TokenType.MEDIUMINT, "MEMBER OF": TokenType.MEMBER_OF, "SEPARATOR": TokenType.SEPARATOR, "START": TokenType.BEGIN, @@ -136,6 +141,7 @@ class MySQL(Dialect): "SIGNED INTEGER": TokenType.BIGINT, "UNSIGNED": TokenType.UBIGINT, "UNSIGNED INTEGER": TokenType.UBIGINT, + "YEAR": TokenType.YEAR, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, "_BIG5": TokenType.INTRODUCER, @@ -185,6 +191,8 @@ class MySQL(Dialect): COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW} class Parser(parser.Parser): + SUPPORTS_USER_DEFINED_TYPES = False + FUNC_TOKENS = { *parser.Parser.FUNC_TOKENS, TokenType.DATABASE, @@ -492,6 +500,17 @@ class MySQL(Dialect): return self.expression(exp.SetItem, this=charset, collate=collate, kind="NAMES") + def _parse_type(self) -> t.Optional[exp.Expression]: + # mysql binary is special and can work anywhere, even in order by operations + # it operates like a no paren func + if self._match(TokenType.BINARY, advance=False): + data_type = self._parse_types(check_func=True, allow_identifiers=False) + + if isinstance(data_type, exp.DataType): + return self.expression(exp.Cast, this=self._parse_column(), to=data_type) + + return super()._parse_type() + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = False @@ -500,6 +519,7 @@ class MySQL(Dialect): DUPLICATE_KEY_UPDATE_WITH_SET = False QUERY_HINT_SEP = " " VALUES_AS_TABLE = False + NVL2_SUPPORTED = False TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -515,6 +535,7 @@ class MySQL(Dialect): exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""", exp.ILike: no_ilike_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.JSONKeyValue: json_keyvalue_comma_sql, exp.Max: max_or_greatest, exp.Min: min_or_least, exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), @@ -524,6 +545,7 @@ class MySQL(Dialect): exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date_sql, exp.StrToTime: _str_to_date_sql, + exp.Stuff: rename_func("INSERT"), exp.TableSample: no_tablesample_sql, exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)), diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 1f63e9f..279ed31 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -8,7 +8,7 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _parse_xml_table(self: parser.Parser) -> exp.XMLTable: +def _parse_xml_table(self: Oracle.Parser) -> exp.XMLTable: this = self._parse_string() passing = None @@ -22,7 +22,7 @@ def _parse_xml_table(self: parser.Parser) -> exp.XMLTable: by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF") if self._match_text_seq("COLUMNS"): - columns = self._parse_csv(lambda: self._parse_column_def(self._parse_field(any_token=True))) + columns = self._parse_csv(self._parse_field_def) return self.expression(exp.XMLTable, this=this, passing=passing, columns=columns, by_ref=by_ref) @@ -78,6 +78,10 @@ class Oracle(Dialect): ) } + # SELECT UNIQUE .. is old-style Oracle syntax for SELECT DISTINCT .. + # Reference: https://stackoverflow.com/a/336455 + DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE} + def _parse_column(self) -> t.Optional[exp.Expression]: column = super()._parse_column() if column: @@ -129,7 +133,6 @@ class Oracle(Dialect): ), exp.Group: transforms.preprocess([transforms.unalias_group]), exp.ILike: no_ilike_sql, - exp.Coalesce: rename_func("NVL"), exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "), @@ -162,7 +165,7 @@ class Oracle(Dialect): return f"XMLTABLE({self.sep('')}{self.indent(this + passing + by_ref + columns)}{self.seg(')', sep='')}" class Tokenizer(tokens.Tokenizer): - VAR_SINGLE_TOKENS = {"@"} + VAR_SINGLE_TOKENS = {"@", "$", "#"} KEYWORDS = { **tokens.Tokenizer.KEYWORDS, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 73ca4e5..c26e121 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -5,6 +5,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + any_value_to_max_sql, arrow_json_extract_scalar_sql, arrow_json_extract_sql, datestrtodate_sql, @@ -39,8 +40,8 @@ DATE_DIFF_FACTOR = { } -def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: - def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | exp.DateSub], str]: + def func(self: Postgres.Generator, expression: exp.DateAdd | exp.DateSub) -> str: expression = expression.copy() this = self.sql(expression, "this") @@ -56,7 +57,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e return func -def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: +def _date_diff_sql(self: Postgres.Generator, expression: exp.DateDiff) -> str: unit = expression.text("unit").upper() factor = DATE_DIFF_FACTOR.get(unit) @@ -82,7 +83,7 @@ def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: return f"CAST({unit} AS BIGINT)" -def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str: +def _substring_sql(self: Postgres.Generator, expression: exp.Substring) -> str: this = self.sql(expression, "this") start = self.sql(expression, "start") length = self.sql(expression, "length") @@ -93,7 +94,7 @@ def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str: return f"SUBSTRING({this}{from_part}{for_part})" -def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str: +def _string_agg_sql(self: Postgres.Generator, expression: exp.GroupConcat) -> str: expression = expression.copy() separator = expression.args.get("separator") or exp.Literal.string(",") @@ -107,7 +108,7 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s return f"STRING_AGG({self.format_args(this, separator)}{order})" -def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: +def _datatype_sql(self: Postgres.Generator, expression: exp.DataType) -> str: if expression.is_type("array"): return f"{self.expressions(expression, flat=True)}[]" return self.datatype_sql(expression) @@ -254,6 +255,7 @@ class Postgres(Dialect): "~~*": TokenType.ILIKE, "~*": TokenType.IRLIKE, "~": TokenType.RLIKE, + "@@": TokenType.DAT, "@>": TokenType.AT_GT, "<@": TokenType.LT_AT, "BEGIN": TokenType.COMMAND, @@ -273,6 +275,18 @@ class Postgres(Dialect): "SMALLSERIAL": TokenType.SMALLSERIAL, "TEMP": TokenType.TEMPORARY, "CSTRING": TokenType.PSEUDO_TYPE, + "OID": TokenType.OBJECT_IDENTIFIER, + "REGCLASS": TokenType.OBJECT_IDENTIFIER, + "REGCOLLATION": TokenType.OBJECT_IDENTIFIER, + "REGCONFIG": TokenType.OBJECT_IDENTIFIER, + "REGDICTIONARY": TokenType.OBJECT_IDENTIFIER, + "REGNAMESPACE": TokenType.OBJECT_IDENTIFIER, + "REGOPER": TokenType.OBJECT_IDENTIFIER, + "REGOPERATOR": TokenType.OBJECT_IDENTIFIER, + "REGPROC": TokenType.OBJECT_IDENTIFIER, + "REGPROCEDURE": TokenType.OBJECT_IDENTIFIER, + "REGROLE": TokenType.OBJECT_IDENTIFIER, + "REGTYPE": TokenType.OBJECT_IDENTIFIER, } SINGLE_TOKENS = { @@ -312,6 +326,9 @@ class Postgres(Dialect): RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps), + TokenType.DAT: lambda self, this: self.expression( + exp.MatchAgainst, this=self._parse_bitwise(), expressions=[this] + ), TokenType.AT_GT: binary_range_parser(exp.ArrayContains), TokenType.LT_AT: binary_range_parser(exp.ArrayContained), } @@ -343,6 +360,7 @@ class Postgres(Dialect): JOIN_HINTS = False TABLE_HINTS = False QUERY_HINTS = False + NVL2_SUPPORTED = False PARAMETER_TOKEN = "$" TYPE_MAPPING = { @@ -357,6 +375,8 @@ class Postgres(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.AnyValue: any_value_to_max_sql, + exp.ArrayConcat: rename_func("ARRAY_CAT"), exp.BitwiseXor: lambda self, e: self.binary(e, "#"), exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]), exp.Explode: rename_func("UNNEST"), @@ -416,3 +436,9 @@ class Postgres(Dialect): expression.set("this", exp.paren(expression.this, copy=False)) return super().bracket_sql(expression) + + def matchagainst_sql(self, expression: exp.MatchAgainst) -> str: + this = self.sql(expression, "this") + expressions = [f"{self.sql(e)} @@ {this}" for e in expression.expressions] + sql = " OR ".join(expressions) + return f"({sql})" if len(expressions) > 1 else sql diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 078da0b..4b54e95 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -26,13 +26,13 @@ from sqlglot.helper import apply_index_offset, seq_get from sqlglot.tokens import TokenType -def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistinct) -> str: +def _approx_distinct_sql(self: Presto.Generator, expression: exp.ApproxDistinct) -> str: accuracy = expression.args.get("accuracy") accuracy = ", " + self.sql(accuracy) if accuracy else "" return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})" -def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str: +def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str: if isinstance(expression.this, (exp.Explode, exp.Posexplode)): expression = expression.copy() return self.sql( @@ -48,12 +48,12 @@ def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) - return self.lateral_sql(expression) -def _initcap_sql(self: generator.Generator, expression: exp.Initcap) -> str: +def _initcap_sql(self: Presto.Generator, expression: exp.Initcap) -> str: regex = r"(\w)(\w*)" return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" -def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str: +def _no_sort_array(self: Presto.Generator, expression: exp.SortArray) -> str: if expression.args.get("asc") == exp.false(): comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" else: @@ -61,7 +61,7 @@ def _no_sort_array(self: generator.Generator, expression: exp.SortArray) -> str: return self.func("ARRAY_SORT", expression.this, comparator) -def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str: +def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str: if isinstance(expression.parent, exp.Property): columns = ", ".join(f"'{c.name}'" for c in expression.expressions) return f"ARRAY[{columns}]" @@ -75,25 +75,25 @@ def _schema_sql(self: generator.Generator, expression: exp.Schema) -> str: return self.schema_sql(expression) -def _quantile_sql(self: generator.Generator, expression: exp.Quantile) -> str: +def _quantile_sql(self: Presto.Generator, expression: exp.Quantile) -> str: self.unsupported("Presto does not support exact quantiles") return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})" def _str_to_time_sql( - self: generator.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate + self: Presto.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate ) -> str: return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})" -def _ts_or_ds_to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str: +def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) -> str: time_format = self.format_time(expression) if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT): return exp.cast(_str_to_time_sql(self, expression), "DATE").sql(dialect="presto") return exp.cast(exp.cast(expression.this, "TIMESTAMP", copy=True), "DATE").sql(dialect="presto") -def _ts_or_ds_add_sql(self: generator.Generator, expression: exp.TsOrDsAdd) -> str: +def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str: this = expression.this if not isinstance(this, exp.CurrentDate): @@ -153,6 +153,20 @@ def _unnest_sequence(expression: exp.Expression) -> exp.Expression: return expression +def _first_last_sql(self: Presto.Generator, expression: exp.First | exp.Last) -> str: + """ + Trino doesn't support FIRST / LAST as functions, but they're valid in the context + of MATCH_RECOGNIZE, so we need to preserve them in that case. In all other cases + they're converted into an ARBITRARY call. + + Reference: https://trino.io/docs/current/sql/match-recognize.html#logical-navigation-functions + """ + if isinstance(expression.find_ancestor(exp.MatchRecognize, exp.Select), exp.MatchRecognize): + return self.function_fallback_sql(expression) + + return rename_func("ARBITRARY")(self, expression) + + class Presto(Dialect): INDEX_OFFSET = 1 NULL_ORDERING = "nulls_are_last" @@ -178,6 +192,7 @@ class Presto(Dialect): class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, + "ARBITRARY": exp.AnyValue.from_arg_list, "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, "APPROX_PERCENTILE": _approx_percentile, "BITWISE_AND": binary_from_function(exp.BitwiseAnd), @@ -205,7 +220,14 @@ class Presto(Dialect): "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2) ), + "REGEXP_REPLACE": lambda args: exp.RegexpReplace( + this=seq_get(args, 0), + expression=seq_get(args, 1), + replacement=seq_get(args, 2) or exp.Literal.string(""), + ), + "ROW": exp.Struct.from_arg_list, "SEQUENCE": exp.GenerateSeries.from_arg_list, + "SPLIT_TO_MAP": exp.StrToMap.from_arg_list, "STRPOS": lambda args: exp.StrPosition( this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2) ), @@ -225,6 +247,7 @@ class Presto(Dialect): QUERY_HINTS = False IS_BOOL_ALLOWED = False TZ_TO_WITH_TIME_ZONE = True + NVL2_SUPPORTED = False STRUCT_DELIMITER = ("(", ")") PROPERTIES_LOCATION = { @@ -242,10 +265,13 @@ class Presto(Dialect): exp.DataType.Type.TIMETZ: "TIME", exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", exp.DataType.Type.STRUCT: "ROW", + exp.DataType.Type.DATETIME: "TIMESTAMP", + exp.DataType.Type.DATETIME64: "TIMESTAMP", } TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.AnyValue: rename_func("ARBITRARY"), exp.ApproxDistinct: _approx_distinct_sql, exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", @@ -268,15 +294,23 @@ class Presto(Dialect): ), exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.DATE_FORMAT}) AS DATE)", exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", + exp.DateSub: lambda self, e: self.func( + "DATE_ADD", + exp.Literal.string(e.text("unit") or "day"), + e.expression * -1, + e.this, + ), exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"), exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)", exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"), exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", + exp.First: _first_last_sql, exp.Group: transforms.preprocess([transforms.unalias_group]), exp.Hex: rename_func("TO_HEX"), exp.If: if_sql, exp.ILike: no_ilike_sql, exp.Initcap: _initcap_sql, + exp.Last: _first_last_sql, exp.Lateral: _explode_to_unnest_sql, exp.Left: left_to_substring_sql, exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), @@ -301,8 +335,10 @@ class Presto(Dialect): exp.SortArray: _no_sort_array, exp.StrPosition: rename_func("STRPOS"), exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", + exp.StrToMap: rename_func("SPLIT_TO_MAP"), exp.StrToTime: _str_to_time_sql, exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", + exp.Struct: rename_func("ROW"), exp.StructExtract: struct_extract_sql, exp.Table: transforms.preprocess([_unnest_sequence]), exp.TimestampTrunc: timestamptrunc_sql, diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 30731e1..351c5df 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -13,7 +13,7 @@ from sqlglot.helper import seq_get from sqlglot.tokens import TokenType -def _json_sql(self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str: +def _json_sql(self: Redshift.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str: return f'{self.sql(expression, "this")}."{expression.expression.name}"' @@ -37,6 +37,8 @@ class Redshift(Postgres): } class Parser(Postgres.Parser): + SUPPORTS_USER_DEFINED_TYPES = False + FUNCTIONS = { **Postgres.Parser.FUNCTIONS, "ADD_MONTHS": lambda args: exp.DateAdd( @@ -55,9 +57,11 @@ class Redshift(Postgres): } def _parse_types( - self, check_func: bool = False, schema: bool = False + self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True ) -> t.Optional[exp.Expression]: - this = super()._parse_types(check_func=check_func, schema=schema) + this = super()._parse_types( + check_func=check_func, schema=schema, allow_identifiers=allow_identifiers + ) if ( isinstance(this, exp.DataType) @@ -100,6 +104,7 @@ class Redshift(Postgres): QUERY_HINTS = False VALUES_AS_TABLE = False TZ_TO_WITH_TIME_ZONE = True + NVL2_SUPPORTED = True TYPE_MAPPING = { **Postgres.Generator.TYPE_MAPPING, @@ -142,6 +147,9 @@ class Redshift(Postgres): # Redshift uses the POW | POWER (expr1, expr2) syntax instead of expr1 ^ expr2 (postgres) TRANSFORMS.pop(exp.Pow) + # Redshift supports ANY_VALUE(..) + TRANSFORMS.pop(exp.AnyValue) + RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"} def with_properties(self, properties: exp.Properties) -> str: diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 9733a85..8d8183c 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -90,7 +90,7 @@ def _parse_datediff(args: t.List) -> exp.DateDiff: return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) -def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str: +def _unix_to_time_sql(self: Snowflake.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") timestamp = self.sql(expression, "this") if scale in [None, exp.UnixToTime.SECONDS]: @@ -105,7 +105,7 @@ def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> # https://docs.snowflake.com/en/sql-reference/functions/date_part.html # https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts -def _parse_date_part(self: parser.Parser) -> t.Optional[exp.Expression]: +def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]: this = self._parse_var() or self._parse_type() if not this: @@ -156,7 +156,7 @@ def _nullifzero_to_if(args: t.List) -> exp.If: return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) -def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: +def _datatype_sql(self: Snowflake.Generator, expression: exp.DataType) -> str: if expression.is_type("array"): return "ARRAY" elif expression.is_type("map"): @@ -164,6 +164,17 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: return self.datatype_sql(expression) +def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> str: + flag = expression.text("flag") + + if "i" not in flag: + flag += "i" + + return self.func( + "REGEXP_LIKE", expression.this, expression.expression, exp.Literal.string(flag) + ) + + def _parse_convert_timezone(args: t.List) -> t.Union[exp.Anonymous, exp.AtTimeZone]: if len(args) == 3: return exp.Anonymous(this="CONVERT_TIMEZONE", expressions=args) @@ -179,6 +190,13 @@ def _parse_regexp_replace(args: t.List) -> exp.RegexpReplace: return regexp_replace +def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser], exp.Show]: + def _parse(self: Snowflake.Parser) -> exp.Show: + return self._parse_show_snowflake(*args, **kwargs) + + return _parse + + class Snowflake(Dialect): # https://docs.snowflake.com/en/sql-reference/identifiers-syntax RESOLVES_IDENTIFIERS_AS_UPPERCASE = True @@ -216,6 +234,7 @@ class Snowflake(Dialect): class Parser(parser.Parser): IDENTIFY_PIVOT_STRINGS = True + SUPPORTS_USER_DEFINED_TYPES = False FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -230,6 +249,7 @@ class Snowflake(Dialect): "DATEDIFF": _parse_datediff, "DIV0": _div0_to_if, "IFF": exp.If.from_arg_list, + "LISTAGG": exp.GroupConcat.from_arg_list, "NULLIFZERO": _nullifzero_to_if, "OBJECT_CONSTRUCT": _parse_object_construct, "REGEXP_REPLACE": _parse_regexp_replace, @@ -250,11 +270,6 @@ class Snowflake(Dialect): } FUNCTION_PARSERS.pop("TRIM") - FUNC_TOKENS = { - *parser.Parser.FUNC_TOKENS, - TokenType.TABLE, - } - COLUMN_OPERATORS = { **parser.Parser.COLUMN_OPERATORS, TokenType.COLON: lambda self, this, path: self.expression( @@ -281,6 +296,16 @@ class Snowflake(Dialect): ), } + STATEMENT_PARSERS = { + **parser.Parser.STATEMENT_PARSERS, + TokenType.SHOW: lambda self: self._parse_show(), + } + + SHOW_PARSERS = { + "PRIMARY KEYS": _show_parser("PRIMARY KEYS"), + "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), + } + def _parse_id_var( self, any_token: bool = True, @@ -296,8 +321,24 @@ class Snowflake(Dialect): return super()._parse_id_var(any_token=any_token, tokens=tokens) + def _parse_show_snowflake(self, this: str) -> exp.Show: + scope = None + scope_kind = None + + if self._match(TokenType.IN): + if self._match_text_seq("ACCOUNT"): + scope_kind = "ACCOUNT" + elif self._match_set(self.DB_CREATABLES): + scope_kind = self._prev.text + if self._curr: + scope = self._parse_table() + elif self._curr: + scope_kind = "TABLE" + scope = self._parse_table() + + return self.expression(exp.Show, this=this, scope=scope, scope_kind=scope_kind) + class Tokenizer(tokens.Tokenizer): - QUOTES = ["'"] STRING_ESCAPES = ["\\", "'"] HEX_STRINGS = [("x'", "'"), ("X'", "'")] RAW_STRINGS = ["$$"] @@ -331,6 +372,8 @@ class Snowflake(Dialect): VAR_SINGLE_TOKENS = {"$"} + COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW} + class Generator(generator.Generator): PARAMETER_TOKEN = "$" MATCHED_BY_SOURCE = False @@ -355,6 +398,7 @@ class Snowflake(Dialect): exp.DataType: _datatype_sql, exp.DayOfWeek: rename_func("DAYOFWEEK"), exp.Extract: rename_func("DATE_PART"), + exp.GroupConcat: rename_func("LISTAGG"), exp.If: rename_func("IFF"), exp.LogicalAnd: rename_func("BOOLAND_AGG"), exp.LogicalOr: rename_func("BOOLOR_AGG"), @@ -362,6 +406,7 @@ class Snowflake(Dialect): exp.Max: max_or_greatest, exp.Min: min_or_least, exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", + exp.RegexpILike: _regexpilike_sql, exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StarMap: rename_func("OBJECT_CONSTRUCT"), exp.StartsWith: rename_func("STARTSWITH"), @@ -373,6 +418,7 @@ class Snowflake(Dialect): "OBJECT_CONSTRUCT", *(arg for expression in e.expressions for arg in expression.flatten()), ), + exp.Stuff: rename_func("INSERT"), exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: lambda self, e: self.func( @@ -403,6 +449,16 @@ class Snowflake(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def show_sql(self, expression: exp.Show) -> str: + scope = self.sql(expression, "scope") + scope = f" {scope}" if scope else "" + + scope_kind = self.sql(expression, "scope_kind") + if scope_kind: + scope_kind = f" IN {scope_kind}" + + return f"SHOW {expression.name}{scope_kind}{scope}" + def regexpextract_sql(self, expression: exp.RegexpExtract) -> str: # Other dialects don't support all of the following parameters, so we need to # generate default values as necessary to ensure the transpilation is correct @@ -436,7 +492,9 @@ class Snowflake(Dialect): kind_value = expression.args.get("kind") or "TABLE" kind = f" {kind_value}" if kind_value else "" this = f" {self.sql(expression, 'this')}" - return f"DESCRIBE{kind}{this}" + expressions = self.expressions(expression, flat=True) + expressions = f" {expressions}" if expressions else "" + return f"DESCRIBE{kind}{this}{expressions}" def generatedasidentitycolumnconstraint_sql( self, expression: exp.GeneratedAsIdentityColumnConstraint diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 7c8982b..a4435f6 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -38,9 +38,15 @@ class Spark(Spark2): class Parser(Spark2.Parser): FUNCTIONS = { **Spark2.Parser.FUNCTIONS, + "ANY_VALUE": lambda args: exp.AnyValue( + this=seq_get(args, 0), ignore_nulls=seq_get(args, 1) + ), "DATEDIFF": _parse_datediff, } + FUNCTION_PARSERS = Spark2.Parser.FUNCTION_PARSERS.copy() + FUNCTION_PARSERS.pop("ANY_VALUE") + class Generator(Spark2.Generator): TYPE_MAPPING = { **Spark2.Generator.TYPE_MAPPING, @@ -56,9 +62,13 @@ class Spark(Spark2): "DATEADD", e.args.get("unit") or "DAY", e.expression, e.this ), } + TRANSFORMS.pop(exp.AnyValue) TRANSFORMS.pop(exp.DateDiff) TRANSFORMS.pop(exp.Group) + def anyvalue_sql(self, expression: exp.AnyValue) -> str: + return self.function_fallback_sql(expression) + def datediff_sql(self, expression: exp.DateDiff) -> str: unit = self.sql(expression, "unit") end = self.sql(expression, "this") diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index ceb48f8..4489b6b 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -15,7 +15,7 @@ from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get -def _create_sql(self: Hive.Generator, e: exp.Create) -> str: +def _create_sql(self: Spark2.Generator, e: exp.Create) -> str: kind = e.args["kind"] properties = e.args.get("properties") @@ -31,17 +31,21 @@ def _create_sql(self: Hive.Generator, e: exp.Create) -> str: return create_with_partitions_sql(self, e) -def _map_sql(self: Hive.Generator, expression: exp.Map) -> str: - keys = self.sql(expression.args["keys"]) - values = self.sql(expression.args["values"]) - return f"MAP_FROM_ARRAYS({keys}, {values})" +def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str: + keys = expression.args.get("keys") + values = expression.args.get("values") + + if not keys or not values: + return "MAP()" + + return f"MAP_FROM_ARRAYS({self.sql(keys)}, {self.sql(values)})" def _parse_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]: return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type)) -def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str: +def _str_to_date(self: Spark2.Generator, expression: exp.StrToDate) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) if time_format == Hive.DATE_FORMAT: @@ -49,7 +53,7 @@ def _str_to_date(self: Hive.Generator, expression: exp.StrToDate) -> str: return f"TO_DATE({this}, {time_format})" -def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str: +def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str: scale = expression.args.get("scale") timestamp = self.sql(expression, "this") if scale is None: @@ -110,6 +114,13 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: return expression +def _insert_sql(self: Spark2.Generator, expression: exp.Insert) -> str: + if expression.expression.args.get("with"): + expression = expression.copy() + expression.set("with", expression.expression.args.pop("with")) + return self.insert_sql(expression) + + class Spark2(Hive): class Parser(Hive.Parser): FUNCTIONS = { @@ -169,10 +180,7 @@ class Spark2(Hive): class Generator(Hive.Generator): QUERY_HINTS = True - - TYPE_MAPPING = { - **Hive.Generator.TYPE_MAPPING, - } + NVL2_SUPPORTED = True PROPERTIES_LOCATION = { **Hive.Generator.PROPERTIES_LOCATION, @@ -197,6 +205,7 @@ class Spark2(Hive): exp.DayOfYear: rename_func("DAYOFYEAR"), exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", exp.From: transforms.preprocess([_unalias_pivot]), + exp.Insert: _insert_sql, exp.LogicalAnd: rename_func("BOOL_AND"), exp.LogicalOr: rename_func("BOOL_OR"), exp.Map: _map_sql, diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 90b774e..7bfdf1c 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -5,6 +5,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + any_value_to_max_sql, arrow_json_extract_scalar_sql, arrow_json_extract_sql, concat_to_dpipe_sql, @@ -18,7 +19,7 @@ from sqlglot.dialects.dialect import ( from sqlglot.tokens import TokenType -def _date_add_sql(self: generator.Generator, expression: exp.DateAdd) -> str: +def _date_add_sql(self: SQLite.Generator, expression: exp.DateAdd) -> str: modifier = expression.expression modifier = modifier.name if modifier.is_string else self.sql(modifier) unit = expression.args.get("unit") @@ -78,6 +79,7 @@ class SQLite(Dialect): JOIN_HINTS = False TABLE_HINTS = False QUERY_HINTS = False + NVL2_SUPPORTED = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -103,6 +105,7 @@ class SQLite(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.AnyValue: any_value_to_max_sql, exp.Concat: concat_to_dpipe_sql, exp.CountIf: count_if_to_sum, exp.Create: transforms.preprocess([_transform_create]), diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py index 2be1a62..163cc13 100644 --- a/sqlglot/dialects/teradata.py +++ b/sqlglot/dialects/teradata.py @@ -95,6 +95,9 @@ class Teradata(Dialect): STATEMENT_PARSERS = { **parser.Parser.STATEMENT_PARSERS, + TokenType.DATABASE: lambda self: self.expression( + exp.Use, this=self._parse_table(schema=False) + ), TokenType.REPLACE: lambda self: self._parse_create(), } @@ -165,6 +168,7 @@ class Teradata(Dialect): exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.StrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), + exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}", } def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str: diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py index af0f78d..0c953a1 100644 --- a/sqlglot/dialects/trino.py +++ b/sqlglot/dialects/trino.py @@ -13,3 +13,6 @@ class Trino(Presto): class Tokenizer(Presto.Tokenizer): HEX_STRINGS = [("X'", "'")] + + class Parser(Presto.Parser): + SUPPORTS_USER_DEFINED_TYPES = False diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 131307f..b26f499 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -7,6 +7,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, + any_value_to_max_sql, max_or_greatest, min_or_least, parse_date_delta, @@ -79,22 +80,23 @@ def _format_time_lambda( def _parse_format(args: t.List) -> exp.Expression: - assert len(args) == 2 + this = seq_get(args, 0) + fmt = seq_get(args, 1) + culture = seq_get(args, 2) - fmt = args[1] - number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name) + number_fmt = fmt and (fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name)) if number_fmt: - return exp.NumberToStr(this=args[0], format=fmt) + return exp.NumberToStr(this=this, format=fmt, culture=culture) - return exp.TimeToStr( - this=args[0], - format=exp.Literal.string( + if fmt: + fmt = exp.Literal.string( format_time(fmt.name, TSQL.FORMAT_TIME_MAPPING) if len(fmt.name) == 1 else format_time(fmt.name, TSQL.TIME_MAPPING) - ), - ) + ) + + return exp.TimeToStr(this=this, format=fmt, culture=culture) def _parse_eomonth(args: t.List) -> exp.Expression: @@ -130,13 +132,13 @@ def _parse_hashbytes(args: t.List) -> exp.Expression: def generate_date_delta_with_unit_sql( - self: generator.Generator, expression: exp.DateAdd | exp.DateDiff + self: TSQL.Generator, expression: exp.DateAdd | exp.DateDiff ) -> str: func = "DATEADD" if isinstance(expression, exp.DateAdd) else "DATEDIFF" return self.func(func, expression.text("unit"), expression.expression, expression.this) -def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str: +def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str: fmt = ( expression.args["format"] if isinstance(expression, exp.NumberToStr) @@ -147,10 +149,10 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim ) ) ) - return self.func("FORMAT", expression.this, fmt) + return self.func("FORMAT", expression.this, fmt, expression.args.get("culture")) -def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str: +def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str: expression = expression.copy() this = expression.this @@ -332,10 +334,12 @@ class TSQL(Dialect): "SQL_VARIANT": TokenType.VARIANT, "TOP": TokenType.TOP, "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, + "UPDATE STATISTICS": TokenType.COMMAND, "VARCHAR(MAX)": TokenType.TEXT, "XML": TokenType.XML, "OUTPUT": TokenType.RETURNING, "SYSTEM_USER": TokenType.CURRENT_USER, + "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, } class Parser(parser.Parser): @@ -395,7 +399,9 @@ class TSQL(Dialect): CONCAT_NULL_OUTPUTS_STRING = True - def _parse_projections(self) -> t.List[t.Optional[exp.Expression]]: + ALTER_TABLE_ADD_COLUMN_KEYWORD = False + + def _parse_projections(self) -> t.List[exp.Expression]: """ T-SQL supports the syntax alias = expression in the SELECT's projection list, so we transform all parsed Selects to convert their EQ projections into Aliases. @@ -458,43 +464,6 @@ class TSQL(Dialect): return self._parse_as_command(self._prev) - def _parse_system_time(self) -> t.Optional[exp.Expression]: - if not self._match_text_seq("FOR", "SYSTEM_TIME"): - return None - - if self._match_text_seq("AS", "OF"): - system_time = self.expression( - exp.SystemTime, this=self._parse_bitwise(), kind="AS OF" - ) - elif self._match_set((TokenType.FROM, TokenType.BETWEEN)): - kind = self._prev.text - this = self._parse_bitwise() - self._match_texts(("TO", "AND")) - expression = self._parse_bitwise() - system_time = self.expression( - exp.SystemTime, this=this, expression=expression, kind=kind - ) - elif self._match_text_seq("CONTAINED", "IN"): - args = self._parse_wrapped_csv(self._parse_bitwise) - system_time = self.expression( - exp.SystemTime, - this=seq_get(args, 0), - expression=seq_get(args, 1), - kind="CONTAINED IN", - ) - elif self._match(TokenType.ALL): - system_time = self.expression(exp.SystemTime, kind="ALL") - else: - system_time = None - self.raise_error("Unable to parse FOR SYSTEM_TIME clause") - - return system_time - - def _parse_table_parts(self, schema: bool = False) -> exp.Table: - table = super()._parse_table_parts(schema=schema) - table.set("system_time", self._parse_system_time()) - return table - def _parse_returns(self) -> exp.ReturnsProperty: table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS) returns = super()._parse_returns() @@ -589,14 +558,36 @@ class TSQL(Dialect): return create + def _parse_if(self) -> t.Optional[exp.Expression]: + index = self._index + + if self._match_text_seq("OBJECT_ID"): + self._parse_wrapped_csv(self._parse_string) + if self._match_text_seq("IS", "NOT", "NULL") and self._match(TokenType.DROP): + return self._parse_drop(exists=True) + self._retreat(index) + + return super()._parse_if() + + def _parse_unique(self) -> exp.UniqueColumnConstraint: + return self.expression( + exp.UniqueColumnConstraint, + this=None + if self._curr and self._curr.text.upper() in {"CLUSTERED", "NONCLUSTERED"} + else self._parse_schema(self._parse_id_var(any_token=False)), + ) + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True LIMIT_IS_TOP = True QUERY_HINTS = False RETURNING_END = False + NVL2_SUPPORTED = False + ALTER_TABLE_ADD_COLUMN_KEYWORD = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, + exp.DataType.Type.BOOLEAN: "BIT", exp.DataType.Type.DECIMAL: "NUMERIC", exp.DataType.Type.DATETIME: "DATETIME2", exp.DataType.Type.INT: "INTEGER", @@ -607,6 +598,8 @@ class TSQL(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.AnyValue: any_value_to_max_sql, + exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY", exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), @@ -651,25 +644,44 @@ class TSQL(Dialect): return sql - def offset_sql(self, expression: exp.Offset) -> str: - return f"{super().offset_sql(expression)} ROWS" + def create_sql(self, expression: exp.Create) -> str: + expression = expression.copy() + kind = self.sql(expression, "kind").upper() + exists = expression.args.pop("exists", None) + sql = super().create_sql(expression) + + if exists: + table = expression.find(exp.Table) + identifier = self.sql(exp.Literal.string(exp.table_name(table) if table else "")) + if kind == "SCHEMA": + sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = {identifier}) EXEC('{sql}')""" + elif kind == "TABLE": + sql = f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = {identifier}) EXEC('{sql}')""" + elif kind == "INDEX": + index = self.sql(exp.Literal.string(expression.this.text("this"))) + sql = f"""IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id({identifier}) AND name = {index}) EXEC('{sql}')""" + elif expression.args.get("replace"): + sql = sql.replace("CREATE OR REPLACE ", "CREATE OR ALTER ", 1) - def systemtime_sql(self, expression: exp.SystemTime) -> str: - kind = expression.args["kind"] - if kind == "ALL": - return "FOR SYSTEM_TIME ALL" + return sql - start = self.sql(expression, "this") - if kind == "AS OF": - return f"FOR SYSTEM_TIME AS OF {start}" + def offset_sql(self, expression: exp.Offset) -> str: + return f"{super().offset_sql(expression)} ROWS" - end = self.sql(expression, "expression") - if kind == "FROM": - return f"FOR SYSTEM_TIME FROM {start} TO {end}" - if kind == "BETWEEN": - return f"FOR SYSTEM_TIME BETWEEN {start} AND {end}" + def version_sql(self, expression: exp.Version) -> str: + name = "SYSTEM_TIME" if expression.name == "TIMESTAMP" else expression.name + this = f"FOR {name}" + expr = expression.expression + kind = expression.text("kind") + if kind in ("FROM", "BETWEEN"): + args = expr.expressions + sep = "TO" if kind == "FROM" else "AND" + expr_sql = f"{self.sql(seq_get(args, 0))} {sep} {self.sql(seq_get(args, 1))}" + else: + expr_sql = self.sql(expr) - return f"FOR SYSTEM_TIME CONTAINED IN ({start}, {end})" + expr_sql = f" {expr_sql}" if expr_sql else "" + return f"{this} {kind}{expr_sql}" def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str: table = expression.args.get("table") @@ -713,3 +725,16 @@ class TSQL(Dialect): identifier = f"#{identifier}" return identifier + + def constraint_sql(self, expression: exp.Constraint) -> str: + this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True, sep=" ") + return f"CONSTRAINT {this} {expressions}" + + # https://learn.microsoft.com/en-us/answers/questions/448821/create-table-in-sql-server + def generatedasidentitycolumnconstraint_sql( + self, expression: exp.GeneratedAsIdentityColumnConstraint + ) -> str: + start = self.sql(expression, "start") or "1" + increment = self.sql(expression, "increment") or "1" + return f"IDENTITY({start}, {increment})" |