From cf49728f719975144a958f23ba5f3336fb81ae55 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 27 Apr 2024 04:50:25 +0200 Subject: Merging upstream version 23.12.1. Signed-off-by: Daniel Baumann --- sqlglot/dialects/bigquery.py | 34 ++++++++- sqlglot/dialects/clickhouse.py | 100 ++++++++++++++++--------- sqlglot/dialects/dialect.py | 2 +- sqlglot/dialects/duckdb.py | 33 +++++---- sqlglot/dialects/hive.py | 1 + sqlglot/dialects/mysql.py | 1 + sqlglot/dialects/oracle.py | 1 + sqlglot/dialects/postgres.py | 1 + sqlglot/dialects/presto.py | 61 ++++++++++++++++ sqlglot/dialects/prql.py | 15 ++++ sqlglot/dialects/redshift.py | 162 ++++++++++++++++++++++++++++++++++++++++- sqlglot/dialects/snowflake.py | 6 +- sqlglot/dialects/tsql.py | 8 ++ 13 files changed, 371 insertions(+), 54 deletions(-) (limited to 'sqlglot/dialects') diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index a7b4895..3d2fb3d 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -222,7 +222,6 @@ class BigQuery(Dialect): # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_elements_date_time TIME_MAPPING = { "%D": "%m/%d/%y", - "%E*S": "%S.%f", "%E6S": "%S.%f", } @@ -474,11 +473,31 @@ class BigQuery(Dialect): if rest and this: this = exp.Dot.build([this, *rest]) # type: ignore - table = exp.Table(this=this, db=db, catalog=catalog) + table = exp.Table( + this=this, db=db, catalog=catalog, pivots=table.args.get("pivots") + ) table.meta["quoted_table"] = True return table + def _parse_column(self) -> t.Optional[exp.Expression]: + column = super()._parse_column() + if isinstance(column, exp.Column): + parts = column.parts + if any("." in p.name for p in parts): + catalog, db, table, this, *rest = ( + exp.to_identifier(p, quoted=True) + for p in split_num_words(".".join(p.name for p in parts), ".", 4) + ) + + if rest and this: + this = exp.Dot.build([this, *rest]) # type: ignore + + column = exp.Column(this=this, table=table, db=db, catalog=catalog) + column.meta["quoted_column"] = True + + return column + @t.overload def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ... @@ -670,6 +689,7 @@ class BigQuery(Dialect): exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP", exp.DataType.Type.TINYINT: "INT64", exp.DataType.Type.VARBINARY: "BYTES", + exp.DataType.Type.ROWVERSION: "BYTES", exp.DataType.Type.VARCHAR: "STRING", exp.DataType.Type.VARIANT: "ANY TYPE", } @@ -781,6 +801,16 @@ class BigQuery(Dialect): "within", } + def column_parts(self, expression: exp.Column) -> str: + if expression.meta.get("quoted_column"): + # If a column reference is of the form `dataset.table`.name, we need + # to preserve the quoted table path, otherwise the reference breaks + table_parts = ".".join(p.name for p in expression.parts[:-1]) + table_path = self.sql(exp.Identifier(this=table_parts, quoted=True)) + return f"{table_path}.{self.sql(expression, 'this')}" + + return super().column_parts(expression) + def table_parts(self, expression: exp.Table) -> str: # Depending on the context, `x.y` may not resolve to the same data source as `x`.`y`, so # we need to make sure the correct quoting is used in each case. diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 34ee529..67e28d0 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -6,6 +6,7 @@ from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, arg_max_or_min_no_count, + build_formatted_time, date_delta_sql, inline_array_sql, json_extract_segments, @@ -19,6 +20,16 @@ from sqlglot.helper import is_int, seq_get from sqlglot.tokens import Token, TokenType +def _build_date_format(args: t.List) -> exp.TimeToStr: + expr = build_formatted_time(exp.TimeToStr, "clickhouse")(args) + + timezone = seq_get(args, 2) + if timezone: + expr.set("timezone", timezone) + + return expr + + def _lower_func(sql: str) -> str: index = sql.index("(") return sql[:index].lower() + sql[index:] @@ -124,6 +135,8 @@ class ClickHouse(Dialect): "DATEDIFF": lambda args: exp.DateDiff( this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) ), + "DATE_FORMAT": _build_date_format, + "FORMATDATETIME": _build_date_format, "JSONEXTRACTSTRING": build_json_extract_path( exp.JSONExtractScalar, zero_based_indexing=False ), @@ -241,6 +254,14 @@ class ClickHouse(Dialect): "sparkBar", "sumCount", "largestTriangleThreeBuckets", + "histogram", + "sequenceMatch", + "sequenceCount", + "windowFunnel", + "retention", + "uniqUpTo", + "sequenceNextNode", + "exponentialTimeDecayedAvg", } AGG_FUNCTIONS_SUFFIXES = [ @@ -383,6 +404,7 @@ class ClickHouse(Dialect): alias_tokens: t.Optional[t.Collection[TokenType]] = None, parse_bracket: bool = False, is_db_reference: bool = False, + parse_partition: bool = False, ) -> t.Optional[exp.Expression]: this = super()._parse_table( schema=schema, @@ -447,46 +469,53 @@ class ClickHouse(Dialect): functions: t.Optional[t.Dict[str, t.Callable]] = None, anonymous: bool = False, optional_parens: bool = True, + any_token: bool = False, ) -> t.Optional[exp.Expression]: - func = super()._parse_function( - functions=functions, anonymous=anonymous, optional_parens=optional_parens + expr = super()._parse_function( + functions=functions, + anonymous=anonymous, + optional_parens=optional_parens, + any_token=any_token, + ) + + func = expr.this if isinstance(expr, exp.Window) else expr + + # Aggregate functions can be split in 2 parts: + parts = ( + self.AGG_FUNC_MAPPING.get(func.this) if isinstance(func, exp.Anonymous) else None ) - if isinstance(func, exp.Anonymous): - parts = self.AGG_FUNC_MAPPING.get(func.this) + if parts: params = self._parse_func_params(func) + kwargs = { + "this": func.this, + "expressions": func.expressions, + } + if parts[1]: + kwargs["parts"] = parts + exp_class = exp.CombinedParameterizedAgg if params else exp.CombinedAggFunc + else: + exp_class = exp.ParameterizedAgg if params else exp.AnonymousAggFunc + + kwargs["exp_class"] = exp_class if params: - if parts and parts[1]: - return self.expression( - exp.CombinedParameterizedAgg, - this=func.this, - expressions=func.expressions, - params=params, - parts=parts, - ) - return self.expression( - exp.ParameterizedAgg, - this=func.this, - expressions=func.expressions, - params=params, - ) - - if parts: - if parts[1]: - return self.expression( - exp.CombinedAggFunc, - this=func.this, - expressions=func.expressions, - parts=parts, - ) - return self.expression( - exp.AnonymousAggFunc, - this=func.this, - expressions=func.expressions, - ) - - return func + kwargs["params"] = params + + func = self.expression(**kwargs) + + if isinstance(expr, exp.Window): + # The window's func was parsed as Anonymous in base parser, fix its + # type to be CH style CombinedAnonymousAggFunc / AnonymousAggFunc + expr.set("this", func) + elif params: + # Params have blocked super()._parse_function() from parsing the following window + # (if that exists) as they're standing between the function call and the window spec + expr = self._parse_window(func) + else: + expr = func + + return expr def _parse_func_params( self, this: t.Optional[exp.Func] = None @@ -653,6 +682,9 @@ class ClickHouse(Dialect): exp.StrPosition: lambda self, e: self.func( "position", e.this, e.args.get("substr"), e.args.get("position") ), + exp.TimeToStr: lambda self, e: self.func( + "DATE_FORMAT", e.this, self.format_time(e), e.args.get("timezone") + ), exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions), } diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 5a47438..0fd1a47 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -568,7 +568,7 @@ def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> st def inline_array_sql(self: Generator, expression: exp.Array) -> str: - return f"[{self.expressions(expression, flat=True)}]" + return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 6486dda..9f54826 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -28,7 +28,7 @@ from sqlglot.dialects.dialect import ( timestrtotime_sql, unit_to_var, ) -from sqlglot.helper import flatten, seq_get +from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -155,16 +155,6 @@ def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str return self.func("TO_TIMESTAMP", exp.Div(this=timestamp, expression=exp.func("POW", 10, scale))) -def _rename_unless_within_group( - a: str, b: str -) -> t.Callable[[DuckDB.Generator, exp.Expression], str]: - return lambda self, expression: ( - self.func(a, *flatten(expression.args.values())) - if isinstance(expression.find_ancestor(exp.Select, exp.WithinGroup), exp.WithinGroup) - else self.func(b, *flatten(expression.args.values())) - ) - - class DuckDB(Dialect): NULL_ORDERING = "nulls_are_last" SUPPORTS_USER_DEFINED_TYPES = False @@ -425,8 +415,8 @@ class DuckDB(Dialect): exp.cast(e.this, exp.DataType.Type.TIMESTAMP, copy=True), ), exp.ParseJSON: rename_func("JSON"), - exp.PercentileCont: _rename_unless_within_group("PERCENTILE_CONT", "QUANTILE_CONT"), - exp.PercentileDisc: _rename_unless_within_group("PERCENTILE_DISC", "QUANTILE_DISC"), + exp.PercentileCont: rename_func("QUANTILE_CONT"), + exp.PercentileDisc: rename_func("QUANTILE_DISC"), # DuckDB doesn't allow qualified columns inside of PIVOT expressions. # See: https://github.com/duckdb/duckdb/blob/671faf92411182f81dce42ac43de8bfb05d9909e/src/planner/binder/tableref/bind_pivot.cpp#L61-L62 exp.Pivot: transforms.preprocess([transforms.unqualify_columns]), @@ -499,6 +489,7 @@ class DuckDB(Dialect): exp.DataType.Type.NVARCHAR: "TEXT", exp.DataType.Type.UINT: "UINTEGER", exp.DataType.Type.VARBINARY: "BLOB", + exp.DataType.Type.ROWVERSION: "BLOB", exp.DataType.Type.VARCHAR: "TEXT", exp.DataType.Type.TIMESTAMP_S: "TIMESTAMP_S", exp.DataType.Type.TIMESTAMP_MS: "TIMESTAMP_MS", @@ -616,3 +607,19 @@ class DuckDB(Dialect): bracket = f"({bracket})[1]" return bracket + + def withingroup_sql(self, expression: exp.WithinGroup) -> str: + expression_sql = self.sql(expression, "expression") + + func = expression.this + if isinstance(func, exp.PERCENTILES): + # Make the order key the first arg and slide the fraction to the right + # https://duckdb.org/docs/sql/aggregates#ordered-set-aggregate-functions + order_col = expression.find(exp.Ordered) + if order_col: + func.set("expression", func.this) + func.set("this", order_col.this) + + this = self.sql(expression, "this").rstrip(")") + + return f"{this}{expression_sql})" diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index cc7debb..d86691e 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -457,6 +457,7 @@ class Hive(Dialect): exp.DataType.Type.TIME: "TIMESTAMP", exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", exp.DataType.Type.VARBINARY: "BINARY", + exp.DataType.Type.ROWVERSION: "BINARY", } TRANSFORMS = { diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 611a155..03576d2 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -443,6 +443,7 @@ class MySQL(Dialect): LOG_DEFAULTS_TO_LN = True STRING_ALIASES = True VALUES_FOLLOWED_BY_PAREN = False + SUPPORTS_PARTITION_SELECTION = True def _parse_primary_key_part(self) -> t.Optional[exp.Expression]: this = self._parse_id_var() diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index e038400..13f86ac 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -222,6 +222,7 @@ class Oracle(Dialect): exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", exp.DataType.Type.BINARY: "BLOB", exp.DataType.Type.VARBINARY: "BLOB", + exp.DataType.Type.ROWVERSION: "BLOB", } TRANSFORMS = { diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 7cbcc23..71339b8 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -431,6 +431,7 @@ class Postgres(Dialect): exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", exp.DataType.Type.BINARY: "BYTEA", exp.DataType.Type.VARBINARY: "BYTEA", + exp.DataType.Type.ROWVERSION: "BYTEA", exp.DataType.Type.DATETIME: "TIMESTAMP", } diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 6c23bdf..1f02831 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -443,6 +443,67 @@ class Presto(Dialect): exp.Xor: bool_xor_sql, } + RESERVED_KEYWORDS = { + "alter", + "and", + "as", + "between", + "by", + "case", + "cast", + "constraint", + "create", + "cross", + "current_time", + "current_timestamp", + "deallocate", + "delete", + "describe", + "distinct", + "drop", + "else", + "end", + "escape", + "except", + "execute", + "exists", + "extract", + "false", + "for", + "from", + "full", + "group", + "having", + "in", + "inner", + "insert", + "intersect", + "into", + "is", + "join", + "left", + "like", + "natural", + "not", + "null", + "on", + "or", + "order", + "outer", + "prepare", + "right", + "select", + "table", + "then", + "true", + "union", + "using", + "values", + "when", + "where", + "with", + } + def strtounix_sql(self, expression: exp.StrToUnix) -> str: # Since `TO_UNIXTIME` requires a `TIMESTAMP`, we need to parse the argument into one. # To do this, we first try to `DATE_PARSE` it, but since this can fail when there's a diff --git a/sqlglot/dialects/prql.py b/sqlglot/dialects/prql.py index 3ee91a8..028c309 100644 --- a/sqlglot/dialects/prql.py +++ b/sqlglot/dialects/prql.py @@ -55,6 +55,20 @@ class PRQL(Dialect): "SORT": lambda self, query: self._parse_order_by(query), } + def _parse_equality(self) -> t.Optional[exp.Expression]: + eq = self._parse_tokens(self._parse_comparison, self.EQUALITY) + if not isinstance(eq, (exp.EQ, exp.NEQ)): + return eq + + # https://prql-lang.org/book/reference/spec/null.html + if isinstance(eq.expression, exp.Null): + is_exp = exp.Is(this=eq.this, expression=eq.expression) + return is_exp if isinstance(eq, exp.EQ) else exp.Not(this=is_exp) + if isinstance(eq.this, exp.Null): + is_exp = exp.Is(this=eq.expression, expression=eq.this) + return is_exp if isinstance(eq, exp.EQ) else exp.Not(this=is_exp) + return eq + def _parse_statement(self) -> t.Optional[exp.Expression]: expression = self._parse_expression() expression = expression if expression else self._parse_query() @@ -136,6 +150,7 @@ class PRQL(Dialect): alias_tokens: t.Optional[t.Collection[TokenType]] = None, parse_bracket: bool = False, is_db_reference: bool = False, + parse_partition: bool = False, ) -> t.Optional[exp.Expression]: return self._parse_table_parts() diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 7a86c61..7b98ed4 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -79,6 +79,7 @@ class Redshift(Postgres): alias_tokens: t.Optional[t.Collection[TokenType]] = None, parse_bracket: bool = False, is_db_reference: bool = False, + parse_partition: bool = False, ) -> t.Optional[exp.Expression]: # Redshift supports UNPIVOTing SUPER objects, e.g. `UNPIVOT foo.obj[0] AS val AT attr` unpivot = self._match(TokenType.UNPIVOT) @@ -145,6 +146,7 @@ class Redshift(Postgres): exp.DataType.Type.TIMETZ: "TIME", exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", exp.DataType.Type.VARBINARY: "VARBYTE", + exp.DataType.Type.ROWVERSION: "VARBYTE", } TRANSFORMS = { @@ -196,7 +198,165 @@ class Redshift(Postgres): # Redshift supports LAST_DAY(..) TRANSFORMS.pop(exp.LastDay) - RESERVED_KEYWORDS = {*Postgres.Generator.RESERVED_KEYWORDS, "snapshot", "type"} + RESERVED_KEYWORDS = { + "aes128", + "aes256", + "all", + "allowoverwrite", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "authorization", + "az64", + "backup", + "between", + "binary", + "blanksasnull", + "both", + "bytedict", + "bzip2", + "case", + "cast", + "check", + "collate", + "column", + "constraint", + "create", + "credentials", + "cross", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "current_user_id", + "default", + "deferrable", + "deflate", + "defrag", + "delta", + "delta32k", + "desc", + "disable", + "distinct", + "do", + "else", + "emptyasnull", + "enable", + "encode", + "encrypt ", + "encryption", + "end", + "except", + "explicit", + "false", + "for", + "foreign", + "freeze", + "from", + "full", + "globaldict256", + "globaldict64k", + "grant", + "group", + "gzip", + "having", + "identity", + "ignore", + "ilike", + "in", + "initially", + "inner", + "intersect", + "interval", + "into", + "is", + "isnull", + "join", + "leading", + "left", + "like", + "limit", + "localtime", + "localtimestamp", + "lun", + "luns", + "lzo", + "lzop", + "minus", + "mostly16", + "mostly32", + "mostly8", + "natural", + "new", + "not", + "notnull", + "null", + "nulls", + "off", + "offline", + "offset", + "oid", + "old", + "on", + "only", + "open", + "or", + "order", + "outer", + "overlaps", + "parallel", + "partition", + "percent", + "permissions", + "pivot", + "placing", + "primary", + "raw", + "readratio", + "recover", + "references", + "rejectlog", + "resort", + "respect", + "restore", + "right", + "select", + "session_user", + "similar", + "snapshot", + "some", + "sysdate", + "system", + "table", + "tag", + "tdes", + "text255", + "text32k", + "then", + "timestamp", + "to", + "top", + "trailing", + "true", + "truncatecolumns", + "type", + "union", + "unique", + "unnest", + "unpivot", + "user", + "using", + "verbose", + "wallet", + "when", + "where", + "with", + "without", + } def unnest_sql(self, expression: exp.Unnest) -> str: args = expression.expressions diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 41d5b65..dba56c4 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -33,10 +33,9 @@ def _build_datetime( ) -> t.Callable[[t.List], exp.Func]: def _builder(args: t.List) -> exp.Func: value = seq_get(args, 0) + int_value = value is not None and is_int(value.name) if isinstance(value, exp.Literal): - int_value = is_int(value.this) - # Converts calls like `TO_TIME('01:02:03')` into casts if len(args) == 1 and value.is_string and not int_value: return exp.cast(value, kind) @@ -49,7 +48,7 @@ def _build_datetime( if not is_float(value.this): return build_formatted_time(exp.StrToTime, "snowflake")(args) - if len(args) == 2 and kind == exp.DataType.Type.DATE: + if kind == exp.DataType.Type.DATE and not int_value: formatted_exp = build_formatted_time(exp.TsOrDsToDate, "snowflake")(args) formatted_exp.set("safe", safe) return formatted_exp @@ -749,6 +748,7 @@ class Snowflake(Dialect): TRANSFORMS = { **generator.Generator.TRANSFORMS, + exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArgMax: rename_func("MAX_BY"), exp.ArgMin: rename_func("MIN_BY"), exp.Array: inline_array_sql, diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 6eed46d..a699f2b 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -464,6 +464,7 @@ class TSQL(Dialect): "SMALLMONEY": TokenType.SMALLMONEY, "SQL_VARIANT": TokenType.VARIANT, "TOP": TokenType.TOP, + "TIMESTAMP": TokenType.ROWVERSION, "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, "UPDATE STATISTICS": TokenType.COMMAND, "XML": TokenType.XML, @@ -755,6 +756,7 @@ class TSQL(Dialect): exp.DataType.Type.TIMESTAMP: "DATETIME2", exp.DataType.Type.TIMESTAMPTZ: "DATETIMEOFFSET", exp.DataType.Type.VARIANT: "SQL_VARIANT", + exp.DataType.Type.ROWVERSION: "ROWVERSION", } TYPE_MAPPING.pop(exp.DataType.Type.NCHAR) @@ -1052,3 +1054,9 @@ class TSQL(Dialect): def partition_sql(self, expression: exp.Partition) -> str: return f"WITH (PARTITIONS({self.expressions(expression, flat=True)}))" + + def altertable_sql(self, expression: exp.AlterTable) -> str: + action = seq_get(expression.args.get("actions") or [], 0) + if isinstance(action, exp.RenameTable): + return f"EXEC sp_rename '{self.sql(expression.this)}', '{action.this.name}'" + return super().altertable_sql(expression) -- cgit v1.2.3