diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/bigquery.py | 30 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 30 | ||||
-rw-r--r-- | sqlglot/dialects/databricks.py | 20 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 9 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 33 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 107 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 8 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 14 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 2 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 47 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 11 | ||||
-rw-r--r-- | sqlglot/dialects/spark2.py | 26 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 9 |
14 files changed, 290 insertions, 59 deletions
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 1349c56..0d741b5 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -190,6 +190,16 @@ class BigQuery(Dialect): "%D": "%m/%d/%y", } + ESCAPE_SEQUENCES = { + "\\a": "\a", + "\\b": "\b", + "\\f": "\f", + "\\n": "\n", + "\\r": "\r", + "\\t": "\t", + "\\v": "\v", + } + FORMAT_MAPPING = { "DD": "%d", "MM": "%m", @@ -212,15 +222,14 @@ class BigQuery(Dialect): @classmethod def normalize_identifier(cls, expression: E) -> E: - # In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least). - # The following check is essentially a heuristic to detect tables based on whether or - # not they're qualified. if isinstance(expression, exp.Identifier): parent = expression.parent - while isinstance(parent, exp.Dot): parent = parent.parent + # In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least). + # The following check is essentially a heuristic to detect tables based on whether or + # not they're qualified. It also avoids normalizing UDFs, because they're case-sensitive. if ( not isinstance(parent, exp.UserDefinedFunction) and not (isinstance(parent, exp.Table) and parent.db) @@ -419,6 +428,7 @@ class BigQuery(Dialect): RENAME_TABLE_WITH_DB = False NVL2_SUPPORTED = False UNNEST_WITH_ORDINALITY = False + COLLATE_IS_FUNC = True TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -520,18 +530,6 @@ class BigQuery(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } - UNESCAPED_SEQUENCE_TABLE = str.maketrans( # type: ignore - { - "\a": "\\a", - "\b": "\\b", - "\f": "\\f", - "\n": "\\n", - "\r": "\\r", - "\t": "\\t", - "\v": "\\v", - } - ) - # from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords RESERVED_KEYWORDS = { *generator.Generator.RESERVED_KEYWORDS, diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 7446081..e9d9326 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -2,7 +2,7 @@ from __future__ import annotations import typing as t -from sqlglot import exp, generator, parser, tokens +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, inline_array_sql, @@ -21,18 +21,33 @@ def _lower_func(sql: str) -> str: return sql[:index].lower() + sql[index:] +def _quantile_sql(self, e): + quantile = e.args["quantile"] + args = f"({self.sql(e, 'this')})" + if isinstance(quantile, exp.Array): + func = self.func("quantiles", *quantile) + else: + func = self.func("quantile", quantile) + return func + args + + class ClickHouse(Dialect): NORMALIZE_FUNCTIONS: bool | str = False NULL_ORDERING = "nulls_are_last" STRICT_STRING_CONCAT = True SUPPORTS_USER_DEFINED_TYPES = False + ESCAPE_SEQUENCES = { + "\\0": "\0", + } + class Tokenizer(tokens.Tokenizer): COMMENTS = ["--", "#", "#!", ("/*", "*/")] IDENTIFIERS = ['"', "`"] STRING_ESCAPES = ["'", "\\"] BIT_STRINGS = [("0b", "")] HEX_STRINGS = [("0x", ""), ("0X", "")] + HEREDOC_STRINGS = ["$"] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -55,6 +70,7 @@ class ClickHouse(Dialect): "LOWCARDINALITY": TokenType.LOWCARDINALITY, "MAP": TokenType.MAP, "NESTED": TokenType.NESTED, + "SAMPLE": TokenType.TABLE_SAMPLE, "TUPLE": TokenType.STRUCT, "UINT128": TokenType.UINT128, "UINT16": TokenType.USMALLINT, @@ -64,6 +80,11 @@ class ClickHouse(Dialect): "UINT8": TokenType.UTINYINT, } + SINGLE_TOKENS = { + **tokens.Tokenizer.SINGLE_TOKENS, + "$": TokenType.HEREDOC_STRING, + } + class Parser(parser.Parser): FUNCTIONS = { **parser.Parser.FUNCTIONS, @@ -301,6 +322,7 @@ class ClickHouse(Dialect): QUERY_HINTS = False STRUCT_DELIMITER = ("(", ")") NVL2_SUPPORTED = False + TABLESAMPLE_REQUIRES_PARENS = False STRING_TYPE_MAPPING = { exp.DataType.Type.CHAR: "String", @@ -348,6 +370,7 @@ class ClickHouse(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.Select: transforms.preprocess([transforms.eliminate_qualify]), exp.AnyValue: rename_func("any"), exp.ApproxDistinct: rename_func("uniq"), exp.Array: inline_array_sql, @@ -359,12 +382,13 @@ class ClickHouse(Dialect): "DATE_DIFF", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this ), exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", + exp.IsNan: rename_func("isNaN"), exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.Pivot: no_pivot_sql, - exp.Quantile: lambda self, e: self.func("quantile", e.args.get("quantile")) - + f"({self.sql(e, 'this')})", + exp.Quantile: _quantile_sql, exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})", + exp.StartsWith: rename_func("startsWith"), exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions), diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 39daad7..a044bc0 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -51,6 +51,26 @@ class Databricks(Spark): exp.ToChar: lambda self, e: self.function_fallback_sql(e), } + def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: + constraint = expression.find(exp.GeneratedAsIdentityColumnConstraint) + kind = expression.args.get("kind") + if ( + constraint + and isinstance(kind, exp.DataType) + and kind.this in exp.DataType.INTEGER_TYPES + ): + # only BIGINT generated identity constraints are supported + expression = expression.copy() + expression.set("kind", exp.DataType.build("bigint")) + return super().columndef_sql(expression, sep) + + def generatedasidentitycolumnconstraint_sql( + self, expression: exp.GeneratedAsIdentityColumnConstraint + ) -> str: + expression = expression.copy() + expression.set("this", True) # trigger ALWAYS in super class + return super().generatedasidentitycolumnconstraint_sql(expression) + class Tokenizer(Spark.Tokenizer): HEX_STRINGS = [] diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index ccf04da..bd839af 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -81,6 +81,8 @@ class _Dialect(type): klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) + klass.INVERSE_ESCAPE_SEQUENCES = {v: k for k, v in klass.ESCAPE_SEQUENCES.items()} + klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) klass.parser_class = getattr(klass, "Parser", Parser) klass.generator_class = getattr(klass, "Generator", Generator) @@ -188,6 +190,9 @@ class Dialect(metaclass=_Dialect): # special syntax cast(x as date format 'yyyy') defaults to time_mapping FORMAT_MAPPING: t.Dict[str, str] = {} + # Mapping of an unescaped escape sequence to the corresponding character + ESCAPE_SEQUENCES: t.Dict[str, str] = {} + # Columns that are auto-generated by the engine corresponding to this dialect # Such columns may be excluded from SELECT * queries, for example PSEUDOCOLUMNS: t.Set[str] = set() @@ -204,6 +209,8 @@ class Dialect(metaclass=_Dialect): INVERSE_TIME_MAPPING: t.Dict[str, str] = {} INVERSE_TIME_TRIE: t.Dict = {} + INVERSE_ESCAPE_SEQUENCES: t.Dict[str, str] = {} + def __eq__(self, other: t.Any) -> bool: return type(self) == other @@ -245,7 +252,7 @@ class Dialect(metaclass=_Dialect): """ Normalizes an unquoted identifier to either lower or upper case, thus essentially making it case-insensitive. If a dialect treats all identifiers as case-insensitive, - they will be normalized regardless of being quoted or not. + they will be normalized to lowercase regardless of being quoted or not. """ if isinstance(expression, exp.Identifier) and ( not expression.quoted or cls.RESOLVES_IDENTIFIERS_AS_UPPERCASE is None diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index a427870..3f925a7 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -51,6 +51,32 @@ TIME_DIFF_FACTOR = { DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") +def _create_sql(self, expression: exp.Create) -> str: + expression = expression.copy() + + # remove UNIQUE column constraints + for constraint in expression.find_all(exp.UniqueColumnConstraint): + if constraint.parent: + constraint.parent.pop() + + properties = expression.args.get("properties") + temporary = any( + isinstance(prop, exp.TemporaryProperty) + for prop in (properties.expressions if properties else []) + ) + + # CTAS with temp tables map to CREATE TEMPORARY VIEW + kind = expression.args["kind"] + if kind.upper() == "TABLE" and temporary: + if expression.expression: + return f"CREATE TEMPORARY VIEW {self.sql(expression, 'this')} AS {self.sql(expression, 'expression')}" + else: + # CREATE TEMPORARY TABLE may require storage provider + expression = self.temporary_storage_provider(expression) + + return create_with_partitions_sql(self, expression) + + def _add_date_sql(self: Hive.Generator, expression: exp.DateAdd | exp.DateSub) -> str: unit = expression.text("unit").upper() func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) @@ -429,7 +455,7 @@ class Hive(Dialect): if e.args.get("allow_null") else "NOT NULL", exp.VarMap: var_map_sql, - exp.Create: create_with_partitions_sql, + exp.Create: _create_sql, exp.Quantile: rename_func("PERCENTILE"), exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), exp.RegexpExtract: regexp_extract_sql, @@ -478,8 +504,13 @@ class Hive(Dialect): exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA, exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, + exp.WithDataProperty: exp.Properties.Location.UNSUPPORTED, } + def temporary_storage_provider(self, expression: exp.Create) -> exp.Create: + # Hive has no temporary storage provider (there are hive settings though) + return expression + def parameter_sql(self, expression: exp.Parameter) -> str: this = self.sql(expression, "this") parent = expression.parent diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 554241d..59a0a2a 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -66,7 +66,9 @@ def _str_to_date(args: t.List) -> exp.StrToDate: return exp.StrToDate(this=seq_get(args, 0), format=date_format) -def _str_to_date_sql(self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime) -> str: +def _str_to_date_sql( + self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate +) -> str: date_format = self.format_time(expression) return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})" @@ -86,8 +88,10 @@ def _trim_sql(self: MySQL.Generator, expression: exp.Trim) -> str: return f"TRIM({trim_type}{remove_chars}{from_part}{target})" -def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.DateSub], str]: - def func(self: MySQL.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _date_add_sql( + kind: str, +) -> t.Callable[[MySQL.Generator, exp.Expression], str]: + def func(self: MySQL.Generator, expression: exp.Expression) -> str: this = self.sql(expression, "this") unit = expression.text("unit").upper() or "DAY" return f"DATE_{kind}({this}, {self.sql(exp.Interval(this=expression.expression.copy(), unit=unit))})" @@ -95,6 +99,30 @@ def _date_add_sql(kind: str) -> t.Callable[[MySQL.Generator, exp.DateAdd | exp.D return func +def _ts_or_ds_to_date_sql(self: MySQL.Generator, expression: exp.TsOrDsToDate) -> str: + time_format = expression.args.get("format") + if time_format: + return _str_to_date_sql(self, expression) + return f"DATE({self.sql(expression, 'this')})" + + +def _remove_ts_or_ds_to_date( + to_sql: t.Optional[t.Callable[[MySQL.Generator, exp.Expression], str]] = None, + args: t.Tuple[str, ...] = ("this",), +) -> t.Callable[[MySQL.Generator, exp.Func], str]: + def func(self: MySQL.Generator, expression: exp.Func) -> str: + expression = expression.copy() + + for arg_key in args: + arg = expression.args.get(arg_key) + if isinstance(arg, exp.TsOrDsToDate) and not arg.args.get("format"): + expression.set(arg_key, arg.this) + + return to_sql(self, expression) if to_sql else self.function_fallback_sql(expression) + + return func + + class MySQL(Dialect): # https://dev.mysql.com/doc/refman/8.0/en/identifiers.html IDENTIFIERS_CAN_START_WITH_DIGIT = True @@ -233,6 +261,7 @@ class MySQL(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, + "DATE": lambda args: exp.TsOrDsToDate(this=seq_get(args, 0)), "DATE_ADD": parse_date_delta_with_interval(exp.DateAdd), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "mysql"), "DATE_SUB": parse_date_delta_with_interval(exp.DateSub), @@ -240,14 +269,33 @@ class MySQL(Dialect): "ISNULL": isnull_to_is_null, "LOCATE": locate_to_strposition, "MONTHNAME": lambda args: exp.TimeToStr( - this=seq_get(args, 0), + this=exp.TsOrDsToDate(this=seq_get(args, 0)), format=exp.Literal.string("%B"), ), "STR_TO_DATE": _str_to_date, + "TO_DAYS": lambda args: exp.paren( + exp.DateDiff( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + expression=exp.TsOrDsToDate(this=exp.Literal.string("0000-01-01")), + unit=exp.var("DAY"), + ) + + 1 + ), + "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "WEEK": lambda args: exp.Week( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), mode=seq_get(args, 1) + ), + "WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), + "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate(this=seq_get(args, 0))), } FUNCTION_PARSERS = { **parser.Parser.FUNCTION_PARSERS, + "CHAR": lambda self: self._parse_chr(), "GROUP_CONCAT": lambda self: self.expression( exp.GroupConcat, this=self._parse_lambda(), @@ -531,6 +579,18 @@ class MySQL(Dialect): return super()._parse_type(parse_interval=parse_interval) + def _parse_chr(self) -> t.Optional[exp.Expression]: + expressions = self._parse_csv(self._parse_conjunction) + kwargs: t.Dict[str, t.Any] = {"this": seq_get(expressions, 0)} + + if len(expressions) > 1: + kwargs["expressions"] = expressions[1:] + + if self._match(TokenType.USING): + kwargs["charset"] = self._parse_var() + + return self.expression(exp.Chr, **kwargs) + class Generator(generator.Generator): LOCKING_READS_SUPPORTED = True NULL_ORDERING_SUPPORTED = False @@ -544,25 +604,33 @@ class MySQL(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.CurrentDate: no_paren_current_date_sql, - exp.DateDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression), - exp.DateAdd: _date_add_sql("ADD"), + exp.DateDiff: _remove_ts_or_ds_to_date( + lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression") + ), + exp.DateAdd: _remove_ts_or_ds_to_date(_date_add_sql("ADD")), exp.DateStrToDate: datestrtodate_sql, - exp.DateSub: _date_add_sql("SUB"), + exp.DateSub: _remove_ts_or_ds_to_date(_date_add_sql("SUB")), exp.DateTrunc: _date_trunc_sql, - exp.DayOfMonth: rename_func("DAYOFMONTH"), - exp.DayOfWeek: rename_func("DAYOFWEEK"), - exp.DayOfYear: rename_func("DAYOFYEAR"), + exp.Day: _remove_ts_or_ds_to_date(), + exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")), + exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")), + exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")), exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""", exp.ILike: no_ilike_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONKeyValue: json_keyvalue_comma_sql, exp.Max: max_or_greatest, exp.Min: min_or_least, + exp.Month: _remove_ts_or_ds_to_date(), exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")), exp.Pivot: no_pivot_sql, exp.Select: transforms.preprocess( - [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] + [ + transforms.eliminate_distinct_on, + transforms.eliminate_semi_and_anti_joins, + transforms.eliminate_qualify, + ] ), exp.StrPosition: strposition_to_locate_sql, exp.StrToDate: _str_to_date_sql, @@ -573,10 +641,16 @@ class MySQL(Dialect): exp.TimestampSub: date_add_interval_sql("DATE", "SUB"), exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)), - exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)), + exp.TimeToStr: _remove_ts_or_ds_to_date( + lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)) + ), exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, - exp.WeekOfYear: rename_func("WEEKOFYEAR"), + exp.TsOrDsAdd: _date_add_sql("ADD"), + exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.Week: _remove_ts_or_ds_to_date(), + exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")), + exp.Year: _remove_ts_or_ds_to_date(), } UNSIGNED_TYPE_MAPPING = { @@ -585,6 +659,7 @@ class MySQL(Dialect): exp.DataType.Type.UMEDIUMINT: "MEDIUMINT", exp.DataType.Type.USMALLINT: "SMALLINT", exp.DataType.Type.UTINYINT: "TINYINT", + exp.DataType.Type.UDECIMAL: "DECIMAL", } TIMESTAMP_TYPE_MAPPING = { @@ -717,3 +792,9 @@ class MySQL(Dialect): limit_offset = f"{offset}, {limit}" if offset else limit return f" LIMIT {limit_offset}" return "" + + def chr_sql(self, expression: exp.Chr) -> str: + this = self.expressions(sqls=[expression.this] + expression.expressions) + charset = expression.args.get("charset") + using = f" USING {self.sql(charset)}" if charset else "" + return f"CHAR({this}{using})" diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 0a4926d..6a007ab 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -153,6 +153,7 @@ class Oracle(Dialect): JOIN_HINTS = False TABLE_HINTS = False COLUMN_JOIN_MARKS_SUPPORTED = True + DATA_TYPE_SPECIFIERS_ALLOWED = True LIMIT_FETCH = "FETCH" @@ -179,7 +180,12 @@ class Oracle(Dialect): ), exp.Group: transforms.preprocess([transforms.unalias_group]), exp.ILike: no_ilike_sql, - exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), + exp.Select: transforms.preprocess( + [ + transforms.eliminate_distinct_on, + transforms.eliminate_qualify, + ] + ), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "), exp.Substring: rename_func("SUBSTR"), diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 342fd95..008727c 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -22,6 +22,7 @@ from sqlglot.dialects.dialect import ( rename_func, simplify_literal, str_position_sql, + struct_extract_sql, timestamptrunc_sql, timestrtotime_sql, trim_sql, @@ -248,11 +249,10 @@ class Postgres(Dialect): } class Tokenizer(tokens.Tokenizer): - QUOTES = ["'", "$$"] - BIT_STRINGS = [("b'", "'"), ("B'", "'")] HEX_STRINGS = [("x'", "'"), ("X'", "'")] BYTE_STRINGS = [("e'", "'"), ("E'", "'")] + HEREDOC_STRINGS = ["$"] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -296,7 +296,7 @@ class Postgres(Dialect): SINGLE_TOKENS = { **tokens.Tokenizer.SINGLE_TOKENS, - "$": TokenType.PARAMETER, + "$": TokenType.HEREDOC_STRING, } VAR_SINGLE_TOKENS = {"$"} @@ -420,9 +420,15 @@ class Postgres(Dialect): exp.Pow: lambda self, e: self.binary(e, "^"), exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), - exp.Select: transforms.preprocess([transforms.eliminate_semi_and_anti_joins]), + exp.Select: transforms.preprocess( + [ + transforms.eliminate_semi_and_anti_joins, + transforms.eliminate_qualify, + ] + ), exp.StrPosition: str_position_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.StructExtract: struct_extract_sql, exp.Substring: _substring_sql, exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 0d8d4ab..e5cfa1c 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -309,6 +309,9 @@ class Presto(Dialect): exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", exp.First: _first_last_sql, exp.Group: transforms.preprocess([transforms.unalias_group]), + exp.GroupConcat: lambda self, e: self.func( + "ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator") + ), exp.Hex: rename_func("TO_HEX"), exp.If: if_sql(), exp.ILike: no_ilike_sql, diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 2145844..88e4448 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -83,7 +83,7 @@ class Redshift(Postgres): class Tokenizer(Postgres.Tokenizer): BIT_STRINGS = [] HEX_STRINGS = [] - STRING_ESCAPES = ["\\"] + STRING_ESCAPES = ["\\", "'"] KEYWORDS = { **Postgres.Tokenizer.KEYWORDS, diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 5c49331..fc3e0fa 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -239,6 +239,8 @@ class Snowflake(Dialect): class Parser(parser.Parser): IDENTIFY_PIVOT_STRINGS = True + TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS | {TokenType.WINDOW} + FUNCTIONS = { **parser.Parser.FUNCTIONS, "ARRAYAGG": exp.ArrayAgg.from_arg_list, @@ -318,6 +320,43 @@ class Snowflake(Dialect): "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), } + STAGED_FILE_SINGLE_TOKENS = { + TokenType.DOT, + TokenType.MOD, + TokenType.SLASH, + } + + def _parse_table_parts(self, schema: bool = False) -> exp.Table: + # https://docs.snowflake.com/en/user-guide/querying-stage + table: t.Optional[exp.Expression] = None + if self._match_text_seq("@"): + table_name = "@" + while True: + self._advance() + table_name += self._prev.text + if not self._match_set(self.STAGED_FILE_SINGLE_TOKENS, advance=False): + break + while self._match_set(self.STAGED_FILE_SINGLE_TOKENS): + table_name += self._prev.text + + table = exp.var(table_name) + elif self._match(TokenType.STRING, advance=False): + table = self._parse_string() + + if table: + file_format = None + pattern = None + + if self._match_text_seq("(", "FILE_FORMAT", "=>"): + file_format = self._parse_string() or super()._parse_table_parts() + if self._match_text_seq(",", "PATTERN", "=>"): + pattern = self._parse_string() + self._match_r_paren() + + return self.expression(exp.Table, this=table, format=file_format, pattern=pattern) + + return super()._parse_table_parts(schema=schema) + def _parse_id_var( self, any_token: bool = True, @@ -394,6 +433,8 @@ class Snowflake(Dialect): TABLE_HINTS = False QUERY_HINTS = False AGGREGATE_FILTER_SUPPORTED = False + SUPPORTS_TABLE_COPY = False + COLLATE_IS_FUNC = True TRANSFORMS = { **generator.Generator.TRANSFORMS, @@ -423,6 +464,12 @@ class Snowflake(Dialect): exp.Max: max_or_greatest, exp.Min: min_or_least, exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", + exp.PercentileCont: transforms.preprocess( + [transforms.add_within_group_for_percentiles] + ), + exp.PercentileDisc: transforms.preprocess( + [transforms.add_within_group_for_percentiles] + ), exp.RegexpILike: _regexpilike_sql, exp.Select: transforms.preprocess( [ diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 9d4a1ab..2eaa2ae 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -54,6 +54,14 @@ class Spark(Spark2): FUNCTION_PARSERS = Spark2.Parser.FUNCTION_PARSERS.copy() FUNCTION_PARSERS.pop("ANY_VALUE") + def _parse_generated_as_identity( + self, + ) -> exp.GeneratedAsIdentityColumnConstraint | exp.ComputedColumnConstraint: + this = super()._parse_generated_as_identity() + if this.expression: + return self.expression(exp.ComputedColumnConstraint, this=this.expression) + return this + class Generator(Spark2.Generator): TYPE_MAPPING = { **Spark2.Generator.TYPE_MAPPING, @@ -73,6 +81,9 @@ class Spark(Spark2): TRANSFORMS.pop(exp.DateDiff) TRANSFORMS.pop(exp.Group) + def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str: + return f"GENERATED ALWAYS AS ({self.sql(expression, 'this')})" + def anyvalue_sql(self, expression: exp.AnyValue) -> str: return self.function_fallback_sql(expression) diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 3dc9838..4130375 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -5,7 +5,6 @@ import typing as t from sqlglot import exp, transforms from sqlglot.dialects.dialect import ( binary_from_function, - create_with_partitions_sql, format_time_lambda, is_parse_json, move_insert_cte_sql, @@ -17,22 +16,6 @@ from sqlglot.dialects.hive import Hive from sqlglot.helper import seq_get -def _create_sql(self: Spark2.Generator, e: exp.Create) -> str: - kind = e.args["kind"] - properties = e.args.get("properties") - - if ( - kind.upper() == "TABLE" - and e.expression - and any( - isinstance(prop, exp.TemporaryProperty) - for prop in (properties.expressions if properties else []) - ) - ): - return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}" - return create_with_partitions_sql(self, e) - - def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str: keys = expression.args.get("keys") values = expression.args.get("values") @@ -118,6 +101,8 @@ def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: class Spark2(Hive): class Parser(Hive.Parser): + TRIM_PATTERN_FIRST = True + FUNCTIONS = { **Hive.Parser.FUNCTIONS, "AGGREGATE": exp.Reduce.from_arg_list, @@ -192,7 +177,6 @@ class Spark2(Hive): exp.AtTimeZone: lambda self, e: f"FROM_UTC_TIMESTAMP({self.sql(e, 'this')}, {self.sql(e, 'zone')})", exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), - exp.Create: _create_sql, exp.DateFromParts: rename_func("MAKE_DATE"), exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")), exp.DayOfMonth: rename_func("DAYOFMONTH"), @@ -236,6 +220,12 @@ class Spark2(Hive): WRAP_DERIVED_VALUES = False CREATE_FUNCTION_RETURN_AS = False + def temporary_storage_provider(self, expression: exp.Create) -> exp.Create: + # spark2, spark, Databricks require a storage provider for temporary tables + provider = exp.FileFormatProperty(this=exp.Literal.string("parquet")) + expression.args["properties"].append("expressions", provider) + return expression + def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: if is_parse_json(expression.this): schema = f"'{self.sql(expression, 'to')}'" diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index fa62e78..6aa49e4 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -14,6 +14,7 @@ from sqlglot.dialects.dialect import ( parse_date_delta, rename_func, timestrtotime_sql, + ts_or_ds_to_date_sql, ) from sqlglot.expressions import DataType from sqlglot.helper import seq_get @@ -590,6 +591,7 @@ class TSQL(Dialect): NVL2_SUPPORTED = False ALTER_TABLE_ADD_COLUMN_KEYWORD = False LIMIT_FETCH = "FETCH" + COMPUTED_COLUMN_WITH_TYPE = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -619,7 +621,11 @@ class TSQL(Dialect): exp.Min: min_or_least, exp.NumberToStr: _format_sql, exp.Select: transforms.preprocess( - [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] + [ + transforms.eliminate_distinct_on, + transforms.eliminate_semi_and_anti_joins, + transforms.eliminate_qualify, + ] ), exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this), exp.SHA2: lambda self, e: self.func( @@ -630,6 +636,7 @@ class TSQL(Dialect): exp.TemporaryProperty: lambda self, e: "", exp.TimeStrToTime: timestrtotime_sql, exp.TimeToStr: _format_sql, + exp.TsOrDsToDate: ts_or_ds_to_date_sql("tsql"), } TRANSFORMS.pop(exp.ReturnsProperty) |