diff options
Diffstat (limited to 'sqlglot')
-rw-r--r-- | sqlglot/dialects/duckdb.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 3 | ||||
-rw-r--r-- | sqlglot/expressions.py | 11 | ||||
-rw-r--r-- | sqlglot/generator.py | 58 | ||||
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 16 | ||||
-rw-r--r-- | sqlglot/parser.py | 50 | ||||
-rw-r--r-- | sqlglot/schema.py | 23 | ||||
-rw-r--r-- | sqlglot/tokens.py | 20 |
10 files changed, 133 insertions, 54 deletions
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index d7ba729..e61ac4f 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -333,6 +333,7 @@ class DuckDB(Dialect): IGNORE_NULLS_IN_FUNC = True JSON_PATH_BRACKETED_KEY_SUPPORTED = False SUPPORTS_CREATE_TABLE_LIKE = False + MULTI_ARG_DISTINCT = False TRANSFORMS = { **generator.Generator.TRANSFORMS, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 0404c78..68e2c6d 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -232,6 +232,9 @@ class Postgres(Dialect): BYTE_STRINGS = [("e'", "'"), ("E'", "'")] HEREDOC_STRINGS = ["$"] + HEREDOC_TAG_IS_IDENTIFIER = True + HEREDOC_STRING_ALTERNATIVE = TokenType.PARAMETER + KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "~~": TokenType.LIKE, @@ -381,6 +384,7 @@ class Postgres(Dialect): JSON_TYPE_REQUIRED_FOR_EXTRACTION = True SUPPORTS_UNLOGGED_TABLES = True LIKE_PROPERTY_INSIDE_SCHEMA = True + MULTI_ARG_DISTINCT = False SUPPORTED_JSON_PATH_PARTS = { exp.JSONPathKey, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 8691192..609103e 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -292,6 +292,7 @@ class Presto(Dialect): LIMIT_ONLY_LITERALS = True SUPPORTS_SINGLE_ARG_CONCAT = False LIKE_PROPERTY_INSIDE_SCHEMA = True + MULTI_ARG_DISTINCT = False PROPERTIES_LOCATION = { **generator.Generator.PROPERTIES_LOCATION, diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 4c5c131..44bd12d 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -50,9 +50,6 @@ class Spark(Spark2): "DATEDIFF": _parse_datediff, } - FUNCTION_PARSERS = Spark2.Parser.FUNCTION_PARSERS.copy() - FUNCTION_PARSERS.pop("ANY_VALUE") - def _parse_generated_as_identity( self, ) -> ( diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 3234c99..11ebbaf 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1796,7 +1796,7 @@ class Lambda(Expression): class Limit(Expression): - arg_types = {"this": False, "expression": True, "offset": False} + arg_types = {"this": False, "expression": True, "offset": False, "expressions": False} class Literal(Condition): @@ -1969,7 +1969,7 @@ class Final(Expression): class Offset(Expression): - arg_types = {"this": False, "expression": True} + arg_types = {"this": False, "expression": True, "expressions": False} class Order(Expression): @@ -4291,6 +4291,11 @@ class RespectNulls(Expression): pass +# https://cloud.google.com/bigquery/docs/reference/standard-sql/aggregate-function-calls#max_min_clause +class HavingMax(Expression): + arg_types = {"this": True, "expression": True, "max": True} + + # Functions class Func(Condition): """ @@ -4491,7 +4496,7 @@ class Avg(AggFunc): class AnyValue(AggFunc): - arg_types = {"this": True, "having": False, "max": False} + pass class Lag(AggFunc): diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 568dcb4..318d782 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -296,6 +296,10 @@ class Generator(metaclass=_Generator): # Whether or not the LikeProperty needs to be specified inside of the schema clause LIKE_PROPERTY_INSIDE_SCHEMA = False + # Whether or not DISTINCT can be followed by multiple args in an AggFunc. If not, it will be + # transpiled into a series of CASE-WHEN-ELSE, ultimately using a tuple conseisting of the args + MULTI_ARG_DISTINCT = True + # Whether or not the JSON extraction operators expect a value of type JSON JSON_TYPE_REQUIRED_FOR_EXTRACTION = False @@ -1841,15 +1845,18 @@ class Generator(metaclass=_Generator): args_sql = ", ".join(self.sql(e) for e in args) args_sql = f"({args_sql})" if any(top and not e.is_number for e in args) else args_sql - return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args_sql}" + expressions = self.expressions(expression, flat=True) + expressions = f" BY {expressions}" if expressions else "" + + return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args_sql}{expressions}" def offset_sql(self, expression: exp.Offset) -> str: this = self.sql(expression, "this") - expression = expression.expression - expression = ( - self._simplify_unless_literal(expression) if self.LIMIT_ONLY_LITERALS else expression - ) - return f"{this}{self.seg('OFFSET')} {self.sql(expression)}" + value = expression.expression + value = self._simplify_unless_literal(value) if self.LIMIT_ONLY_LITERALS else value + expressions = self.expressions(expression, flat=True) + expressions = f" BY {expressions}" if expressions else "" + return f"{this}{self.seg('OFFSET')} {self.sql(value)}{expressions}" def setitem_sql(self, expression: exp.SetItem) -> str: kind = self.sql(expression, "kind") @@ -2834,6 +2841,13 @@ class Generator(metaclass=_Generator): def distinct_sql(self, expression: exp.Distinct) -> str: this = self.expressions(expression, flat=True) + + if not self.MULTI_ARG_DISTINCT and len(expression.expressions) > 1: + case = exp.case() + for arg in expression.expressions: + case = case.when(arg.is_(exp.null()), exp.null()) + this = self.sql(case.else_(f"({this})")) + this = f" {this}" if this else "" on = self.sql(expression, "on") @@ -2846,13 +2860,33 @@ class Generator(metaclass=_Generator): def respectnulls_sql(self, expression: exp.RespectNulls) -> str: return self._embed_ignore_nulls(expression, "RESPECT NULLS") + def havingmax_sql(self, expression: exp.HavingMax) -> str: + this_sql = self.sql(expression, "this") + expression_sql = self.sql(expression, "expression") + kind = "MAX" if expression.args.get("max") else "MIN" + return f"{this_sql} HAVING {kind} {expression_sql}" + def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str: - if self.IGNORE_NULLS_IN_FUNC: - this = expression.find(exp.AggFunc) - if this: - sql = self.sql(this) - sql = sql[:-1] + f" {text})" - return sql + if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"): + # The first modifier here will be the one closest to the AggFunc's arg + mods = sorted( + expression.find_all(exp.HavingMax, exp.Order, exp.Limit), + key=lambda x: 0 + if isinstance(x, exp.HavingMax) + else (1 if isinstance(x, exp.Order) else 2), + ) + + if mods: + mod = mods[0] + this = expression.__class__(this=mod.this.copy()) + this.meta["inline"] = True + mod.this.replace(this) + return self.sql(expression.this) + + agg_func = expression.find(exp.AggFunc) + + if agg_func: + return self.sql(agg_func)[:-1] + f" {text})" return f"{self.sql(expression, 'this')} {text}" diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index a2a86cd..cb9312c 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -263,6 +263,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), exp.Div: lambda self, e: self._annotate_div(e), + exp.Explode: lambda self, e: self._annotate_explode(e), exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), @@ -333,9 +334,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator): self._visited: t.Set[int] = set() def _set_type( - self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type + self, expression: exp.Expression, target_type: t.Optional[exp.DataType | exp.DataType.Type] ) -> None: - expression.type = target_type # type: ignore + expression.type = target_type or exp.DataType.Type.UNKNOWN # type: ignore self._visited.add(id(expression)) def annotate(self, expression: E) -> E: @@ -564,13 +565,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator): if isinstance(bracket_arg, exp.Slice): self._set_type(expression, this.type) elif this.type.is_type(exp.DataType.Type.ARRAY): - contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN - self._set_type(expression, contained_type) + self._set_type(expression, seq_get(this.type.expressions, 0)) elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys: index = this.keys.index(bracket_arg) value = seq_get(this.values, index) - value_type = value.type if value else exp.DataType.Type.UNKNOWN - self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN) + self._set_type(expression, value.type if value else None) else: self._set_type(expression, exp.DataType.Type.UNKNOWN) @@ -591,3 +590,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator): self._set_type(expression, self._maybe_coerce(left_type, right_type)) return expression + + def _annotate_explode(self, expression: exp.Explode) -> exp.Explode: + self._annotate_args(expression) + self._set_type(expression, seq_get(expression.this.type.expressions, 0)) + return expression diff --git a/sqlglot/parser.py b/sqlglot/parser.py index a89e4fa..dfa3024 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -872,7 +872,6 @@ class Parser(metaclass=_Parser): FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"} FUNCTION_PARSERS = { - "ANY_VALUE": lambda self: self._parse_any_value(), "CAST": lambda self: self._parse_cast(self.STRICT_CAST), "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), "DECODE": lambda self: self._parse_decode(), @@ -2465,8 +2464,14 @@ class Parser(metaclass=_Parser): this.set(key, expression) if key == "limit": offset = expression.args.pop("offset", None) + if offset: - this.set("offset", exp.Offset(expression=offset)) + offset = exp.Offset(expression=offset) + this.set("offset", offset) + + limit_by_expressions = expression.expressions + expression.set("expressions", None) + offset.set("expressions", limit_by_expressions) continue break return this @@ -3341,7 +3346,12 @@ class Parser(metaclass=_Parser): offset = None limit_exp = self.expression( - exp.Limit, this=this, expression=expression, offset=offset, comments=comments + exp.Limit, + this=this, + expression=expression, + offset=offset, + comments=comments, + expressions=self._parse_limit_by(), ) return limit_exp @@ -3377,7 +3387,13 @@ class Parser(metaclass=_Parser): count = self._parse_term() self._match_set((TokenType.ROW, TokenType.ROWS)) - return self.expression(exp.Offset, this=this, expression=count) + + return self.expression( + exp.Offset, this=this, expression=count, expressions=self._parse_limit_by() + ) + + def _parse_limit_by(self) -> t.Optional[t.List[exp.Expression]]: + return self._match_text_seq("BY") and self._parse_csv(self._parse_bitwise) def _parse_locks(self) -> t.List[exp.Lock]: locks = [] @@ -4115,7 +4131,9 @@ class Parser(metaclass=_Parser): else: this = self._parse_select_or_expression(alias=alias) - return self._parse_limit(self._parse_order(self._parse_respect_or_ignore_nulls(this))) + return self._parse_limit( + self._parse_order(self._parse_having_max(self._parse_respect_or_ignore_nulls(this))) + ) def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: index = self._index @@ -4549,18 +4567,6 @@ class Parser(metaclass=_Parser): return self.expression(exp.Extract, this=this, expression=self._parse_bitwise()) - def _parse_any_value(self) -> exp.AnyValue: - this = self._parse_lambda() - is_max = None - having = None - - if self._match(TokenType.HAVING): - self._match_texts(("MAX", "MIN")) - is_max = self._prev.text == "MAX" - having = self._parse_column() - - return self.expression(exp.AnyValue, this=this, having=having, max=is_max) - def _parse_cast(self, strict: bool, safe: t.Optional[bool] = None) -> exp.Expression: this = self._parse_conjunction() @@ -4941,6 +4947,16 @@ class Parser(metaclass=_Parser): return self.expression(exp.RespectNulls, this=this) return this + def _parse_having_max(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + if self._match(TokenType.HAVING): + self._match_texts(("MAX", "MIN")) + max = self._prev.text.upper() != "MIN" + return self.expression( + exp.HavingMax, this=this, expression=self._parse_column(), max=max + ) + + return this + def _parse_window( self, this: t.Optional[exp.Expression], alias: bool = False ) -> t.Optional[exp.Expression]: diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 13f72d6..1fd4025 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -106,19 +106,6 @@ class Schema(abc.ABC): name = column if isinstance(column, str) else column.name return name in self.column_names(table, dialect=dialect, normalize=normalize) - @abc.abstractmethod - def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: - """ - Returns the schema of a given table. - - Args: - table: the target table. - raise_on_missing: whether or not to raise in case the schema is not found. - - Returns: - The schema of the target table. - """ - @property @abc.abstractmethod def supported_table_args(self) -> t.Tuple[str, ...]: @@ -170,6 +157,16 @@ class AbstractMappingSchema: return [table.text(part) for part in exp.TABLE_PARTS if table.text(part)] def find(self, table: exp.Table, raise_on_missing: bool = True) -> t.Optional[t.Any]: + """ + Returns the schema of a given table. + + Args: + table: the target table. + raise_on_missing: whether or not to raise in case the schema is not found. + + Returns: + The schema of the target table. + """ parts = self.table_parts(table)[0 : len(self.supported_table_args)] value, trie = in_trie(self.mapping_trie, parts) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 87a4924..b064957 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -504,6 +504,7 @@ class _Tokenizer(type): command_prefix_tokens={ _TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMAND_PREFIX_TOKENS }, + heredoc_tag_is_identifier=klass.HEREDOC_TAG_IS_IDENTIFIER, ) token_types = RsTokenTypeSettings( bit_string=_TOKEN_TYPE_TO_INDEX[TokenType.BIT_STRING], @@ -517,6 +518,7 @@ class _Tokenizer(type): semicolon=_TOKEN_TYPE_TO_INDEX[TokenType.SEMICOLON], string=_TOKEN_TYPE_TO_INDEX[TokenType.STRING], var=_TOKEN_TYPE_TO_INDEX[TokenType.VAR], + heredoc_string_alternative=_TOKEN_TYPE_TO_INDEX[klass.HEREDOC_STRING_ALTERNATIVE], ) klass._RS_TOKENIZER = RsTokenizer(settings, token_types) else: @@ -573,6 +575,12 @@ class Tokenizer(metaclass=_Tokenizer): STRING_ESCAPES = ["'"] VAR_SINGLE_TOKENS: t.Set[str] = set() + # Whether or not the heredoc tags follow the same lexical rules as unquoted identifiers + HEREDOC_TAG_IS_IDENTIFIER = False + + # Token that we'll generate as a fallback if the heredoc prefix doesn't correspond to a heredoc + HEREDOC_STRING_ALTERNATIVE = TokenType.VAR + # Autofilled _COMMENTS: t.Dict[str, str] = {} _FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {} @@ -1249,6 +1257,18 @@ class Tokenizer(metaclass=_Tokenizer): elif token_type == TokenType.BIT_STRING: base = 2 elif token_type == TokenType.HEREDOC_STRING: + if ( + self.HEREDOC_TAG_IS_IDENTIFIER + and not self._peek.isidentifier() + and not self._peek == end + ): + if self.HEREDOC_STRING_ALTERNATIVE != token_type.VAR: + self._add(self.HEREDOC_STRING_ALTERNATIVE) + else: + self._scan_var() + + return True + self._advance() tag = "" if self._char == end else self._extract_string(end) end = f"{start}{tag}{end}" |