From 9b39dac84e82bf473216939e50b8836170f01d23 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 29 Jun 2023 15:02:29 +0200 Subject: Merging upstream version 16.7.3. Signed-off-by: Daniel Baumann --- sqlglot/dialects/bigquery.py | 205 +++++++++++++++++++++++++++++++++++++++--- sqlglot/dialects/dialect.py | 5 ++ sqlglot/dialects/mysql.py | 15 ++++ sqlglot/dialects/postgres.py | 10 +++ sqlglot/dialects/presto.py | 13 ++- sqlglot/dialects/redshift.py | 8 +- sqlglot/dialects/snowflake.py | 31 ++++--- sqlglot/dialects/spark.py | 1 + sqlglot/dialects/spark2.py | 4 +- sqlglot/dialects/sqlite.py | 4 +- sqlglot/dialects/tsql.py | 1 + 11 files changed, 269 insertions(+), 28 deletions(-) (limited to 'sqlglot/dialects') diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 52d4a88..8786063 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import re import typing as t @@ -21,6 +22,8 @@ from sqlglot.dialects.dialect import ( from sqlglot.helper import seq_get, split_num_words from sqlglot.tokens import TokenType +logger = logging.getLogger("sqlglot") + def _date_add_sql( data_type: str, kind: str @@ -104,12 +107,70 @@ def _unqualify_unnest(expression: exp.Expression) -> exp.Expression: return expression +# https://issuetracker.google.com/issues/162294746 +# workaround for bigquery bug when grouping by an expression and then ordering +# WITH x AS (SELECT 1 y) +# SELECT y + 1 z +# FROM x +# GROUP BY x + 1 +# ORDER by z +def _alias_ordered_group(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Select): + group = expression.args.get("group") + order = expression.args.get("order") + + if group and order: + aliases = { + select.this: select.args["alias"] + for select in expression.selects + if isinstance(select, exp.Alias) + } + + for e in group.expressions: + alias = aliases.get(e) + + if alias: + e.replace(exp.column(alias)) + + return expression + + +def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression: + """BigQuery doesn't allow column names when defining a CTE, so we try to push them down.""" + if isinstance(expression, exp.CTE) and expression.alias_column_names: + cte_query = expression.this + + if cte_query.is_star: + logger.warning( + "Can't push down CTE column names for star queries. Run the query through" + " the optimizer or use 'qualify' to expand the star projections first." + ) + return expression + + column_names = expression.alias_column_names + expression.args["alias"].set("columns", None) + + for name, select in zip(column_names, cte_query.selects): + to_replace = select + + if isinstance(select, exp.Alias): + select = select.this + + # Inner aliases are shadowed by the CTE column names + to_replace.replace(exp.alias_(select, name)) + + return expression + + class BigQuery(Dialect): UNNEST_COLUMN_ONLY = True # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity RESOLVES_IDENTIFIERS_AS_UPPERCASE = None + # bigquery udfs are case sensitive + NORMALIZE_FUNCTIONS = False + TIME_MAPPING = { "%D": "%m/%d/%y", } @@ -135,12 +196,16 @@ class BigQuery(Dialect): # In BigQuery, CTEs aren't case-sensitive, but table names are (by default, at least). # The following check is essentially a heuristic to detect tables based on whether or # not they're qualified. - if ( - isinstance(expression, exp.Identifier) - and not (isinstance(expression.parent, exp.Table) and expression.parent.db) - and not expression.meta.get("is_table") - ): - expression.set("this", expression.this.lower()) + if isinstance(expression, exp.Identifier): + parent = expression.parent + + while isinstance(parent, exp.Dot): + parent = parent.parent + + if not (isinstance(parent, exp.Table) and parent.db) and not expression.meta.get( + "is_table" + ): + expression.set("this", expression.this.lower()) return expression @@ -298,10 +363,8 @@ class BigQuery(Dialect): **generator.Generator.TRANSFORMS, exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArraySize: rename_func("ARRAY_LENGTH"), - exp.AtTimeZone: lambda self, e: self.func( - "TIMESTAMP", self.func("DATETIME", e.this, e.args.get("zone")) - ), exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]), + exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), exp.DateAdd: _date_add_sql("DATE", "ADD"), exp.DateSub: _date_add_sql("DATE", "SUB"), exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"), @@ -325,7 +388,12 @@ class BigQuery(Dialect): ), exp.RegexpLike: rename_func("REGEXP_CONTAINS"), exp.Select: transforms.preprocess( - [_unqualify_unnest, transforms.eliminate_distinct_on] + [ + transforms.explode_to_unnest, + _unqualify_unnest, + transforms.eliminate_distinct_on, + _alias_ordered_group, + ] ), exp.StrToDate: lambda self, e: f"PARSE_DATE({self.format_time(e)}, {self.sql(e, 'this')})", exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", @@ -334,7 +402,6 @@ class BigQuery(Dialect): exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), exp.TimeStrToTime: timestrtotime_sql, - exp.TryCast: lambda self, e: f"SAFE_CAST({self.sql(e, 'this')} AS {self.sql(e, 'to')})", exp.TsOrDsToDate: ts_or_ds_to_date_sql("bigquery"), exp.TsOrDsAdd: _date_add_sql("DATE", "ADD"), exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", @@ -378,7 +445,121 @@ class BigQuery(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } - RESERVED_KEYWORDS = {*generator.Generator.RESERVED_KEYWORDS, "hash"} + # from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords + RESERVED_KEYWORDS = { + *generator.Generator.RESERVED_KEYWORDS, + "all", + "and", + "any", + "array", + "as", + "asc", + "assert_rows_modified", + "at", + "between", + "by", + "case", + "cast", + "collate", + "contains", + "create", + "cross", + "cube", + "current", + "default", + "define", + "desc", + "distinct", + "else", + "end", + "enum", + "escape", + "except", + "exclude", + "exists", + "extract", + "false", + "fetch", + "following", + "for", + "from", + "full", + "group", + "grouping", + "groups", + "hash", + "having", + "if", + "ignore", + "in", + "inner", + "intersect", + "interval", + "into", + "is", + "join", + "lateral", + "left", + "like", + "limit", + "lookup", + "merge", + "natural", + "new", + "no", + "not", + "null", + "nulls", + "of", + "on", + "or", + "order", + "outer", + "over", + "partition", + "preceding", + "proto", + "qualify", + "range", + "recursive", + "respect", + "right", + "rollup", + "rows", + "select", + "set", + "some", + "struct", + "tablesample", + "then", + "to", + "treat", + "true", + "unbounded", + "union", + "unnest", + "using", + "when", + "where", + "window", + "with", + "within", + } + + def attimezone_sql(self, expression: exp.AtTimeZone) -> str: + if not isinstance(expression.parent, exp.Cast): + return self.func( + "TIMESTAMP", self.func("DATETIME", expression.this, expression.args.get("zone")) + ) + return super().attimezone_sql(expression) + + def trycast_sql(self, expression: exp.TryCast) -> str: + return self.cast_sql(expression, safe_prefix="SAFE_") + + def cte_sql(self, expression: exp.CTE) -> str: + if expression.alias_column_names: + self.unsupported("Column names in CTE definition are not supported.") + return super().cte_sql(expression) def array_sql(self, expression: exp.Array) -> str: first_arg = seq_get(expression.expressions, 0) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 0e25b9b..d258826 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -388,6 +388,11 @@ def no_comment_column_constraint_sql( return "" +def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: + self.unsupported("MAP_FROM_ENTRIES unsupported") + return "" + + def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: this = self.sql(expression, "this") substr = self.sql(expression, "substr") diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 1dd2096..5f743ee 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -132,6 +132,10 @@ class MySQL(Dialect): "SEPARATOR": TokenType.SEPARATOR, "ENUM": TokenType.ENUM, "START": TokenType.BEGIN, + "SIGNED": TokenType.BIGINT, + "SIGNED INTEGER": TokenType.BIGINT, + "UNSIGNED": TokenType.UBIGINT, + "UNSIGNED INTEGER": TokenType.UBIGINT, "_ARMSCII8": TokenType.INTRODUCER, "_ASCII": TokenType.INTRODUCER, "_BIG5": TokenType.INTRODUCER, @@ -441,6 +445,17 @@ class MySQL(Dialect): LIMIT_FETCH = "LIMIT" + def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: + """(U)BIGINT is not allowed in a CAST expression, so we use (UN)SIGNED instead.""" + if expression.to.this == exp.DataType.Type.BIGINT: + to = "SIGNED" + elif expression.to.this == exp.DataType.Type.UBIGINT: + to = "UNSIGNED" + else: + return super().cast_sql(expression) + + return f"CAST({self.sql(expression, 'this')} AS {to})" + def show_sql(self, expression: exp.Show) -> str: this = f" {expression.name}" full = " FULL" if expression.args.get("full") else "" diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 8c2a4ab..766b584 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import ( format_time_lambda, max_or_greatest, min_or_least, + no_map_from_entries_sql, no_paren_current_date_sql, no_pivot_sql, no_tablesample_sql, @@ -346,6 +347,7 @@ class Postgres(Dialect): exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), exp.Max: max_or_greatest, + exp.MapFromEntries: no_map_from_entries_sql, exp.Min: min_or_least, exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), exp.ArrayContains: lambda self, e: self.binary(e, "@>"), @@ -378,3 +380,11 @@ class Postgres(Dialect): exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + + def bracket_sql(self, expression: exp.Bracket) -> str: + """Forms like ARRAY[1, 2, 3][3] aren't allowed; we need to wrap the ARRAY.""" + if isinstance(expression.this, exp.Array): + expression = expression.copy() + expression.set("this", exp.paren(expression.this, copy=False)) + + return super().bracket_sql(expression) diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 265780e..24c439b 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -20,7 +20,7 @@ from sqlglot.dialects.dialect import ( ) from sqlglot.dialects.mysql import MySQL from sqlglot.errors import UnsupportedError -from sqlglot.helper import seq_get +from sqlglot.helper import apply_index_offset, seq_get from sqlglot.tokens import TokenType @@ -154,6 +154,13 @@ def _from_unixtime(args: t.List) -> exp.Expression: return exp.UnixToTime.from_arg_list(args) +def _parse_element_at(args: t.List) -> exp.SafeBracket: + this = seq_get(args, 0) + index = seq_get(args, 1) + assert isinstance(this, exp.Expression) and isinstance(index, exp.Expression) + return exp.SafeBracket(this=this, expressions=apply_index_offset(this, [index], -1)) + + def _unnest_sequence(expression: exp.Expression) -> exp.Expression: if isinstance(expression, exp.Table): if isinstance(expression.this, exp.GenerateSeries): @@ -201,6 +208,7 @@ class Presto(Dialect): "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), "DATE_TRUNC": date_trunc_to_time, + "ELEMENT_AT": _parse_element_at, "FROM_HEX": exp.Unhex.from_arg_list, "FROM_UNIXTIME": _from_unixtime, "FROM_UTF8": lambda args: exp.Decode( @@ -285,6 +293,9 @@ class Presto(Dialect): exp.Pivot: no_pivot_sql, exp.Quantile: _quantile_sql, exp.Right: right_to_substring_sql, + exp.SafeBracket: lambda self, e: self.func( + "ELEMENT_AT", e.this, seq_get(apply_index_offset(e.this, e.expressions, 1), 0) + ), exp.SafeDivide: no_safe_divide_sql, exp.Schema: _schema_sql, exp.Select: transforms.preprocess( diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index db6cc3f..87be42c 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -41,8 +41,6 @@ class Redshift(Postgres): "STRTOL": exp.FromBase.from_arg_list, } - CONVERT_TYPE_FIRST = True - def _parse_types( self, check_func: bool = False, schema: bool = False ) -> t.Optional[exp.Expression]: @@ -58,6 +56,12 @@ class Redshift(Postgres): return this + def _parse_convert(self, strict: bool) -> t.Optional[exp.Expression]: + to = self._parse_types() + self._match(TokenType.COMMA) + this = self._parse_bitwise() + return self.expression(exp.TryCast, this=this, to=to) + class Tokenizer(Postgres.Tokenizer): BIT_STRINGS = [] HEX_STRINGS = [] diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 1f620df..a2dbfd9 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -258,14 +258,29 @@ class Snowflake(Dialect): ALTER_PARSERS = { **parser.Parser.ALTER_PARSERS, - "UNSET": lambda self: self._parse_alter_table_set_tag(unset=True), - "SET": lambda self: self._parse_alter_table_set_tag(), + "SET": lambda self: self._parse_set(tag=self._match_text_seq("TAG")), + "UNSET": lambda self: self.expression( + exp.Set, + tag=self._match_text_seq("TAG"), + expressions=self._parse_csv(self._parse_id_var), + unset=True, + ), } - def _parse_alter_table_set_tag(self, unset: bool = False) -> exp.Expression: - self._match_text_seq("TAG") - parser = t.cast(t.Callable, self._parse_id_var if unset else self._parse_conjunction) - return self.expression(exp.SetTag, expressions=self._parse_csv(parser), unset=unset) + def _parse_id_var( + self, + any_token: bool = True, + tokens: t.Optional[t.Collection[TokenType]] = None, + ) -> t.Optional[exp.Expression]: + if self._match_text_seq("IDENTIFIER", "("): + identifier = ( + super()._parse_id_var(any_token=any_token, tokens=tokens) + or self._parse_string() + ) + self._match_r_paren() + return self.expression(exp.Anonymous, this="IDENTIFIER", expressions=[identifier]) + + return super()._parse_id_var(any_token=any_token, tokens=tokens) class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] @@ -380,10 +395,6 @@ class Snowflake(Dialect): self.unsupported("INTERSECT with All is not supported in Snowflake") return super().intersect_op(expression) - def settag_sql(self, expression: exp.SetTag) -> str: - action = "UNSET" if expression.args.get("unset") else "SET" - return f"{action} TAG {self.expressions(expression)}" - def describe_sql(self, expression: exp.Describe) -> str: # Default to table if kind is unknown kind_value = expression.args.get("kind") or "TABLE" diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index b7d1641..7a7ee01 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -43,6 +43,7 @@ class Spark(Spark2): class Generator(Spark2.Generator): TRANSFORMS = Spark2.Generator.TRANSFORMS.copy() TRANSFORMS.pop(exp.DateDiff) + TRANSFORMS.pop(exp.Group) def datediff_sql(self, expression: exp.DateDiff) -> str: unit = self.sql(expression, "unit") diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index ed6992d..3720b8d 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -231,14 +231,14 @@ class Spark2(Hive): WRAP_DERIVED_VALUES = False CREATE_FUNCTION_RETURN_AS = False - def cast_sql(self, expression: exp.Cast) -> str: + def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: if isinstance(expression.this, exp.Cast) and expression.this.is_type("json"): schema = f"'{self.sql(expression, 'to')}'" return self.func("FROM_JSON", expression.this.this, schema) if expression.is_type("json"): return self.func("TO_JSON", expression.this) - return super(Hive.Generator, self).cast_sql(expression) + return super(Hive.Generator, self).cast_sql(expression, safe_prefix=safe_prefix) def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: return super().columndef_sql( diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index 803f361..519e62a 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, @@ -133,7 +135,7 @@ class SQLite(Dialect): LIMIT_FETCH = "LIMIT" - def cast_sql(self, expression: exp.Cast) -> str: + def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: if expression.is_type("date"): return self.func("DATE", expression.this) diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 6d674f5..f671630 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -166,6 +166,7 @@ def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> s class TSQL(Dialect): + RESOLVES_IDENTIFIERS_AS_UPPERCASE = None NULL_ORDERING = "nulls_are_small" TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'" -- cgit v1.2.3